from __future__ import annotations

import json
import os
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional

import torch
from sentence_transformers import SentenceTransformer, util
from tqdm import tqdm

from ..utils.io import read_json, write_json


@dataclass
class AnalyzeConfig:
    model_name: str = "all-MiniLM-L6-v2"
    trend_lookahead_days: int = 3
    trend_threshold: float = 0.05
    device: Optional[str] = None  # "cuda" or "cpu"
    encoding_cache_file: Optional[str] = None


class EncodingCache:
    def __init__(self, path: Optional[str]) -> None:
        self.path = path
        self._data: Dict[str, List[float]] = {}
        if path and os.path.exists(path):
            try:
                with open(path, "r", encoding="utf-8") as f:
                    self._data = json.load(f)
            except json.JSONDecodeError:
                self._data = {}

    def get(self, key: str) -> Optional[List[float]]:
        return self._data.get(key)

    def set(self, key: str, value: List[float]) -> None:
        self._data[key] = value
        self.save()

    def save(self) -> None:
        if not self.path:
            return
        os.makedirs(os.path.dirname(self.path) or ".", exist_ok=True)
        with open(self.path, "w", encoding="utf-8") as f:
            json.dump(self._data, f)


def _calculate_trend(history: List[Dict[str, Any]], lookahead_days: int, threshold: float) -> None:
    for i, hist in enumerate(history):
        if "center" not in hist or "start_time" not in hist:
            hist["trend"] = "Still"
            continue
        current_center = hist["center"]
        start_date = datetime.strptime(hist["start_time"], "%Y-%m-%d")
        hist["trend"] = "Still"
        for future_hist in history[i + 1 :]:
            if "center" in future_hist and "start_time" in future_hist:
                future_date = datetime.strptime(future_hist["start_time"], "%Y-%m-%d")
                if (future_date - start_date).days > lookahead_days:
                    break
                center_diff = future_hist["center"] - current_center
                if center_diff > threshold:
                    hist["trend"] = "Up"
                    break
                elif center_diff < -threshold:
                    hist["trend"] = "Down"
                    break


def _find_best_news(hist: Dict[str, Any], model: SentenceTransformer, cache: EncodingCache) -> None:
    if "comment_text" not in hist or "news" not in hist or not hist["news"]:
        return
    comment = hist["comment_text"]
    cached = cache.get(comment)
    if cached is not None:
        comment_embedding = torch.tensor(cached, device=model.device)
    else:
        comment_embedding = model.encode(comment, convert_to_tensor=True)
        cache.set(comment, comment_embedding.tolist())

    news_embeddings: List[torch.Tensor] = []
    for news in hist["news"]:
        cached_news = cache.get(news)
        if cached_news is not None:
            news_embeddings.append(torch.tensor(cached_news, device=model.device))
        else:
            emb = model.encode(news, convert_to_tensor=True)
            cache.set(news, emb.tolist())
            news_embeddings.append(emb)

    if not news_embeddings:
        return
    news_embeddings_t = torch.stack(news_embeddings)
    similarities = util.pytorch_cos_sim(comment_embedding, news_embeddings_t)[0]
    best_idx = int(torch.argmax(similarities).item())
    hist["best_news"] = hist["news"][best_idx]


def analyze_news(
    input_file: str,
    output_file: str,
    cfg: Optional[AnalyzeConfig] = None,
) -> None:
    cfg = cfg or AnalyzeConfig()
    device = cfg.device or ("cuda" if torch.cuda.is_available() else "cpu")
    model = SentenceTransformer(cfg.model_name, device=device)
    cache = EncodingCache(cfg.encoding_cache_file)

    questions_data = read_json(input_file)
    processed: List[Dict[str, Any]] = []

    for entry in tqdm(questions_data, desc="Processing Entries"):
        _calculate_trend(entry.get("history", []), cfg.trend_lookahead_days, cfg.trend_threshold)
        for hist in entry.get("history", []) or []:
            _find_best_news(hist, model, cache)
        processed.append(entry)
        write_json(output_file, processed, indent=2)  # incremental save similar to original
