from __future__ import annotations
import os,json
from typing import Dict, Tuple
from sentence_transformers import SentenceTransformer

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from vllm import LLM, SamplingParams
from data_utils import get_task
from prompts import build_prompt,TASK_STYLES

_CACHE: Dict[str, Tuple[object, object]] = {}
_CACHE_DIR = os.getenv("MODEL_CACHE_DIR", "./model_cache")


MODEL_ID_MAP = {
    "qwen3-4b": "Qwen/Qwen3-4B-Instruct-2507",
    "llama3.1-8b": "meta-llama/Llama-3.1-8B-Instruct",
}

LOCAL_LLM_PATH = {"Qwen/Qwen3-4B-Instruct-2507": os.getenv("QWEN3_4B_PATH", "Qwen/Qwen3-4B-Instruct-2507"),
                    "meta-llama/Llama-3.1-8B-Instruct": os.getenv("LLAMA3_1_8B_PATH", "meta-llama/Llama-3.1-8B-Instruct")}

ALIASES = {
    "qwen-3-4b": "qwen3-4b",
    "qwen3_4b": "qwen3-4b",
    "llama-3.1-8b": "llama3.1-8b",
    "llama3_1-8b": "llama3.1-8b",
}

def resolve_model_id(name: str) -> str:

    key = name.strip().lower()
    key = ALIASES.get(key, key)
    try:
        return MODEL_ID_MAP[key]
    except KeyError as e:
        raise KeyError(
            f"Unknown model short name '{name}'. "
            f"Known: {sorted(MODEL_ID_MAP.keys())}"
        ) from e


def logp_y_given_userprompt(tok, llm, task, insturct, ex, examples=[]):

    prefix_text = build_prompt(task=task, instruction=insturct, query=ex["input"], examples=examples)
    full_text = prefix_text + ex["target"]
    # print(full_text)

    prefix_ids = tok(prefix_text, add_special_tokens=False).input_ids

    sp = SamplingParams(
        temperature=0.0,
        max_tokens=1,
        prompt_logprobs=1,
        n=1,
        stop_token_ids=[tok.eos_token_id] if tok.eos_token_id is not None else None,
    )

    out = llm.generate([full_text], sp, use_tqdm=False)[0]

    prompt_ids = out.prompt_token_ids
    plps = out.prompt_logprobs  # list[Optional[dict[token_id -> logprob_obj/float]]]

    start = len(prefix_ids)
    end = len(prompt_ids)

    if tok.eos_token_id is not None and end > start and prompt_ids[end - 1] == tok.eos_token_id:
        end -= 1

    s = 0.0
    n = 0
    for i in range(start, end):
        if plps[i] is None:
            continue
        tid = prompt_ids[i]
        v = plps[i].get(tid, None)
        lp = _as_float_logprob(v)
        if lp is None:
            # In rare cases where value cannot be retrieved, use a small fallback value
            lp = -100.0
        s += lp
        n += 1


    return s/n

def _as_float_logprob(v):
    if v is None:
        return None
    if hasattr(v, "logprob"):
        return float(v.logprob)
    return float(v)

def instruction_information_gain_vllm(model: str, task, insturcts, ex):
    """
    IG = logp(y | T, x) - logp(y | T0, x)
    """

    key = f"{model}|vllm"
    tok, llm = _CACHE.get(key, (None, None))

    if tok is None:
        tok = AutoTokenizer.from_pretrained(
            model,
            trust_remote_code=True,
            cache_dir=_CACHE_DIR,
        )
        if tok.pad_token_id is None:
            tok.pad_token = tok.eos_token

        llm = LLM(
            model=LOCAL_LLM_PATH[model],
            max_model_len=16384,
            tensor_parallel_size=1,
            gpu_memory_utilization=0.4,
        )
        _CACHE[key] = (tok, llm)

    infogain=[]
    # lp_no = logp_y_given_userprompt(tok, llm, task, "Let's think step by step.", ex)
    lp_no = logp_y_given_userprompt(tok, llm, task, TASK_STYLES[task]["instruction"],  ex)


    for insturct in insturcts:
        lp_with = logp_y_given_userprompt(tok, llm, task, insturct, ex)
        infogain.append(lp_with - lp_no)

    return infogain


def run_eval(
    task: str,
    model: str = "qwen3-4b",
):
    instpath=f"./data/{task}/{task}_instrs.json"
    with open(instpath, "r", encoding="utf-8") as f:
        insturcts = json.load(f)

    results=[]
    trainset = get_task(task, 'train')
    for ex in trainset:
        ifgain=instruction_information_gain_vllm(model, task, insturcts, ex)

        results.append({ex["input"]:ifgain})
    return results


if __name__ == "__main__":
    tasks=['gsm8k','gpqa','fp', 'xsum','date','salient']
    model="llama3.1-8b"
    # model="qwen3-4b"
    backend_id = resolve_model_id(model)

    for task in tasks:
        out_dir = f"./data/{task}"
        results=run_eval(
                task=task,
                model=backend_id,
            )
        json_path = os.path.join(out_dir, f"train_ifgain.json")
        # json_path = os.path.join(out_dir, f"{model}_train_ifgain.json")
        with open(json_path, "w", encoding="utf-8") as f:
            json.dump(results, f, ensure_ascii=False, indent=2)
