diff --git a/check.go b/check.go index 1586436..1475b47 100644 --- a/check.go +++ b/check.go @@ -95,7 +95,7 @@ func check(files []string) chan error { results := []<-chan error{} for w := 0; w < *ngo; w++ { - results = append(results, compute(jobs)) + results = append(results, verify(jobs)) } return merge(results) @@ -125,7 +125,7 @@ func merge(cs []<-chan error) chan error { return out } -func compute(jobs chan work) chan error { +func verify(jobs chan work) chan error { r := make(chan error) go func() { for job := range jobs { diff --git a/main.go b/main.go index 6bd0464..2588d10 100644 --- a/main.go +++ b/main.go @@ -7,9 +7,11 @@ import ( "crypto/sha512" "flag" "fmt" + "hash" "io" "os" "runtime" + "sync" ) var algo = flag.String("a", "sha1", "algorithm to use") @@ -30,51 +32,118 @@ func main() { os.Exit(1) } case false: - if err := hsh(files); err != nil { - fmt.Fprintf(os.Stderr, "%v\n", err) + c := 0 + for res := range hsh(files) { + if res.err != nil { + c++ + fmt.Fprintf(os.Stderr, "%v\n", res.err) + } else { + fmt.Printf("%v\n", res.msg) + } + } + if c > 0 { os.Exit(1) } } } -func hsh(files []string) error { - h := sha256.New() +type hashr func() hash.Hash + +func hsh(files []string) chan result { + var h hashr switch *algo { case "sha1", "1": - h = sha1.New() + h = sha1.New case "sha256", "256": - h = sha256.New() + h = sha256.New case "sha512", "512": - h = sha512.New() + h = sha512.New case "md5": - h = md5.New() + h = md5.New default: - return fmt.Errorf("unsupported algorithm: %v", *algo) + r := make(chan result) + go func() { + r <- result{err: fmt.Errorf("unsupported algorithm: %v", *algo)} + }() + return r } if len(files) == 0 { - _, err := io.Copy(h, os.Stdin) + hsh := h() + _, err := io.Copy(hsh, os.Stdin) if err != nil { fmt.Fprintf(os.Stderr, "%v\n", err) os.Exit(1) } - fmt.Printf("%x -\n", h.Sum(nil)) - } else { + fmt.Printf("%x -\n", hsh.Sum(nil)) + return nil + } + + jobs := make(chan work) + go func() { for _, name := range files { - f, err := os.Open(name) + jobs <- work{cs: checksum{filename: name}} + } + close(jobs) + }() + + res := []<-chan result{} + for w := 0; w < *ngo; w++ { + res = append(res, compute(h, jobs)) + } + + return rmerge(res) +} + +type result struct { + msg string + err error +} + +func compute(h hashr, jobs chan work) chan result { + hsh := h() + r := make(chan result) + go func() { + for job := range jobs { + f, err := os.Open(job.cs.filename) if err != nil { - fmt.Fprintf(os.Stderr, "%v\n", err) + r <- result{err: err} continue } - h.Reset() - _, err = io.Copy(h, f) + hsh.Reset() + _, err = io.Copy(hsh, f) f.Close() if err != nil { - fmt.Fprintf(os.Stderr, "%v\n", err) + r <- result{err: err} continue } - fmt.Printf("%x %s\n", h.Sum(nil), name) + r <- result{msg: fmt.Sprintf("%x %s", hsh.Sum(nil), job.cs.filename)} } - } - return nil + close(r) + }() + return r +} + +func rmerge(cs []<-chan result) chan result { + out := make(chan result) + + var wg sync.WaitGroup + + output := func(c <-chan result) { + for n := range c { + out <- n + } + wg.Done() + } + + wg.Add(len(cs)) + for _, c := range cs { + go output(c) + } + + go func() { + wg.Wait() + close(out) + }() + return out }