#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
LGCD (LoRA-Gated Contrastive Decoding) — lm-eval adapter for your QA-style code.

- Reuses your LoRA extraction (knowledge model → LoRA deltas) and QA forward logic:
  * keep only language_model + lora_params
  * on low confidence: approximate knowledge logits by temporarily patching weights
- Implements lm-eval LM API: generate_until, loglikelihood, loglikelihood_rolling
- Memory-conscious: knowledge model freed right after LoRA extraction
"""

import os
import gc
import json
import time
import logging
import warnings
import traceback
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional

import numpy as np
import torch
import torch.nn.functional as F
from collections import defaultdict
from transformers import AutoModelForCausalLM, AutoTokenizer

# lm-eval
from lm_eval.api.model import LM
from lm_eval import evaluator

warnings.filterwarnings("ignore")

# ---------------------------------------------------------------------
# Logging & determinism
# ---------------------------------------------------------------------
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    handlers=[logging.FileHandler("lgcd_lmeval_from_qa.log"), logging.StreamHandler()],
)
logger = logging.getLogger(__name__)

torch.manual_seed(42)
np.random.seed(42)

# ---------------------------------------------------------------------
# Device
# ---------------------------------------------------------------------
GPU_ID = int(os.environ.get("LGCD_GPU_ID", "2"))
DEVICE = f"cuda:{GPU_ID}" if torch.cuda.is_available() else "cpu"
DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32


def clear_gpu_memory():
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
    logger.info("GPU memory cleared")



@dataclass
class LGCDConfig:
    lora_rank: int = 32
    lora_alpha: float = 32
    lora_dropout: float = 0.1

    layer_group: str = "all"   # 'all'|'lower'|'middle'|'upper'

    confidence_threshold: float = 0.7
    confidence_method: str = "max_prob"  # 'entropy'|'max_prob'|'variance'
    adaptive_threshold: bool = False

    contrastive_alpha: float = 0.1
    contrastive_beta: float = 1.0

    contrastive_top_k: int = 100
    generation_top_k: int = 100

    max_length: int = 512
    generation_temperature: float = 0.7
    top_p: float = 0.9

    device: str = DEVICE


class LoRAExtractor:
    def __init__(self, config: LGCDConfig):
        self.config = config
        self.lora_params = {}

    def get_layer_groups(self, layer_names: List[str]) -> List[str]:
        layer_numbers = []
        for name in layer_names:
            parts = name.split(".")
            for i, part in enumerate(parts):
                if part == "layers" and i + 1 < len(parts):
                    try:
                        layer_num = int(parts[i + 1])
                        layer_numbers.append((layer_num, name))
                        break
                    except ValueError:
                        continue
        if not layer_numbers:
            logger.warning("No layer numbers found, using all layers")
            return layer_names
        layer_numbers.sort(key=lambda x: x[0])
        sorted_names = [name for _, name in layer_numbers]
        total_layers = len(sorted_names)
        third = max(1, total_layers // 3)
        g = self.config.layer_group
        if g == "lower":
            sel = sorted_names[:third]
        elif g == "middle":
            sel = sorted_names[third:2 * third]
        elif g == "upper":
            sel = sorted_names[2 * third:]
        else:
            sel = sorted_names
        logger.info(f"Selected layer group '{g}' with {len(sel)} layers")
        return sel

    def extract_lora_from_knowledge(self, language_model, knowledge_model) -> Dict:
        logger.info("Extracting LoRA params (knowledge - base)...")
        lora_params = {}
        stats = defaultdict(list)

        base_sd = language_model.state_dict()
        know_sd = knowledge_model.state_dict()

        targets = ["gate_proj", "up_proj", "down_proj"]
        candidates = [
            n for n in know_sd.keys()
            if any(t in n for t in targets) and n in base_sd
        ]

        selected = set(self.get_layer_groups(candidates))

        for name in know_sd.keys():
            if name not in selected:
                continue
            delta = know_sd[name] - base_sd[name]
            if delta.dim() != 2:
                continue
            try:
                U, S, Vh = torch.linalg.svd(delta.float(), full_matrices=False)
                r = min(self.config.lora_rank, min(delta.shape))
                if r <= 0:
                    continue
                sqrt_s = torch.sqrt(S[:r])
                A = U[:, :r] @ torch.diag(sqrt_s)
                B = torch.diag(sqrt_s) @ Vh[:r, :]
                lora_params[name] = {
                    "lora_A": A.to(delta.dtype).to(self.config.device),
                    "lora_B": B.to(delta.dtype).to(self.config.device),
                    "scaling": self.config.lora_alpha / r,
                    "original_shape": tuple(delta.shape),
                    "rank": int(r),
                    "explained_variance": float(S[:r].sum() / (S.sum() + 1e-12)),
                }
                stats["rank"].append(r)
                stats["explained_variance"].append(lora_params[name]["explained_variance"])
            except Exception as e:
                logger.warning(f"SVD failed for {name}: {e}")

        if stats["explained_variance"]:
            logger.info(
                f"LoRA extracted: {len(lora_params)} layers | "
                f"avg rank={np.mean(stats['rank']):.1f}, "
                f"avg explained={np.mean(stats['explained_variance']):.3f}"
            )
        self.lora_params = lora_params
        return lora_params


class ConfidenceEstimator:
    def __init__(self, config: LGCDConfig):
        self.config = config
        self.hist: List[float] = []

    def compute(self, logits: torch.Tensor) -> torch.Tensor:
        m = self.config.confidence_method
        if m == "entropy":
            probs = F.softmax(logits, dim=-1)
            ent = -(probs * (probs.clamp_min(1e-8)).log()).sum(dim=-1)
            max_ent = torch.log(torch.tensor(logits.size(-1), dtype=logits.dtype))
            return 1.0 - (ent / max_ent)
        if m == "max_prob":
            return F.softmax(logits, dim=-1).max(dim=-1)[0]
        if m == "variance":
            probs = F.softmax(logits, dim=-1)
            var = probs.var(dim=-1)
            return 1.0 / (1.0 + var)
        raise ValueError(f"Unknown confidence method: {m}")

    def threshold(self, current: float) -> float:
        if not self.config.adaptive_threshold:
            return self.config.confidence_threshold
        self.hist.append(current)
        self.hist = self.hist[-100:]
        if len(self.hist) < 10:
            return self.config.confidence_threshold
        mean, std = float(np.mean(self.hist)), float(np.std(self.hist))
        thr = max(0.1, mean - 0.5 * std)
        return min(0.9, thr)


# ---------------------------------------------------------------------
# lm-eval wrapper
# ---------------------------------------------------------------------
class LGCDHarnessLM(LM):
    """
    Uses:
      - base (language) model
      - lora_params approximating knowledge deltas
    On low confidence:
      - temporarily patches weights with (scale * A @ B), runs forward to get
        approximated "knowledge logits" and blends contrastively.
    """

    def __init__(
        self,
        language_model: AutoModelForCausalLM,
        tokenizer: AutoTokenizer,
        lora_params: Dict[str, Dict],
        cfg: LGCDConfig,
        terminators: Optional[List[int]] = None,
    ):
        super().__init__()
        self.base_model = language_model
        self.tok = tokenizer
        self.lora_params = lora_params
        self.cfg = cfg
        self.conf = ConfidenceEstimator(cfg)
        self.terminators = terminators or [self.tok.eos_token_id]

    # lm-eval shims
    @property
    def tokenizer(self):
        return self.tok

    @property
    def model(self):
        return self.base_model

    @property
    def vocab_size(self):
        return len(self.tok) if hasattr(self.tok, "__len__") else self.tok.vocab_size

    def tok_encode(self, s: str, **kwargs):
        return self.tok.encode(s, **kwargs)

    def tok_decode(self, ids):
        return self.tok.decode(ids)

    # ---- helpers from your QA model ----
    def _topk_mask(self, logits: torch.Tensor, k: int) -> torch.Tensor:
        k = max(1, min(k, logits.size(-1)))
        _, idx = torch.topk(logits, k, dim=-1)
        mask = torch.zeros_like(logits, dtype=torch.bool)
        mask.scatter_(dim=-1, index=idx, src=torch.ones_like(idx, dtype=torch.bool))
        return mask

    def _apply_lora_delta_to_weight(self, weight: torch.Tensor, name: str) -> torch.Tensor:
        info = self.lora_params.get(name)
        if info is None:
            return weight
        A, B, scale = info["lora_A"], info["lora_B"], info["scaling"]
        delta = scale * (A @ B)
        return weight + delta

    def _get_module_for_name(self, model, name: str):
        parts = name.split(".")
        if parts[-1] == "weight":
            parts = parts[:-1]
        mod = model
        for p in parts:
            if hasattr(mod, p):
                mod = getattr(mod, p)
            else:
                return None
        return mod

    @torch.no_grad()
    def _approx_knowledge_logits(self, input_ids: torch.Tensor) -> torch.Tensor:
        """Temporarily patch selected modules' weights, forward once, then restore."""
        originals = {}
        patched = []
        try:
            for lname in self.lora_params.keys():
                module = self._get_module_for_name(self.base_model, lname)
                if module is not None and hasattr(module, "weight"):
                    originals[lname] = module.weight.data.clone()
                    module.weight.data = self._apply_lora_delta_to_weight(module.weight.data, lname)
                    patched.append((lname, module))
            out = self.base_model(input_ids)
            return out.logits[:, -1, :]
        finally:
            for lname, module in patched:
                if lname in originals:
                    module.weight.data = originals[lname]

    @torch.no_grad()
    def _lgcd_step(self, input_ids: torch.Tensor) -> Tuple[torch.Tensor, str]:
        base_out = self.base_model(input_ids, output_hidden_states=True)
        base_logits = base_out.logits[:, -1, :]
        base_conf = self.conf.compute(base_logits).mean().item()
        thr = self.conf.threshold(base_conf)

        if base_conf > thr:
            mask = self._topk_mask(base_logits, self.cfg.generation_top_k)
            masked = base_logits.clone()
            masked[~mask] = float("-inf")
            return masked, "base"

        # Low confidence: compute approximated "knowledge logits" via temp patch
        know_logits = self._approx_knowledge_logits(input_ids)

        # Contrastive blend on base top-k only
        mask = self._topk_mask(base_logits, self.cfg.contrastive_top_k)
        delta = (self.cfg.contrastive_beta * know_logits) - (self.cfg.contrastive_alpha * base_logits)
        blended = base_logits.clone()
        blended[mask] += delta[mask]
        blended[~mask] = float("-inf")
        return blended, "contrastive"

    # ---- lm-eval API ----
    def generate_until(self, requests):
        outs = []
        for req in requests:
            try:
                context = req.args[0] if hasattr(req, "args") else str(req)
                stops = req.args[1] if hasattr(req, "args") and len(req.args) > 1 else []

                ids = self.tok.encode(context, return_tensors="pt").to(self.cfg.device)
                gen = ids.clone()
                stat = {"base": 0, "contrastive": 0}

                for _ in range(self.cfg.max_length):
                    logits, mode = self._lgcd_step(gen)
                    stat[mode] += 1
                    logits = logits / self.cfg.generation_temperature

                    # nucleus
                    if self.cfg.top_p < 1.0:
                        probs = F.softmax(logits, dim=-1)
                        sorted_probs, sorted_idx = torch.sort(probs, dim=-1, descending=True)
                        cum = torch.cumsum(sorted_probs, dim=-1)
                        remove = cum > self.cfg.top_p
                        remove[..., 1:] = remove[..., :-1].clone()
                        remove[..., 0] = False
                        probs.scatter_(dim=-1, index=sorted_idx, src=sorted_probs * (~remove))
                        probs = probs / probs.sum(dim=-1, keepdim=True)
                    else:
                        probs = F.softmax(logits, dim=-1)

                    next_id = torch.multinomial(probs, 1)
                    gen = torch.cat([gen, next_id], dim=1)
                    if next_id.item() in self.terminators:
                        break

                text = self.tok.decode(gen[0, ids.size(1):], skip_special_tokens=True) if gen.size(1) > ids.size(1) else ""
                for s in stops:
                    if s in text:
                        text = text.split(s, 1)[0]

                logger.info(f"generate_until: base={stat['base']} contrastive={stat['contrastive']}")
                outs.append(text)
            except Exception as e:
                logger.error(f"generate_until error: {e}")
                logger.error(traceback.format_exc())
                outs.append("")
        return outs

    def loglikelihood(self, requests):
        res = []
        for req in requests:
            try:
                ctx, cont = req.args
                full = ctx + cont
                ids_full = self.tok.encode(full, return_tensors="pt").to(self.cfg.device)
                ids_ctx = self.tok.encode(ctx, return_tensors="pt").to(self.cfg.device)
                Lc = ids_ctx.size(1)
                if Lc >= ids_full.size(1):
                    res.append((0.0, False))
                    continue

                total, n = 0.0, 0
                for i in range(Lc, ids_full.size(1)):
                    prefix = ids_full[:, :i]
                    target = ids_full[0, i].item()
                    logits, _ = self._lgcd_step(prefix)
                    logp = F.log_softmax(logits, dim=-1)[0, target].item()
                    total += logp
                    n += 1
                res.append((total, n > 0))
            except Exception as e:
                logger.error(f"loglikelihood error: {e}")
                logger.error(traceback.format_exc())
                res.append((0.0, False))
        return res

    def loglikelihood_rolling(self, requests):
        res = []
        for req in requests:
            try:
                ctx = req.args[0]
                ids = self.tok.encode(ctx, return_tensors="pt").to(self.cfg.device)
                if ids.size(1) <= 1:
                    res.append((0.0, False))
                    continue
                lps = []
                for i in range(1, ids.size(1)):
                    prefix = ids[:, :i]
                    target = ids[0, i].item()
                    logits, _ = self._lgcd_step(prefix)
                    lps.append(F.log_softmax(logits, dim=-1)[0, target].item())
                res.append((sum(lps) if lps else 0.0, bool(lps)))
            except Exception as e:
                logger.error(f"loglikelihood_rolling error: {e}")
                logger.error(traceback.format_exc())
                res.append((0.0, False))
        return res


