import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
from datasets import load_dataset
from tqdm import tqdm
import argparse
import json
import numpy as np
import math
import random

# === Utility Functions ===

def top_p_sampling(probs, p=0.9):
    sorted_probs, sorted_indices = torch.sort(probs, descending=True)
    cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
    cutoff = cumulative_probs > p
    last_idx = torch.where(cutoff)[0][0] + 1 if torch.any(cutoff) else len(probs)
    return sorted_indices[:last_idx], sorted_probs[:last_idx]


def recall_weighted_mass(short_probs, full_probs):
    return sum(full_probs.get(tok,0.0) for tok in short_probs)

def kl_divergence(p, q, epsilon=1e-9):
    all_toks = set(p)|set(q)
    return sum(p.get(tok,0.0)*math.log((p.get(tok,0.0)+epsilon)/(q.get(tok,0.0)+epsilon)) for tok in all_toks)

def js_divergence(p, q, epsilon=1e-9):
    m = {tok:0.5*(p.get(tok,0.0)+q.get(tok,0.0)) for tok in set(p)|set(q)}
    return 0.5*kl_divergence(p,m,epsilon) + 0.5*kl_divergence(q,m,epsilon)

def l1_distance_between_probs(p_s, p_l):
    all_toks = set(p_s)|set(p_l)
    return sum(abs(p_s.get(tok,0.0)-p_l.get(tok,0.0)) for tok in all_toks)

def compute_metrics_against_full(short_probs, full_probs, full_set, ctx_len):
    short_set = set(short_probs)
    recall = len(short_set & full_set)/len(full_set) if full_set else 0.0
    precision   = len(short_set & full_set)/len(short_set) if short_set else 0.0
    l1     = l1_distance_between_probs(short_probs, full_probs)
    jsd    = math.sqrt(js_divergence(short_probs, full_probs))
    mass   = recall_weighted_mass(short_probs, full_probs)
    kl     = kl_divergence(full_probs, short_probs)
    if recall + precision > 0:
        f1 = 2 * recall * precision / (recall + precision)
    else:
        f1 = 0.0

    return {
      "ctx_len": ctx_len,
      "recall": recall,
      "precision": precision,
      "support_size": len(short_set),
      "l1_diff": l1,
      "js_distance": jsd,
      "recall_mass": mass,
      "entropy": None,
      "kl_div": kl,
      "f1": f1
    }


# === Main ===

