https://todo.sr.ht/~emersion/soju/207 diff -u b/user.go b/user.go --- b/user.go +++ b/user.go @@ -218,6 +218,7 @@ net.user.srv.metrics.upstreams.Add(1) defer net.user.srv.metrics.upstreams.Add(-1) + done := ctx.Done() ctx, cancel := context.WithTimeout(ctx, time.Minute) defer cancel() @@ -227,6 +228,12 @@ } defer uc.Close() + // The context is cancelled by the caller when the network is stopped. + go func() { + <-done + uc.Close() + }() + if net.user.srv.Identd != nil { net.user.srv.Identd.Store(uc.RemoteAddr().String(), uc.LocalAddr().String(), userIdent(&net.user.User)) defer net.user.srv.Identd.Delete(uc.RemoteAddr().String(), uc.LocalAddr().String()) @@ -239,9 +246,6 @@ return fmt.Errorf("failed to register: %w", err) } - // TODO: this is racy with net.stopped. If the network is stopped - // before the user goroutine receives eventUpstreamConnected, the - // connection won't be closed. net.user.events <- eventUpstreamConnected{uc} defer func() { net.user.events <- eventUpstreamDisconnected{uc} @@ -259,6 +263,12 @@ return } + ctx, cancel := context.WithCancel(context.TODO()) + go func() { + <-net.stopped + cancel() + }() + var lastTry time.Time backoff := newBackoffer(retryConnectMinDelay, retryConnectMaxDelay, retryConnectJitter) for { @@ -273,7 +283,7 @@ } lastTry = time.Now() - if err := net.runConn(context.TODO()); err != nil { + if err := net.runConn(ctx); err != nil { text := err.Error() temp := true var regErr registrationError @@ -299,10 +309,6 @@ if !net.isStopped() { close(net.stopped) } - - if net.conn != nil { - net.conn.Close() - } } func (net *network) detach(ch *database.Channel) {