diff --git a/cmd/crawl/main.go b/cmd/crawl/main.go index 8de42bd..b18ffc4 100644 --- a/cmd/crawl/main.go +++ b/cmd/crawl/main.go @@ -2,6 +2,8 @@ package main import ( "fmt" + "io" + "io/ioutil" "net/http" "os" @@ -20,9 +22,14 @@ func main() { for p := range spider.Pages(os.Args[1]) { resp, err := http.Get(p.To) if err != nil { + p.Err = err failures = append(failures, p) + continue } + io.Copy(ioutil.Discard, resp.Body) + resp.Body.Close() if resp.StatusCode != http.StatusOK { + p.Err = fmt.Errorf("http status; got %s, want %s", http.StatusText(resp.StatusCode), http.StatusText(http.StatusOK)) failures = append(failures, p) } } diff --git a/cmd/lnks/main.go b/cmd/lnks/main.go index 01eaf48..5a3f427 100644 --- a/cmd/lnks/main.go +++ b/cmd/lnks/main.go @@ -2,6 +2,7 @@ package main import ( "fmt" + "net/http" "os" "mcquay.me/spider" @@ -14,7 +15,15 @@ func main() { fmt.Fprintf(os.Stderr, "%s\n", usage) os.Exit(1) } - for _, l := range spider.URLs(os.Args[1]) { + resp, err := http.Get(os.Args[1]) + if err != nil { + panic(err) + } + links, err := spider.URLs(resp.Body) + if err != nil { + panic(err) + } + for _, l := range links { fmt.Println(l) } } diff --git a/spider.go b/spider.go index 35eae71..6971b2b 100644 --- a/spider.go +++ b/spider.go @@ -2,33 +2,27 @@ package spider import ( "fmt" + "io" "net/http" - "os" "strings" "golang.org/x/net/html" ) // URLs returns all links on a page -func URLs(url string) []Link { - resp, err := http.Get(url) +func URLs(page io.Reader) ([]string, error) { + doc, err := html.Parse(page) if err != nil { - fmt.Fprintf(os.Stderr, "%v\n", err) - os.Exit(1) + return nil, fmt.Errorf("parsing html: %v", err) } - doc, err := html.Parse(resp.Body) - if err != nil { - fmt.Fprintf(os.Stderr, "%v\n", err) - os.Exit(1) - } - paths := []Link{} + paths := []string{} var f func(*html.Node) f = func(n *html.Node) { if n.Type == html.ElementNode { for _, a := range n.Attr { switch a.Key { case "href", "src": - paths = append(paths, Link{From: url, To: a.Val}) + paths = append(paths, a.Val) break } } @@ -38,25 +32,52 @@ func URLs(url string) []Link { } } f(doc) - return paths + return paths, nil } type Link struct { From string To string + + Err error } func (l Link) String() string { - return fmt.Sprintf("%s > %s", l.From, l.To) + r := fmt.Sprintf("%s > %s", l.From, l.To) + if l.Err != nil { + r = fmt.Sprintf("%v (%v)", r, l.Err) + } + return r } // Pages returns a stream of full urls starting at a given base page. func Pages(base string) <-chan Link { + r := make(chan Link) + base = strings.TrimRight(base, "/") visited := map[string]bool{base: true} - links := URLs(base) + links := []Link{} - r := make(chan Link) + resp, err := http.Get(base) + if err != nil { + go func() { + r <- Link{To: base, From: "start", Err: err} + close(r) + }() + return r + } + lks, err := URLs(resp.Body) + if err != nil { + go func() { + r <- Link{To: base, From: "start", Err: err} + close(r) + }() + return r + } + + for _, l := range lks { + links = append(links, Link{From: base, To: l}) + } go func() { for len(links) > 0 { @@ -77,8 +98,18 @@ func Pages(base string) <-chan Link { if _, ok := visited[l.To]; !ok { r <- l - for _, lk := range URLs(l.To) { - links = append(links, lk) + + resp, err := http.Get(l.To) + if err != nil { + panic(err) + } + lks, err := URLs(resp.Body) + if err != nil { + panic(err) + } + + for _, lk := range lks { + links = append(links, Link{From: l.From, To: lk}) } } visited[l.To] = true