diff --git a/check.go b/check.go index 4e067e9..7820a39 100644 --- a/check.go +++ b/check.go @@ -6,13 +6,13 @@ import ( "crypto/sha1" "crypto/sha256" "crypto/sha512" - "errors" "fmt" "hash" "io" "log" "os" "strings" + "sync" ) type checksum struct { @@ -43,61 +43,113 @@ func parseCS(line string) (checksum, error) { return checksum{filename: f, hash: hsh, checksum: cs}, nil } -func check(files []string) error { - streams := []io.ReadCloser{} - defer func() { - for _, stream := range streams { - stream.Close() +type input struct { + f io.ReadCloser + err error +} + +type work struct { + cs checksum + err error +} + +func streams(files []string) chan input { + r := make(chan input) + + go func() { + for _, name := range files { + f, err := os.Open(name) + r <- input{f, err} } + if len(files) == 0 { + r <- input{f: os.Stdin} + } + close(r) }() - for _, name := range files { - f, err := os.Open(name) - if err != nil { - return err - } - streams = append(streams, f) - } - if len(files) == 0 { - streams = append(streams, os.Stdin) - } - - jobs := []checksum{} - for _, stream := range streams { - s := bufio.NewScanner(stream) - for s.Scan() { - cs, err := parseCS(s.Text()) - if err != nil { - return err - } - jobs = append(jobs, cs) - } - if s.Err() != nil { - return s.Err() - } - } - - errs := 0 - for _, job := range jobs { - f, err := os.Open(job.filename) - if err != nil { - return fmt.Errorf("open: %v", err) - } - if _, err := io.Copy(job.hash, f); err != nil { - log.Printf("%+v", err) - } - f.Close() - if fmt.Sprintf("%x", job.hash.Sum(nil)) == job.checksum { - fmt.Printf("%s: OK\n", job.filename) - } else { - errs++ - fmt.Fprintf(os.Stderr, "%s: bad\n", job.filename) - } - } - - var err error - if errs != 0 { - err = errors.New("bad files found") - } - return err + return r +} + +func check(files []string) chan error { + jobs := make(chan work) + + go func() { + for stream := range streams(files) { + if stream.err != nil { + jobs <- work{err: stream.err} + break + } + s := bufio.NewScanner(stream.f) + for s.Scan() { + cs, err := parseCS(s.Text()) + jobs <- work{cs, err} + } + stream.f.Close() + if s.Err() != nil { + jobs <- work{err: s.Err()} + } + } + close(jobs) + }() + + results := []<-chan error{} + + workers := 8 + for w := 0; w < workers; w++ { + results = append(results, compute(jobs)) + } + + return merge(results) +} + +func merge(cs []<-chan error) chan error { + out := make(chan error) + + var wg sync.WaitGroup + + output := func(c <-chan error) { + for n := range c { + out <- n + } + wg.Done() + } + + wg.Add(len(cs)) + for _, c := range cs { + go output(c) + } + + go func() { + wg.Wait() + close(out) + }() + return out +} + +func compute(jobs chan work) chan error { + r := make(chan error) + go func() { + for job := range jobs { + if job.err != nil { + log.Printf("%+v", job.err) + continue + } + f, err := os.Open(job.cs.filename) + if err != nil { + r <- fmt.Errorf("open: %v", err) + continue + } + if _, err := io.Copy(job.cs.hash, f); err != nil { + log.Printf("%+v", err) + } + f.Close() + if fmt.Sprintf("%x", job.cs.hash.Sum(nil)) == job.cs.checksum { + fmt.Printf("%s: OK\n", job.cs.filename) + } else { + r <- fmt.Errorf("%s: bad", job.cs.filename) + } + } + close(r) + }() + return r } diff --git a/main.go b/main.go index cd2de4d..c261dd5 100644 --- a/main.go +++ b/main.go @@ -19,8 +19,12 @@ func main() { files := flag.Args() switch *mode { case true: - if err := check(files); err != nil { + c := 0 + for err := range check(files) { + c++ fmt.Fprintf(os.Stderr, "%v\n", err) + } + if c > 0 { os.Exit(1) } case false: