from __future__ import annotations

import os, time, math, json, hashlib, argparse, importlib, random
from collections import Counter
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import openai
import textgrad as tg

from llm_provider import get_provider

openai.api_key = os.getenv("OPENAI_API_KEY")


# =========================
# Config & CLI
# =========================


@dataclass
class Config:
    dataset: str = "glue"
    task: str = "sst2"
    model_name: str = "gpt-4o"

    gamma: float = 0.65
    confidence_metric: str = "vote"  # vote | logprob

    pseudo_cache_path: Optional[str] = None
    reuse_pseudo: bool = False
    regen_strategy: str = "missing"  # all | missing | low_conf
    regen_threshold: float = 0.5
    pseudo_sample_size: Optional[int] = None

    seed: int = 42

    provider: str = "openai"
    temperature: float = 0.3
    sc_paths: int = 5
    top_logprobs: int = 5
    rate_limit_rps: float = 3.0

    K: int = 4
    embed_dim: int = 768

    @staticmethod
    def from_args() -> "Config":
        p = argparse.ArgumentParser()
        p.add_argument("--dataset", default="glue")
        p.add_argument("--task", default=None)
        p.add_argument("--model", dest="model_name", default="gpt-4o")

        p.add_argument("--gamma", type=float, default=0.65)
        p.add_argument("--confidence-metric", choices=["vote", "logprob"], default="vote")

        p.add_argument("--pseudo-cache-path")
        p.add_argument("--reuse-pseudo", action="store_true")
        p.add_argument("--regen-strategy", choices=["all", "missing", "low_conf"], default="missing")
        p.add_argument("--regen-threshold", type=float, default=0.5)
        p.add_argument("--pseudo-sample-size", type=int)

        p.add_argument("--seed", type=int, default=42)

        p.add_argument("--provider", default="openai")
        p.add_argument("--temperature", type=float, default=0.3)
        p.add_argument("--sc-paths", type=int, default=5)
        p.add_argument("--top-logprobs", type=int, default=5)
        p.add_argument("--rate-limit-rps", type=float, default=3.0)

        p.add_argument("--K", type=int, default=4)
        p.add_argument("--embed-dim", type=int, default=768)

        args = p.parse_args()

        dataset, task = args.dataset, args.task
        if args.task is None and "-" in args.dataset:
            dataset, task = args.dataset.split("-", 1)

        cfg_kwargs = {k: v for k, v in vars(args).items() if k not in {"dataset", "task"}}
        return Config(dataset=dataset, task=task, **cfg_kwargs)


# =========================
# Data structures
# =========================


class Example:
    def __init__(self, uid: str, text: str, label: Optional[str] = None, score: Optional[float] = None):
        self.uid = uid
        self.text = text
        self.label = label
        self.score = score

    def to_json(self) -> Dict[str, Any]:
        return {"uid": self.uid, "text": self.text, "label": self.label, "score": self.score}

    @staticmethod
    def from_json(d: Dict[str, Any]) -> "Example":
        return Example(d["uid"], d["text"], d.get("label"), d.get("score"))


