Files
remembos/internal/search/selector.go
T
av 868c90c896
release / docker-image (push) Successful in 1m9s
release / goreleaser (push) Successful in 10m14s
fix memo selection
2026-02-12 18:33:05 +03:00

295 lines
6.9 KiB
Go

package search
import (
"context"
"fmt"
"log/slog"
"math/rand/v2"
"sync"
"time"
"git.vakhrushev.me/av/remembos/internal/config"
"git.vakhrushev.me/av/remembos/internal/memos"
"git.vakhrushev.me/av/remembos/internal/storage"
)
// Selector implements the memory search algorithm.
type Selector struct {
client *memos.Client
store *storage.Storage
cfg config.SearchConfig
loc *time.Location
logger *slog.Logger
}
func NewSelector(client *memos.Client, store *storage.Storage, cfg *config.SearchConfig, loc *time.Location, logger *slog.Logger) *Selector {
return &Selector{
client: client,
store: store,
cfg: *cfg,
loc: loc,
logger: logger,
}
}
// Select runs the full search algorithm and returns one Memory for the given day.
func (s *Selector) Select(ctx context.Context, today time.Time) (*Memory, error) {
weights := s.cfg.TierWeights.AsSlice()
tierFuncs := [7]func(time.Time) []DateRange{
func(t time.Time) []DateRange { return tier1Ranges(t, s.cfg.MaxYearsBack, s.loc) },
func(t time.Time) []DateRange { return tier2Ranges(t, s.loc) },
func(t time.Time) []DateRange { return tier3Ranges(t, s.cfg.MaxYearsBack, s.loc) },
func(t time.Time) []DateRange { return tier4Ranges(t, s.cfg.MaxYearsBack, s.loc) },
func(t time.Time) []DateRange { return tier5Ranges(t, s.cfg.MaxYearsBack, s.loc) },
func(t time.Time) []DateRange { return tier6Ranges(t, s.cfg.MaxYearsBack, s.loc) },
func(t time.Time) []DateRange { return tier7Ranges(t, s.loc) },
}
// Build a weighted order of tiers to try
order := weightedTierOrder(weights)
// Try with normal cooldown
mem, err := s.tryTiers(ctx, today, order, tierFuncs, s.cfg.CooldownDays)
if err != nil {
return nil, err
}
if mem != nil {
return mem, nil
}
s.logger.Info("no candidates with normal cooldown, trying relaxed")
// Try with relaxed cooldown
mem, err = s.tryTiers(ctx, today, order, tierFuncs, s.cfg.RelaxedCooldownDays)
if err != nil {
return nil, err
}
if mem != nil {
return mem, nil
}
s.logger.Info("no candidates with relaxed cooldown, full fallback")
// Full fallback: any memo
return s.fullFallback(ctx)
}
func (s *Selector) tryTiers(
ctx context.Context,
today time.Time,
order []int,
tierFuncs [7]func(time.Time) []DateRange,
cooldownDays int,
) (*Memory, error) {
cooldownSet, err := s.store.GetCooldownMemoNames(ctx, cooldownDays)
if err != nil {
return nil, fmt.Errorf("get cooldown set: %w", err)
}
for _, tierIdx := range order {
tier := tierIdx + 1
ranges := tierFuncs[tierIdx](today)
if len(ranges) == 0 {
continue
}
allMemos, err := s.fetchRanges(ctx, ranges)
if err != nil {
return nil, fmt.Errorf("fetch tier %d: %w", tier, err)
}
// Filter by cooldown
var filtered []*memos.Memo
for _, m := range allMemos {
if _, blocked := cooldownSet[m.Name]; !blocked {
filtered = append(filtered, m)
}
}
if len(filtered) == 0 {
s.logger.Debug("tier empty after cooldown filter", "tier", tier, "total", len(allMemos))
continue
}
// Get show counts for scoring
names := make([]string, len(filtered))
for i, m := range filtered {
names[i] = m.Name
}
showCounts, err := s.store.GetShowCounts(ctx, names)
if err != nil {
return nil, fmt.Errorf("get show counts: %w", err)
}
// Build candidates
candidates := make([]candidate, len(filtered))
for i, m := range filtered {
yearsAgo := today.Year() - m.DisplayTime.Year()
if yearsAgo < 0 {
yearsAgo = 0
}
candidates[i] = candidate{
memo: m,
yearsAgo: yearsAgo,
showCount: showCounts[m.Name],
}
}
picked := weightedSelect(candidates, s.cfg.PreferOlder, s.cfg.MaxYearsBack)
if picked == nil {
continue
}
s.logger.Info("selected memory",
"tier", tier,
"memo", picked.memo.Name,
"years_ago", picked.yearsAgo,
)
return &Memory{
Memo: picked.memo,
Tier: tier,
YearsAgo: picked.yearsAgo,
ShowCount: picked.showCount,
Date: picked.memo.DisplayTime,
}, nil
}
return nil, nil
}
// fetchRanges queries Memos API for all date ranges concurrently.
func (s *Selector) fetchRanges(ctx context.Context, ranges []DateRange) ([]*memos.Memo, error) {
type result struct {
memos []*memos.Memo
err error
}
results := make([]result, len(ranges))
var wg sync.WaitGroup
// Limit concurrency to avoid overwhelming the API
sem := make(chan struct{}, 5)
for i, dr := range ranges {
wg.Add(1)
go func(idx int, dr DateRange) {
defer wg.Done()
sem <- struct{}{}
defer func() { <-sem }()
filter := BuildCELFilter(dr)
resp, err := s.client.ListMemos(ctx, filter, s.cfg.PageSize, "")
results[idx] = result{
memos: resp.GetMemos(),
err: err,
}
}(i, dr)
}
wg.Wait()
var all []*memos.Memo
for _, r := range results {
if r.err != nil {
return nil, r.err
}
all = append(all, r.memos...)
}
return all, nil
}
func (s *Selector) fullFallback(ctx context.Context) (*Memory, error) {
resp, err := s.client.ListMemos(ctx, "", s.cfg.PageSize, "")
if err != nil {
return nil, fmt.Errorf("full fallback: %w", err)
}
allMemos := resp.GetMemos()
if len(allMemos) == 0 {
return nil, nil
}
// Try to apply relaxed cooldown
cooldownSet, err := s.store.GetCooldownMemoNames(ctx, s.cfg.RelaxedCooldownDays)
if err != nil {
return nil, fmt.Errorf("full fallback cooldown: %w", err)
}
var filtered []*memos.Memo
for _, m := range allMemos {
if _, blocked := cooldownSet[m.Name]; !blocked {
filtered = append(filtered, m)
}
}
// If nothing survives cooldown, use all memos
if len(filtered) == 0 {
filtered = allMemos
}
candidates := make([]candidate, len(filtered))
for i, m := range filtered {
candidates[i] = candidate{memo: m}
}
picked := weightedSelect(candidates, false, 0)
if picked == nil {
return nil, nil
}
s.logger.Info("selected memory via full fallback", "memo", picked.memo.Name)
return &Memory{
Memo: picked.memo,
Tier: 0,
Date: picked.memo.DisplayTime,
}, nil
}
// weightedTierOrder returns tier indices (0-based) shuffled by weight.
// Higher weight tiers come first, with randomization within the ordering.
func weightedTierOrder(weights [7]int) []int {
type entry struct {
idx int
weight int
}
entries := make([]entry, 7)
for i, w := range weights {
entries[i] = entry{idx: i, weight: w}
}
// Shuffle using weighted random selection without replacement
order := make([]int, 0, 7)
remaining := make([]entry, len(entries))
copy(remaining, entries)
for len(remaining) > 0 {
totalWeight := 0
for _, e := range remaining {
totalWeight += e.weight
}
if totalWeight == 0 {
// All remaining have zero weight, just append them
for _, e := range remaining {
order = append(order, e.idx)
}
break
}
r := rand.IntN(totalWeight) //nolint:gosec // non-cryptographic use
cumulative := 0
for i, e := range remaining {
cumulative += e.weight
if r < cumulative {
order = append(order, e.idx)
remaining = append(remaining[:i], remaining[i+1:]...)
break
}
}
}
return order
}