diff --git a/go.mod b/go.mod index f7a0ad2..e2fb7d7 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,5 @@ module git.vakhrushev.me/av/trackers go 1.25.5 + +require github.com/pelletier/go-toml/v2 v2.2.4 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..3cf50e1 --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= +github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= diff --git a/main.go b/main.go new file mode 100644 index 0000000..7805f1c --- /dev/null +++ b/main.go @@ -0,0 +1,316 @@ +package main + +import ( + "context" + "crypto/sha1" + "encoding/hex" + "errors" + "flag" + "fmt" + "io" + "log" + "net/http" + "net/url" + "os" + "os/signal" + "path/filepath" + "sort" + "strings" + "sync" + "syscall" + "time" + + toml "github.com/pelletier/go-toml/v2" +) + +type Config struct { + Port int `toml:"port"` + PollInterval string `toml:"poll_interval"` + CacheDir string `toml:"cache_dir"` + Sources []string `toml:"sources"` +} + +func loadConfig(path string) (Config, time.Duration, error) { + data, err := os.ReadFile(path) + if err != nil { + return Config{}, 0, fmt.Errorf("read config: %w", err) + } + + var cfg Config + if err := toml.Unmarshal(data, &cfg); err != nil { + return Config{}, 0, fmt.Errorf("parse config: %w", err) + } + + if cfg.Port == 0 { + cfg.Port = 8080 + } + if cfg.CacheDir == "" { + cfg.CacheDir = "cache" + } + intervalText := cfg.PollInterval + if intervalText == "" { + intervalText = "60m" + } + + interval, err := time.ParseDuration(intervalText) + if err != nil { + return Config{}, 0, fmt.Errorf("parse poll_interval: %w", err) + } + if interval <= 0 { + return Config{}, 0, errors.New("poll_interval must be positive") + } + + if len(cfg.Sources) == 0 { + return Config{}, 0, errors.New("no sources configured") + } + + return cfg, interval, nil +} + +type Aggregator struct { + mu sync.RWMutex + perSource map[string]map[string]struct{} +} + +func NewAggregator() *Aggregator { + return &Aggregator{perSource: make(map[string]map[string]struct{})} +} + +func (a *Aggregator) Update(source string, links []string) { + set := make(map[string]struct{}, len(links)) + for _, link := range links { + set[link] = struct{}{} + } + + a.mu.Lock() + a.perSource[source] = set + a.mu.Unlock() +} + +func (a *Aggregator) List() []string { + a.mu.RLock() + defer a.mu.RUnlock() + + combined := make(map[string]struct{}) + for _, set := range a.perSource { + for link := range set { + combined[link] = struct{}{} + } + } + + list := make([]string, 0, len(combined)) + for link := range combined { + list = append(list, link) + } + + sort.Strings(list) + return list +} + +func main() { + configPath := flag.String("config", "config.toml", "path to config file") + flag.Parse() + + cfg, interval, err := loadConfig(*configPath) + if err != nil { + log.Fatalf("config error: %v", err) + } + + if err := os.MkdirAll(cfg.CacheDir, 0o755); err != nil { + log.Fatalf("cache dir: %v", err) + } + + agg := NewAggregator() + + client := &http.Client{Timeout: 15 * time.Second} + + for _, source := range cfg.Sources { + cached, err := loadCachedLinks(cfg.CacheDir, source) + if err != nil { + log.Printf("load cache for %s: %v", source, err) + } + if len(cached) > 0 { + agg.Update(source, cached) + } + } + + ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer stop() + + for _, source := range cfg.Sources { + go pollSource(ctx, source, interval, cfg.CacheDir, agg, client) + } + + mux := http.NewServeMux() + mux.HandleFunc("/list", func(w http.ResponseWriter, r *http.Request) { + links := agg.List() + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + for i, link := range links { + if i > 0 { + _, _ = w.Write([]byte("\n")) + } + _, _ = w.Write([]byte(link)) + } + }) + + server := &http.Server{ + Addr: fmt.Sprintf(":%d", cfg.Port), + Handler: mux, + ReadHeaderTimeout: 5 * time.Second, + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + } + + go func() { + <-ctx.Done() + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := server.Shutdown(shutdownCtx); err != nil { + log.Printf("server shutdown: %v", err) + } + }() + + log.Printf("listening on :%d", cfg.Port) + if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + log.Fatalf("server error: %v", err) + } +} + +func pollSource(ctx context.Context, source string, interval time.Duration, cacheDir string, agg *Aggregator, client *http.Client) { + ticker := time.NewTicker(interval) + defer ticker.Stop() + + runOnce(ctx, source, cacheDir, agg, client) + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + runOnce(ctx, source, cacheDir, agg, client) + } + } +} + +func runOnce(ctx context.Context, source string, cacheDir string, agg *Aggregator, client *http.Client) { + links, err := fetchSource(ctx, source, client) + if err != nil { + log.Printf("poll %s: %v", source, err) + return + } + if len(links) == 0 { + agg.Update(source, nil) + if err := writeCache(cacheDir, source, nil); err != nil { + log.Printf("write cache %s: %v", source, err) + } + return + } + + agg.Update(source, links) + if err := writeCache(cacheDir, source, links); err != nil { + log.Printf("write cache %s: %v", source, err) + } +} + +func fetchSource(ctx context.Context, source string, client *http.Client) ([]string, error) { + u, err := url.Parse(source) + if err != nil { + return nil, fmt.Errorf("invalid source url: %w", err) + } + + switch u.Scheme { + case "http", "https": + req, err := http.NewRequestWithContext(ctx, http.MethodGet, source, nil) + if err != nil { + return nil, fmt.Errorf("build request: %w", err) + } + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("request failed: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("unexpected status: %s", resp.Status) + } + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read body: %w", err) + } + return normalizeLinks(string(body)), nil + case "file": + path := u.Path + if path == "" { + return nil, errors.New("file source path is empty") + } + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("read file: %w", err) + } + return normalizeLinks(string(data)), nil + default: + return nil, fmt.Errorf("unsupported source scheme: %s", u.Scheme) + } +} + +func normalizeLinks(content string) []string { + rawLines := strings.Split(content, "\n") + set := make(map[string]struct{}) + + for _, line := range rawLines { + link := strings.TrimSpace(line) + if link == "" { + continue + } + if !isValidTrackerLink(link) { + continue + } + set[link] = struct{}{} + } + + result := make([]string, 0, len(set)) + for link := range set { + result = append(result, link) + } + sort.Strings(result) + return result +} + +func isValidTrackerLink(link string) bool { + u, err := url.Parse(link) + if err != nil { + return false + } + + switch u.Scheme { + case "http", "https", "udp", "ws", "wss": + default: + return false + } + + if u.Hostname() == "" { + return false + } + + return true +} + +func cacheFilePath(cacheDir, source string) string { + sum := sha1.Sum([]byte(source)) + filename := hex.EncodeToString(sum[:]) + ".txt" + return filepath.Join(cacheDir, filename) +} + +func writeCache(cacheDir, source string, links []string) error { + path := cacheFilePath(cacheDir, source) + data := strings.Join(links, "\n") + return os.WriteFile(path, []byte(data), 0o644) +} + +func loadCachedLinks(cacheDir, source string) ([]string, error) { + path := cacheFilePath(cacheDir, source) + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + return normalizeLinks(string(data)), nil +}