diff --git a/check.go b/check.go new file mode 100644 index 0000000..4e067e9 --- /dev/null +++ b/check.go @@ -0,0 +1,103 @@ +package main + +import ( + "bufio" + "crypto/md5" + "crypto/sha1" + "crypto/sha256" + "crypto/sha512" + "errors" + "fmt" + "hash" + "io" + "log" + "os" + "strings" +) + +type checksum struct { + filename string + hash hash.Hash + checksum string +} + +func parseCS(line string) (checksum, error) { + elems := strings.Fields(line) + if len(elems) != 2 { + return checksum{}, fmt.Errorf("unexpected content: %d != 2", len(elems)) + } + cs, f := elems[0], elems[1] + var hsh hash.Hash + switch len(cs) { + case 32: + hsh = md5.New() + case 40: + hsh = sha1.New() + case 64: + hsh = sha256.New() + case 128: + hsh = sha512.New() + default: + return checksum{}, fmt.Errorf("unknown format: %q", line) + } + return checksum{filename: f, hash: hsh, checksum: cs}, nil +} + +func check(files []string) error { + streams := []io.ReadCloser{} + defer func() { + for _, stream := range streams { + stream.Close() + } + }() + + for _, name := range files { + f, err := os.Open(name) + if err != nil { + return err + } + streams = append(streams, f) + } + if len(files) == 0 { + streams = append(streams, os.Stdin) + } + + jobs := []checksum{} + for _, stream := range streams { + s := bufio.NewScanner(stream) + for s.Scan() { + cs, err := parseCS(s.Text()) + if err != nil { + return err + } + jobs = append(jobs, cs) + } + if s.Err() != nil { + return s.Err() + } + } + + errs := 0 + for _, job := range jobs { + f, err := os.Open(job.filename) + if err != nil { + return fmt.Errorf("open: %v", err) + } + if _, err := io.Copy(job.hash, f); err != nil { + log.Printf("%+v", err) + } + f.Close() + if fmt.Sprintf("%x", job.hash.Sum(nil)) == job.checksum { + fmt.Printf("%s: OK\n", job.filename) + } else { + errs++ + fmt.Fprintf(os.Stderr, "%s: bad\n", job.filename) + } + } + + var err error + if errs != 0 { + err = errors.New("bad files found") + } + return err +} diff --git a/main.go b/main.go index 5a5b019..cd2de4d 100644 --- a/main.go +++ b/main.go @@ -12,10 +12,26 @@ import ( ) var algo = flag.String("a", "sha1", "algorithm to use") +var mode = flag.Bool("c", false, "check") func main() { flag.Parse() files := flag.Args() + switch *mode { + case true: + if err := check(files); err != nil { + fmt.Fprintf(os.Stderr, "%v\n", err) + os.Exit(1) + } + case false: + if err := hsh(files); err != nil { + fmt.Fprintf(os.Stderr, "%v\n", err) + os.Exit(1) + } + } +} + +func hsh(files []string) error { h := sha256.New() switch *algo { case "sha1", "1": @@ -27,8 +43,7 @@ func main() { case "md5": h = md5.New() default: - fmt.Fprintf(os.Stderr, "unsupported algorithm: %v\n", *algo) - os.Exit(1) + return fmt.Errorf("unsupported algorithm: %v", *algo) } if len(files) == 0 { @@ -55,4 +70,5 @@ func main() { fmt.Printf("%x %s\n", h.Sum(nil), name) } } + return nil }