class DatasetManager:
    def __init__(self, cfg: Config):
        self.cfg = cfg
        self.examples: List[Example] = self._load_raw(cfg.dataset, cfg.task)
        if cfg.pseudo_cache_path and Path(cfg.pseudo_cache_path).exists():
            self._load_pseudo_cache(cfg.pseudo_cache_path)

    def unlabeled_iter(self):
        strategy, th = self.cfg.regen_strategy, self.cfg.regen_threshold
        count = 0
        for ex in self.examples:
            need = (
                (strategy == "all")
                or (strategy == "missing" and ex.label is None)
                or (strategy == "low_conf" and (ex.label is None or (ex.score or 0.0) <= th))
            )
            if need:
                if self.cfg.pseudo_sample_size and count >= self.cfg.pseudo_sample_size:
                    break
                yield ex
                count += 1

    def high_conf_examples(self) -> List[Example]:
        g = self.cfg.gamma
        return [ex for ex in self.examples if ex.label is not None and (ex.score or 0.0) >= g]

    def update_label(self, ex: Example, label: str, score: float):
        ex.label, ex.score = label, score

    def save_pseudo_cache(self):
        if not self.cfg.pseudo_cache_path:
            return
        path = Path(self.cfg.pseudo_cache_path)
        path.parent.mkdir(parents=True, exist_ok=True)
        with path.open("w", encoding="utf-8") as f:
            for ex in self.examples:
                f.write(json.dumps(ex.to_json(), ensure_ascii=False) + "\n")

    def _load_raw(self, dataset: str, task: Optional[str]) -> List[Example]:
        try:
            module = importlib.import_module(f"datasets.{dataset}")
        except ModuleNotFoundError as e:
            raise ValueError(f"No handler for dataset '{dataset}'. Create datasets/{dataset}.py") from e

        if hasattr(module, "load_split"):
            rows = module.load_split(task or "", split="train")
        elif hasattr(module, "load_subject"):
            rows = module.load_subject(task or "", split="train")
        else:
            raise ValueError(f"Handler datasets/{dataset}.py missing expected load_* function")

        return [Example(str(i), row["text"], row.get("label")) for i, row in enumerate(rows)]

    def _load_pseudo_cache(self, cache_path: str):
        print(f"[DatasetManager] Loading cached pseudo labels from {cache_path}")
        uid2ex = {ex.uid: ex for ex in self.examples}
        with Path(cache_path).open() as f:
            for line in f:
                d = json.loads(line)
                if d["uid"] in uid2ex:
                    ex = uid2ex[d["uid"]]
                    ex.label, ex.score = d.get("label"), d.get("score")


# =========================
# Embedding (hashing trick)
# =========================


class HashingEmbedder:
    def __init__(self, dim: int = 768, seed: int = 42):
        self.dim = dim
        random.seed(seed)

    def _tokenize(self, s: str) -> List[str]:
        return [t for t in "".join(ch.lower() if ch.isalnum() else " " for ch in s).split() if t]

    def __call__(self, s: str):
        import numpy as np

        vec = np.zeros(self.dim, dtype="float32")
        for tok in self._tokenize(s):
            h = int(hashlib.md5(tok.encode()).hexdigest(), 16)
            idx = h % self.dim
            vec[idx] += 1.0
        n = np.linalg.norm(vec) + 1e-12
        return vec / n


class EmbeddingManager:
    def __init__(self, cfg: Config, embed_fn=None):
        self.cfg = cfg
        self.embed_fn = embed_fn or HashingEmbedder(cfg.embed_dim, cfg.seed)
        self._vecs = None
        self.examples: List[Example] = []

    def build(self, examples: List[Example]):
        import numpy as np

        self.examples = list(examples)
        self._vecs = np.vstack([self.embed_fn(ex.text) for ex in self.examples])

    def knn(self, query: str, K: int) -> List[Example]:
        import numpy as np

        if self._vecs is None or not self.examples:
            return []
        q = self.embed_fn(query)
        sims = self._vecs @ q  # vectors are normalized
        idx = np.argsort(-sims)[: max(K, 1)]
        return [self.examples[i] for i in idx]


# =========================
# LLM pseudo supervisor
# =========================


