262 lines
6.2 KiB
Go
262 lines
6.2 KiB
Go
package search
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log/slog"
|
|
"math/rand/v2"
|
|
"sync"
|
|
"time"
|
|
|
|
"git.vakhrushev.me/av/remembos/internal/config"
|
|
"git.vakhrushev.me/av/remembos/internal/memos"
|
|
"git.vakhrushev.me/av/remembos/internal/storage"
|
|
)
|
|
|
|
// Selector implements the memory search algorithm.
|
|
type Selector struct {
|
|
client *memos.Client
|
|
store *storage.Storage
|
|
cfg config.SearchConfig
|
|
loc *time.Location
|
|
logger *slog.Logger
|
|
}
|
|
|
|
func NewSelector(client *memos.Client, store *storage.Storage, cfg *config.SearchConfig, loc *time.Location, logger *slog.Logger) *Selector {
|
|
return &Selector{
|
|
client: client,
|
|
store: store,
|
|
cfg: *cfg,
|
|
loc: loc,
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
// Select runs the full search algorithm and returns one Memory for the given day.
|
|
func (s *Selector) Select(ctx context.Context, today time.Time) (*Memory, error) {
|
|
weights := s.cfg.TierWeights.AsSlice()
|
|
tierFuncs := [7]func(time.Time) []DateRange{
|
|
func(t time.Time) []DateRange { return tier1Ranges(t, s.cfg.MaxYearsBack, s.loc) },
|
|
func(t time.Time) []DateRange { return tier2Ranges(t, s.loc) },
|
|
func(t time.Time) []DateRange { return tier3Ranges(t, s.cfg.MaxYearsBack, s.loc) },
|
|
func(t time.Time) []DateRange { return tier4Ranges(t, s.cfg.MaxYearsBack, s.loc) },
|
|
func(t time.Time) []DateRange { return tier5Ranges(t, s.cfg.MaxYearsBack, s.loc) },
|
|
func(t time.Time) []DateRange { return tier6Ranges(t, s.cfg.MaxYearsBack, s.loc) },
|
|
func(t time.Time) []DateRange { return tier7Ranges(t, s.loc) },
|
|
}
|
|
|
|
// Build a weighted order of tiers to try
|
|
order := weightedTierOrder(weights)
|
|
|
|
// Try with normal cooldown
|
|
mem, err := s.tryTiers(ctx, today, order, tierFuncs, s.cfg.CooldownDays)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if mem != nil {
|
|
return mem, nil
|
|
}
|
|
|
|
s.logger.Info("no candidates with normal cooldown, trying relaxed")
|
|
|
|
// Try with relaxed cooldown
|
|
mem, err = s.tryTiers(ctx, today, order, tierFuncs, s.cfg.RelaxedCooldownDays)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if mem != nil {
|
|
return mem, nil
|
|
}
|
|
|
|
s.logger.Info("no candidates with relaxed cooldown, full fallback")
|
|
|
|
// Full fallback: any memo
|
|
return s.fullFallback(ctx)
|
|
}
|
|
|
|
func (s *Selector) tryTiers(
|
|
ctx context.Context,
|
|
today time.Time,
|
|
order []int,
|
|
tierFuncs [7]func(time.Time) []DateRange,
|
|
cooldownDays int,
|
|
) (*Memory, error) {
|
|
cooldownSet, err := s.store.GetCooldownMemoNames(ctx, cooldownDays)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("get cooldown set: %w", err)
|
|
}
|
|
|
|
for _, tierIdx := range order {
|
|
tier := tierIdx + 1
|
|
ranges := tierFuncs[tierIdx](today)
|
|
if len(ranges) == 0 {
|
|
continue
|
|
}
|
|
|
|
allMemos, err := s.fetchRanges(ctx, ranges)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("fetch tier %d: %w", tier, err)
|
|
}
|
|
|
|
// Filter by cooldown
|
|
var filtered []*memos.Memo
|
|
for _, m := range allMemos {
|
|
if _, blocked := cooldownSet[m.Name]; !blocked {
|
|
filtered = append(filtered, m)
|
|
}
|
|
}
|
|
|
|
if len(filtered) == 0 {
|
|
s.logger.Debug("tier empty after cooldown filter", "tier", tier, "total", len(allMemos))
|
|
continue
|
|
}
|
|
|
|
// Get show counts for scoring
|
|
names := make([]string, len(filtered))
|
|
for i, m := range filtered {
|
|
names[i] = m.Name
|
|
}
|
|
showCounts, err := s.store.GetShowCounts(ctx, names)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("get show counts: %w", err)
|
|
}
|
|
|
|
// Build candidates
|
|
candidates := make([]candidate, len(filtered))
|
|
for i, m := range filtered {
|
|
yearsAgo := today.Year() - m.DisplayTime.Year()
|
|
if yearsAgo < 0 {
|
|
yearsAgo = 0
|
|
}
|
|
candidates[i] = candidate{
|
|
memo: m,
|
|
yearsAgo: yearsAgo,
|
|
showCount: showCounts[m.Name],
|
|
}
|
|
}
|
|
|
|
picked := weightedSelect(candidates, s.cfg.PreferOlder, s.cfg.MaxYearsBack)
|
|
if picked == nil {
|
|
continue
|
|
}
|
|
|
|
s.logger.Info("selected memory",
|
|
"tier", tier,
|
|
"memo", picked.memo.Name,
|
|
"years_ago", picked.yearsAgo,
|
|
)
|
|
|
|
return &Memory{
|
|
Memo: picked.memo,
|
|
Tier: tier,
|
|
YearsAgo: picked.yearsAgo,
|
|
ShowCount: picked.showCount,
|
|
Date: picked.memo.DisplayTime,
|
|
}, nil
|
|
}
|
|
|
|
return nil, nil
|
|
}
|
|
|
|
// fetchRanges queries Memos API for all date ranges concurrently.
|
|
func (s *Selector) fetchRanges(ctx context.Context, ranges []DateRange) ([]*memos.Memo, error) {
|
|
type result struct {
|
|
memos []*memos.Memo
|
|
err error
|
|
}
|
|
|
|
results := make([]result, len(ranges))
|
|
var wg sync.WaitGroup
|
|
|
|
// Limit concurrency to avoid overwhelming the API
|
|
sem := make(chan struct{}, 5)
|
|
|
|
for i, dr := range ranges {
|
|
wg.Add(1)
|
|
go func(idx int, dr DateRange) {
|
|
defer wg.Done()
|
|
sem <- struct{}{}
|
|
defer func() { <-sem }()
|
|
|
|
filter := BuildCELFilter(dr)
|
|
resp, err := s.client.ListMemos(ctx, filter, s.cfg.PageSize, "")
|
|
results[idx] = result{
|
|
memos: resp.GetMemos(),
|
|
err: err,
|
|
}
|
|
}(i, dr)
|
|
}
|
|
|
|
wg.Wait()
|
|
|
|
var all []*memos.Memo
|
|
for _, r := range results {
|
|
if r.err != nil {
|
|
return nil, r.err
|
|
}
|
|
all = append(all, r.memos...)
|
|
}
|
|
|
|
return all, nil
|
|
}
|
|
|
|
func (s *Selector) fullFallback(ctx context.Context) (*Memory, error) {
|
|
memo, err := s.client.GetRandomMemo(ctx)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("full fallback: %w", err)
|
|
}
|
|
if memo == nil {
|
|
return nil, nil
|
|
}
|
|
return &Memory{
|
|
Memo: memo,
|
|
Tier: 0,
|
|
Date: memo.DisplayTime,
|
|
}, nil
|
|
}
|
|
|
|
// weightedTierOrder returns tier indices (0-based) shuffled by weight.
|
|
// Higher weight tiers come first, with randomization within the ordering.
|
|
func weightedTierOrder(weights [7]int) []int {
|
|
type entry struct {
|
|
idx int
|
|
weight int
|
|
}
|
|
entries := make([]entry, 7)
|
|
for i, w := range weights {
|
|
entries[i] = entry{idx: i, weight: w}
|
|
}
|
|
|
|
// Shuffle using weighted random selection without replacement
|
|
order := make([]int, 0, 7)
|
|
remaining := make([]entry, len(entries))
|
|
copy(remaining, entries)
|
|
|
|
for len(remaining) > 0 {
|
|
totalWeight := 0
|
|
for _, e := range remaining {
|
|
totalWeight += e.weight
|
|
}
|
|
if totalWeight == 0 {
|
|
// All remaining have zero weight, just append them
|
|
for _, e := range remaining {
|
|
order = append(order, e.idx)
|
|
}
|
|
break
|
|
}
|
|
|
|
r := rand.IntN(totalWeight) //nolint:gosec // non-cryptographic use
|
|
cumulative := 0
|
|
for i, e := range remaining {
|
|
cumulative += e.weight
|
|
if r < cumulative {
|
|
order = append(order, e.idx)
|
|
remaining = append(remaining[:i], remaining[i+1:]...)
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
return order
|
|
}
|