# ---------------------------------------------------------------------
# Loading & Runner (pairs → lm-eval)
# ---------------------------------------------------------------------
def load_and_extract_lora(language_model_id: str, knowledge_model_id: str, cfg: LGCDConfig):
    logger.info("Loading tokenizers...")
    base_tok = AutoTokenizer.from_pretrained(language_model_id)
    know_tok = AutoTokenizer.from_pretrained(knowledge_model_id)

    if base_tok.pad_token is None:
        base_tok.pad_token = base_tok.eos_token
    if know_tok.pad_token is None:
        know_tok.pad_token = know_tok.eos_token

    logger.info("Loading models...")
    lang_model = AutoModelForCausalLM.from_pretrained(language_model_id, torch_dtype=DTYPE).to(DEVICE)
    know_model = AutoModelForCausalLM.from_pretrained(knowledge_model_id, torch_dtype=DTYPE).to(DEVICE)
    lang_model.eval(); know_model.eval()

    logger.info("Extracting LoRA deltas & freeing knowledge model...")
    extractor = LoRAExtractor(cfg)
    lora_params = extractor.extract_lora_from_knowledge(lang_model, know_model)

    del know_tok, know_model
    clear_gpu_memory()

    return base_tok, lang_model, lora_params


def run_lmeval_for_pairs(
    model_pairs: Dict[str, List[Dict[str, str]]],
    cfg: LGCDConfig,
    tasks_by_lang: Optional[Dict[str, List[str]]] = None,
    shots: List[int] = [0, 5],
    model_tag_prefix: str = "lgcdqa",
):
    ts = time.strftime("%Y%m%d_%H%M%S")

    for lang, pairs in model_pairs.items():
        for idx, pair in enumerate(pairs, start=1):
            base_id = pair["language_model"]
            know_id = pair["knowledge_model"]
            logger.info(f"[{lang}] Pair {idx}: base={base_id} | knowledge={know_id}")

            try:
                # 1) Load base + extract LoRA (then free knowledge)
                base_tok, base_model, lora_params = load_and_extract_lora(base_id, know_id, cfg)

                # 2) Terminators
                terms = [base_tok.eos_token_id]
                try:
                    eot = base_tok.convert_tokens_to_ids("<|eot_id|>")
                    if eot is not None and eot != base_tok.unk_token_id:
                        terms.append(eot)
                except Exception:
                    pass

                # 3) Wrap LM
                lm = LGCDHarnessLM(
                    language_model=base_model,
                    tokenizer=base_tok,
                    lora_params=lora_params,
                    cfg=cfg,
                    terminators=terms,
                )

                # 4) Pick tasks
                tasks = (tasks_by_lang or {}).get(lang, [f"global_mmlu_full_{lang}"])

                # 5) Evaluate (per shots)
                for s in shots:
                    for task in tasks:
                        logger.info(f"Evaluating {task} ({s}-shot)")
                        results = evaluator.simple_evaluate(
                            model=lm,
                            tasks=[task],
                            num_fewshot=s,
                            rewrite_requests_cache=True,
                            cache_requests=False,
                            batch_size=20,
                            use_cache=None,
                            device=DEVICE,
                            apply_chat_template=False,
                        )
                        out = (
                            f"{model_tag_prefix}-{base_id.split('/')[-1]}"
                            f"_ct{cfg.confidence_threshold}"
                            f"_ca{cfg.contrastive_alpha}"
                            f"_cb{cfg.contrastive_beta}"
                            f"_lr{cfg.lora_rank}"
                            f"_{task}_{s}shot_{ts}.json"
                        )
                        with open(out, "w") as f:
                            json.dump(results, f, indent=2)
                        logger.info(f"Saved: {out}")

                # 6) Cleanup
                del lm, base_tok, base_model
                clear_gpu_memory()

            except Exception as e:
                logger.error(f"Fatal error for [{lang}] pair {idx}: {e}")
                logger.error(traceback.format_exc())
                clear_gpu_memory()