class PseudoSupervisor:
    def __init__(self, cfg: Config):
        self.cfg = cfg
        self.llm = get_provider(cfg.provider, cfg)
        self.cache_dir = Path(".upfl_cache")
        self.cache_dir.mkdir(exist_ok=True)
        self._last_call_ts: float = 0.0

    def generate_labels(self, ds: DatasetManager, prompt: str, builder=None):
        if builder is None:
            for ex in ds.unlabeled_iter():
                label, score = self._infer_llm_noctx(prompt, ex.text)
                ds.update_label(ex, label, score)
            ds.save_pseudo_cache()
            return

        for ex in ds.unlabeled_iter():
            Dl = builder(ex.text, self.cfg.K)
            label, score = self._infer_llm_with_ctx(prompt, ex.text, Dl)
            ds.update_label(ex, label, score)
        ds.save_pseudo_cache()

    def _infer_llm_noctx(self, prompt_tpl: str, x: str) -> Tuple[str, float]:
        cache_key = self._hash(prompt_tpl, x, None)
        v = self._read_cache(cache_key)
        if v is not None:
            return v["label"], v["score"]
        votes, probs = self._multi_samples(f"{prompt_tpl.format(text=x)}")
        maj_label, maj_cnt = Counter(votes).most_common(1)[0]
        if self.cfg.confidence_metric == "vote":
            confidence = maj_cnt / max(1, self.cfg.sc_paths)
        else:
            agree = [p for l, p in zip(votes, probs) if l == maj_label and p is not None]
            confidence = (sum(agree) / len(agree)) if agree else maj_cnt / max(1, self.cfg.sc_paths)
        self._write_cache(cache_key, {"label": maj_label, "score": confidence})
        return maj_label, confidence

    def _infer_llm_with_ctx(self, prompt_tpl: str, x: str, Dl: List[Tuple[str, str]]) -> Tuple[str, float]:
        cache_key = self._hash(prompt_tpl, x, Dl)
        v = self._read_cache(cache_key)
        if v is not None:
            return v["label"], v["score"]

        shots = "\n\n".join([f"Example:\nInput: {tx}\nLabel: {lb}" for tx, lb in Dl]) if Dl else ""
        full_prompt = f"{shots}\n\nQuery: {x}\nAnswer:".strip()
        votes, probs = self._multi_samples(full_prompt)
        maj_label, maj_cnt = Counter(votes).most_common(1)[0]
        if self.cfg.confidence_metric == "vote":
            confidence = maj_cnt / max(1, self.cfg.sc_paths)
        else:
            agree = [p for l, p in zip(votes, probs) if l == maj_label and p is not None]
            confidence = (sum(agree) / len(agree)) if agree else maj_cnt / max(1, self.cfg.sc_paths)
        self._write_cache(cache_key, {"label": maj_label, "score": confidence})
        return maj_label, confidence

    def _multi_samples(self, full_prompt: str) -> Tuple[List[str], List[Optional[float]]]:
        votes, probs = [], []
        n = max(1, self.cfg.sc_paths)
        for _ in range(n):
            resp = self._safe_chat(full_prompt)
            lbl, p = self._parse(resp, allow_prob_missing=True)
            votes.append(lbl)
            probs.append(p)
        return votes, probs

    def _safe_chat(self, prompt: str) -> dict:
        dt = time.time() - self._last_call_ts
        wait = max(0.0, 1.0 / self.cfg.rate_limit_rps - dt)
        if wait:
            time.sleep(wait)
        for attempt in range(6):
            try:
                resp = self.llm.chat(prompt, temperature=max(self.cfg.temperature, 0.05))
                self._last_call_ts = time.time()
                return resp
            except Exception as e:
                backoff = 2**attempt
                print(f"[warn] LLM error ({e}); retry in {backoff}s")
                time.sleep(backoff)
        raise RuntimeError("LLM backend failed too many times")

    def _parse(self, resp: dict, allow_prob_missing: bool = False) -> Tuple[str, Optional[float]]:
        raw = (resp.get("content") or "").strip()
        label = raw.split()[0].lower() if raw else "unknown"
        tk_info = resp.get("logprobs")
        if tk_info:
            for t in tk_info:
                tok = (t.get("token") or "").strip().lower()
                if tok == label:
                    try:
                        return label, math.exp(float(t.get("logprob")))
                    except Exception:
                        break
        if allow_prob_missing:
            return label, None
        raise ValueError("No token logprobs returned and fallback disabled")

    def _hash(self, prompt: str, x: str, Dl: Optional[List[Tuple[str, str]]]) -> str:
        ctx = "" if not Dl else json.dumps(Dl, ensure_ascii=False)
        return hashlib.md5(f"{prompt}§{x}§{ctx}".encode("utf-8")).hexdigest()[:16]

    def _read_cache(self, key: str):
        fp = self.cache_dir / f"{key}.json"
        if fp.exists():
            try:
                return json.loads(fp.read_text())
            except Exception:
                return None

    def _write_cache(self, key: str, obj: dict):
        (self.cache_dir / f"{key}.json").write_text(json.dumps(obj, ensure_ascii=False))