def main(args):
    torch.manual_seed(42)
    random.seed(42)

    # Constants
    SHORT_THRESH = 32

    if args.model_name.lower() == "llama":
        #model_id = "meta-llama/Llama-2-7b-hf"  # ✅ non-chat version
        model_id = "meta-llama/Meta-Llama-3-8B"
    elif args.model_name.lower() == "mistral":
        model_id = "mistralai/Mistral-7B-Instruct-v0.2"
    elif args.model_name.lower() == "qwen":
        model_id = "Qwen/Qwen2-7B"  # ✅ Use BASE model, not Instruct
    elif args.model_name.lower() == "mixtral8x7b":
        model_id = "mistralai/Mixtral-8x7B-v0.1"
    else:
        raise ValueError("Model name must be either 'llama' or 'mistral'")

    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.pad_token = tokenizer.eos_token

    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch.float16,
        use_cache=True
    ).cuda().eval()


    dataset_dict = load_dataset("abacusai/LongChat-Lines")
    #input(" WAIT HERE DATA DOWNLOADED ... ")
    os.makedirs(args.output_dir, exist_ok=True)
    out_file = open(os.path.join(args.output_dir, "synthetic_records.jsonl"), "w")

    for split_name, split_dataset in dataset_dict.items():
        print(f"Processing split: {split_name}")
        for ex in tqdm(split_dataset, desc=f"{split_name}"):

            prompt = ex["prompt"]
            expected = str(ex["expected_number"])

            # Extract line ID from question
            line_id = prompt.split("line")[-1].split("?")[0].strip().replace("-", "-").strip()

            # Prompt formatting
            short_suffix = "A: The <REGISTER_CONTENT> in line"
            long_suffix = f"A: The <REGISTER_CONTENT> in line {line_id} is <"

            for group, suffix in [("short", short_suffix), ("long", long_suffix)]:
                input_ids = tokenizer(prompt + "\n" + suffix, return_tensors="pt")["input_ids"][0].cuda()
                with torch.no_grad():
                    gen = model.generate(input_ids.unsqueeze(0), max_new_tokens=8, do_sample=False)[0]

                pred = tokenizer.decode(gen[input_ids.size(0):], skip_special_tokens=True)
                print(f"expect:{expected}, pred:{[pred]}")
                if group == "long" and expected not in pred:
                    continue  # only keep correct ones

                records = []
                context_label = "short context tokens" if group == "short" else "long context tokens"

                if group == "short":
                    line_tokens = tokenizer(line_id, add_special_tokens=False)["input_ids"]
                    if len(line_tokens) == 0:
                        print(f"[Warning] line_id='{line_id}' produced no tokens. Skipping.")
                        continue
                    target = line_tokens[0]
                else:
                    expected_tokens = tokenizer(expected, add_special_tokens=False)["input_ids"]
                    if len(expected_tokens) < 2:
                        print(f"[Warning] expected='{expected}' produced no tokens. Skipping.")
                        continue
                    # Choose first or second token depending on your logic
                    target = expected_tokens[0]  # or [1] if that makes more sense
                    if tokenizer.decode([target]) != expected[0]:
                        target = tokenizer(expected, add_special_tokens=False)["input_ids"][1]

                # print(f"target:{tokenizer.decode([target])}")
                # input("Wait")
  

                li = input_ids
                si = li[-SHORT_THRESH:]


                li = li.to(dtype=torch.long, device="cuda")
                si = si.to(dtype=torch.long, device="cuda")

                with torch.no_grad():
                    full_logit = model(li.unsqueeze(0)).logits[0, -1]
                    short_logit = model(si.unsqueeze(0)).logits[0, -1]

                lcl = F.log_softmax(full_logit, dim=-1)[target].item()
                lsd = lcl - F.log_softmax(short_logit, dim=-1)[target].item()
                records.append({
                    "token": tokenizer.decode([target]),
                    "label": context_label,
                    "LCL": lcl,
                    "LSD": lsd
                })

                # Distribution metrics
                with torch.no_grad():
                    full_logits = model(input_ids.unsqueeze(0)).logits[0, -1]
                    short_logits = model(input_ids[-SHORT_THRESH:].unsqueeze(0)).logits[0, -1]

                pf = F.softmax(full_logits, dim=-1)
                ps = F.softmax(short_logits, dim=-1)
                idx_f, p_f = top_p_sampling(pf)
                idx_s, p_s = top_p_sampling(ps)
                p_f = p_f / p_f.sum()
                p_s = p_s / p_s.sum()
                
                full_probs = {int(i): float(p) for i, p in zip(idx_f.tolist(), p_f.tolist())}
                short_probs = {int(i): float(p) for i, p in zip(idx_s.tolist(), p_s.tolist())}


                # # Get top-5 tokens and their probabilities
                # topk_probs, topk_indices = torch.topk(pf, k=5)

                # # Convert to dictionary
                # full_probs_top5 = {int(tok_id): float(prob) for tok_id, prob in zip(topk_indices, topk_probs)}

                # # Optionally decode the tokens for readability
                # decoded_top5 = [(tokenizer.decode([tok_id]), float(prob)) for tok_id, prob in zip(topk_indices, topk_probs)]

                # print("Top 5 predictions from full context:")
                # for tok, prob in decoded_top5:
                #     print(f"{tok!r}: {prob:.4f}")
                # input("Wait")

                metrics = compute_metrics_against_full(
                    short_probs, full_probs, set(full_probs),
                    ctx_len=input_ids.size(0)
                )
                record = {
                    "group": group,
                    "line_id": line_id, 
                    "expected": expected,
                    **metrics,
                    "lsd_lcl": records
                }
                out_file.write(json.dumps(record) + "\n")
            

    out_file.close()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument("--model_name", type=str, default="llama", choices=["llama", "mistral","qwen", "mixtral8x7b"])
    parser.add_argument("--output_dir", type=str, default="/output_path")
    parser.add_argument("--cache_dir", type=str, default="/cache_path")
    args = parser.parse_args()

    
    args.output_dir = os.path.join(
        args.output_dir,
        f"{args.model_name}_LongEval"
    )
    os.makedirs(args.output_dir, exist_ok=True)

    main(args)