# ---------------------------------------------------------------------
# Example MODEL_PAIRS
# ---------------------------------------------------------------------
MODEL_PAIRS = {
    "zh": [
        {"language_model": "hfl/llama-3-chinese-8b-instruct", "knowledge_model": "meta-llama/Meta-Llama-3-8B-Instruct"},
        {"language_model": "shenzhi-wang/Llama3-8B-Chinese-Chat", "knowledge_model": "meta-llama/Meta-Llama-3-8B-Instruct"},
    ],
    "de": [
        {"language_model": "DiscoResearch/Llama3-DiscoLeo-Instruct-8B-v0.1", "knowledge_model": "meta-llama/Meta-Llama-3-8B-Instruct"},
    ],
    "pt": [
        {"language_model": "rhaymison/gemma-portuguese-luana-2b", "knowledge_model": "google/gemma-2b-it"},
    ],
    "ar": [
        {"language_model": "MohamedRashad/Arabic-Orpo-Llama-3-8B-Instruct", "knowledge_model": "meta-llama/Meta-Llama-3-8B-Instruct"},
    ],
    "fa": [
        {"language_model": "PartAI/Dorna-Llama3-8B-Instruct", "knowledge_model": "meta-llama/Meta-Llama-3-8B-Instruct"},
    ],
    "ja": [
        {"language_model": "elyza/Llama-3-ELYZA-JP-8B", "knowledge_model": "meta-llama/Meta-Llama-3-8B-Instruct"},
        {"language_model": "tokyotech-llm/Llama-3-Swallow-8B-Instruct-v0.1", "knowledge_model": "meta-llama/Meta-Llama-3-8B-Instruct"},
    ],
    "ko": [
        {"language_model": "KISTI-KONI/KONI-Llama3-8B-Instruct-20240729", "knowledge_model": "meta-llama/Meta-Llama-3-8B-Instruct"},
        {"language_model": "MLP-KTLim/llama-3-Korean-Bllossom-8B", "knowledge_model": "meta-llama/Meta-Llama-3-8B-Instruct"},
    ],
    "id": [
        {"language_model": "GoToCompany/llama3-8b-cpt-sahabatai-v1-instruct", "knowledge_model": "meta-llama/Meta-Llama-3-8B-Instruct"},
    ],
    "sw": [
        {"language_model": "Jacaranda/UlizaLlama3", "knowledge_model": "meta-llama/Meta-Llama-3-8B-Instruct"},
    ],
}

# Optional: customize tasks per language, otherwise defaults to global_mmlu_full_{lang}
TASKS_BY_LANG = {

}


def main():
    logger.info("Starting lm-eval over QA-style LGCD pipeline")
    cfg = LGCDConfig(
        confidence_threshold=0.7,
        contrastive_alpha=0.1,
        contrastive_beta=1.0,
        lora_rank=32,
        layer_group="all",
        device=DEVICE,
        max_length=512,
        generation_top_k=100,
        contrastive_top_k=100,
        generation_temperature=0.7,
        top_p=0.9,
    )

    run_lmeval_for_pairs(
        model_pairs=MODEL_PAIRS,
        cfg=cfg,
        tasks_by_lang=TASKS_BY_LANG,
        shots=[0, 5],
        model_tag_prefix="lgcdqa",
    )

    logger.info("All lm-eval runs finished.")


if __name__ == "__main__":
    main()