# =========================
# TextGrad optimizer
# =========================


class TextGradOptimizer:
    def __init__(self, cfg_or_model_name="gpt-4o", feedback_model_name="gpt-4o", cache=True):
        if isinstance(cfg_or_model_name, str):
            model_name = cfg_or_model_name
        else:
            model_name = cfg_or_model_name.model_name
        self.model_engine = tg.get_engine(model_name, cache=cache)
        tg.set_backward_engine(feedback_model_name, override=True, cache=cache)

    def step(self, ds: DatasetManager, prompt_tpl: str, build_context_fn, steps: int = 3, K: int = 4) -> str:
        eval_template = (
            "Optimize the classification prompt so that, when used with the same few-shot high-confidence examples "
            "at inference time, the model's predicted label matches the pseudo supervision. "
            "Penalize inconsistencies between few-shot usage and a no-shot baseline. "
            "Keep the prompt concise and ensure the model outputs the label first."
        )
        var = tg.Variable(prompt_tpl, requires_grad=True, role_description="prompt template for classification")
        optimizer = tg.TGD(parameters=[var])
        loss_fn = tg.TextLoss(eval_template)

        high_conf = ds.high_conf_examples()
        random.shuffle(high_conf)
        batch = high_conf[: min(64, len(high_conf))]

        for _ in range(steps):
            chunks = []
            for ex in batch:
                Dl = build_context_fn(ex.text, K)
                shots = "\n".join([f"Input: {t}\nLabel: {y}" for t, y in Dl])
                chunks.append(f"[USAGE]\n{shots}\nQuery: {ex.text}\nTarget: {ex.label}\n[BASELINE] No-shot should not contradict.")
            ctx = "\n\n".join(chunks) if chunks else "No examples."
            loss = loss_fn(var.with_context(ctx))
            loss.backward()
            optimizer.step()

        return var.value


# =========================
# Trainer
# =========================


class UPFLTrainer:
    def __init__(self, cfg: Config):
        self.cfg = cfg
        self.ds = DatasetManager(cfg)
        self.ps = PseudoSupervisor(cfg)
        self.opt = TextGradOptimizer(cfg)
        self.prompt = "{text} →"
        self.emb = EmbeddingManager(cfg)

    def build_context(self, x: str, K: int) -> List[Tuple[str, str]]:
        pool = [ex for ex in self.ds.examples if ex.label is not None and (ex.score or 0.0) >= self.cfg.gamma]
        if not pool:
            return []
        self.emb.build(pool)
        nn = self.emb.knn(x, max(K, 1))
        nn = [ex for ex in nn if (ex.score or 0.0) >= self.cfg.gamma][:K]
        return [(ex.text, ex.label) for ex in nn]

    def run(self, rounds: int = 3):
        random.seed(self.cfg.seed)

        if not (self.cfg.reuse_pseudo and any(ex.label for ex in self.ds.examples)):
            self.ps.generate_labels(self.ds, self.prompt)

        self.emb.build(self.ds.examples)

        for r in range(rounds):
            print(f"===== Round {r+1}/{rounds} =====")

            if self.cfg.regen_strategy != "missing":
                for ex in self.ds.examples:
                    Dl = self.build_context(ex.text, self.cfg.K)
                    if not Dl:
                        continue
                    label, score = self.ps._infer_llm_with_ctx(self.prompt, ex.text, Dl)
                    self.ds.update_label(ex, label, score)
                self.ds.save_pseudo_cache()

            self.prompt = self.opt.step(self.ds, self.prompt, self.build_context, steps=3, K=self.cfg.K)
            self._evaluate()

    def _evaluate(self):
        total = len(self.ds.examples) or 1
        pos = sum(1 for e in self.ds.examples if e.label == "positive")
        print(f"[Eval] positive ratio: {pos/total:.2f}\n")


# =========================
# Entry
# =========================


def main():
    cfg = Config.from_args()
    trainer = UPFLTrainer(cfg)
    trainer.run()


if __name__ == "__main__":
    main()
