Fix: let lib/traceroute wait all goroutine finish

This commit is contained in:
186526 2022-05-05 20:50:28 +08:00
parent e4af31660d
commit 0d2e6e1e17
Signed by: 186526
GPG Key ID: C7EB1E6B8CC5E51D
3 changed files with 36 additions and 14 deletions

2
.gitignore vendored
View File

@ -9,6 +9,8 @@
*.so *.so
*.dylib *.dylib
tracer
# Test binary, built with `go test -c` # Test binary, built with `go test -c`
*.test *.test

View File

@ -154,9 +154,12 @@ type TracerouteResult struct {
Hops []TracerouteHop Hops []TracerouteHop
} }
func notify(hop TracerouteHop, channels []chan TracerouteHop) { func notify(hop TracerouteHop, channels []chan TracerouteHop, DoneChannels chan bool) {
// fmt.Print(hop)
for _, c := range channels { for _, c := range channels {
c <- hop c <- hop
<-DoneChannels
// fmt.Print("Done")
} }
} }
@ -166,7 +169,7 @@ func closeNotify(channels []chan TracerouteHop) {
} }
} }
func Traceroute(dest string, options *TracerouteOptions, c ...chan TracerouteHop) (result TracerouteResult, err error) { func Traceroute(dest string, options *TracerouteOptions, d chan bool, c ...chan TracerouteHop) (result TracerouteResult, err error) {
result.Hops = []TracerouteHop{} result.Hops = []TracerouteHop{}
destAddrBytes, destIPAddr, err := destAddr(dest) destAddrBytes, destIPAddr, err := destAddr(dest)
result.DestinationAddress = destAddrBytes result.DestinationAddress = destAddrBytes
@ -237,7 +240,7 @@ func Traceroute(dest string, options *TracerouteOptions, c ...chan TracerouteHop
if err != nil { if err != nil {
if err, ok := err.(net.Error); ok && err.Timeout() { if err, ok := err.(net.Error); ok && err.Timeout() {
// means timeout here // means timeout here
notify(TracerouteHop{Success: false, TTL: ttl}, c) notify(TracerouteHop{Success: false, TTL: ttl}, c, d)
retry += 1 retry += 1
if retry > options.Retries() { if retry > options.Retries() {
ttl += 1 ttl += 1
@ -271,7 +274,7 @@ func Traceroute(dest string, options *TracerouteOptions, c ...chan TracerouteHop
} }
if rm.Type == ipv4.ICMPTypeEchoReply || rm.Type == ipv4.ICMPTypeTimeExceeded { if rm.Type == ipv4.ICMPTypeEchoReply || rm.Type == ipv4.ICMPTypeTimeExceeded {
notify(hop, c) notify(hop, c, d)
result.Hops = append(result.Hops, hop) result.Hops = append(result.Hops, hop)
} }
@ -282,7 +285,5 @@ func Traceroute(dest string, options *TracerouteOptions, c ...chan TracerouteHop
closeNotify(c) closeNotify(c)
return result, nil return result, nil
} }
time.Sleep(time.Millisecond * 500)
} }
} }

View File

@ -20,6 +20,9 @@ import (
"github.com/oschwald/maxminddb-golang" "github.com/oschwald/maxminddb-golang"
) )
var isFinish bool = false
var destIP net.IP
func remove(slice []string, s int) []string { func remove(slice []string, s int) []string {
if s < len(slice) { if s < len(slice) {
return append(slice[:s], slice[s+1:]...) return append(slice[:s], slice[s+1:]...)
@ -27,6 +30,15 @@ func remove(slice []string, s int) []string {
return slice return slice
} }
func removeEmpty(slice []string) (ret []string) {
for i := 0; i < len(slice); i++ {
if (len(slice[i])) != 0 {
ret = append(ret, slice[i])
}
}
return
}
type record struct { type record struct {
ASN int `maxminddb:"autonomous_system_number"` ASN int `maxminddb:"autonomous_system_number"`
ASO string `maxminddb:"autonomous_system_organization"` ASO string `maxminddb:"autonomous_system_organization"`
@ -88,7 +100,7 @@ func parseIPFromMap42(ip net.IP) (record, APIResponse, error) {
APIResponse := APIResponse{ APIResponse := APIResponse{
Code: 0, Code: 0,
Detail: IPDetail{ Detail: IPDetail{
ISP: strings.Join(remove(remove(strings.Split(res.Area, "\t"), 6), 5), " "), ISP: strings.Join(removeEmpty(remove(remove(strings.Split(res.Area, "\t"), 6), 5)), " "),
}, },
} }
return record, APIResponse, nil return record, APIResponse, nil
@ -150,7 +162,8 @@ func printHop(hop traceroute.TracerouteHop) {
if _, Prefix, _ := net.ParseCIDR("172.16.0.0/12"); Prefix.Contains(ip) { if _, Prefix, _ := net.ParseCIDR("172.16.0.0/12"); Prefix.Contains(ip) {
info, APIResponse, err = parseIPFromMap42(ip) info, APIResponse, err = parseIPFromMap42(ip)
if err != nil { if err != nil {
log.Fatal(err) info = parseIPFromMaxminddb(ip)
APIResponse = parseIPFromBilibiliAPI(ip)
} }
} else { } else {
info = parseIPFromMaxminddb(ip) info = parseIPFromMaxminddb(ip)
@ -172,9 +185,9 @@ func printHop(hop traceroute.TracerouteHop) {
ASN = fmt.Sprintf("AS%v", info.ASN) ASN = fmt.Sprintf("AS%v", info.ASN)
} }
if APIResponse.Code != 0 { if APIResponse.Code != 0 {
fmt.Printf("%-3d %8v %15v (%v) %8v\n", hop.TTL, ASN, hostOrAddr, addr, hop.ElapsedTime.Round(time.Microsecond)) fmt.Printf("%-3d %8v %15v %18v %8v\n", hop.TTL, ASN, hostOrAddr, "("+addr+")", hop.ElapsedTime.Round(time.Microsecond))
} else { } else {
fmt.Printf("%-3d %8v %15v (%v) %8v ", hop.TTL, ASN, hostOrAddr, addr, hop.ElapsedTime.Round(time.Microsecond)) fmt.Printf("%-3d %8v %15v %18v %8v ", hop.TTL, ASN, hostOrAddr, "("+addr+")", hop.ElapsedTime.Round(time.Microsecond))
var flag bool = false var flag bool = false
@ -211,6 +224,11 @@ func printHop(hop traceroute.TracerouteHop) {
} else { } else {
fmt.Printf("%-3d *\n", hop.TTL) fmt.Printf("%-3d *\n", hop.TTL)
} }
if destIP.Equal(ip) {
isFinish = true
}
} }
func address(address [4]byte) string { func address(address [4]byte) string {
@ -237,6 +255,7 @@ func main() {
fmt.Printf("traceroute to %v (%v), %v hops max, %v byte packets, using ICMP methods.\n", host, ipAddr, options.MaxHops(), options.PacketSize()) fmt.Printf("traceroute to %v (%v), %v hops max, %v byte packets, using ICMP methods.\n", host, ipAddr, options.MaxHops(), options.PacketSize())
c := make(chan traceroute.TracerouteHop, 0) c := make(chan traceroute.TracerouteHop, 0)
d := make(chan bool)
go func() { go func() {
for { for {
hop, ok := <-c hop, ok := <-c
@ -245,14 +264,14 @@ func main() {
return return
} }
printHop(hop) printHop(hop)
d <- true
} }
}() }()
res, err := traceroute.Traceroute(ipAddr.String(), &options, c) destIP = ipAddr.IP
_, err = traceroute.Traceroute(ipAddr.String(), &options, d, c)
if err != nil { if err != nil {
fmt.Printf("Error: %v", err) fmt.Printf("Error: %v", err)
} }
printHop(res.Hops[len(res.Hops)-1])
} }