import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
from datasets import load_dataset
from tqdm import tqdm
import argparse
import json
import numpy as np
import math
import random

# === Utility Functions ===

def top_p_sampling(probs, p=0.9):
    sorted_probs, sorted_indices = torch.sort(probs, descending=True)
    cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
    cutoff = cumulative_probs > p
    last_idx = torch.where(cutoff)[0][0] + 1 if torch.any(cutoff) else len(probs)
    return sorted_indices[:last_idx], sorted_probs[:last_idx]


def recall_weighted_mass(short_probs, full_probs):
    return sum(full_probs.get(tok,0.0) for tok in short_probs)

def kl_divergence(p, q, epsilon=1e-9):
    all_toks = set(p)|set(q)
    return sum(p.get(tok,0.0)*math.log((p.get(tok,0.0)+epsilon)/(q.get(tok,0.0)+epsilon)) for tok in all_toks)

def js_divergence(p, q, epsilon=1e-9):
    m = {tok:0.5*(p.get(tok,0.0)+q.get(tok,0.0)) for tok in set(p)|set(q)}
    return 0.5*kl_divergence(p,m,epsilon) + 0.5*kl_divergence(q,m,epsilon)

def l1_distance_between_probs(p_s, p_l):
    all_toks = set(p_s)|set(p_l)
    return sum(abs(p_s.get(tok,0.0)-p_l.get(tok,0.0)) for tok in all_toks)

def compute_metrics_against_full(short_probs, full_probs, full_set, ctx_len):
    short_set = set(short_probs)
    recall = len(short_set & full_set)/len(full_set) if full_set else 0.0
    precision   = len(short_set & full_set)/len(short_set) if short_set else 0.0
    l1     = l1_distance_between_probs(short_probs, full_probs)
    jsd    = math.sqrt(js_divergence(short_probs, full_probs))
    mass   = recall_weighted_mass(short_probs, full_probs)
    kl     = kl_divergence(full_probs, short_probs)
    if recall + precision > 0:
        f1 = 2 * recall * precision / (recall + precision)
    else:
        f1 = 0.0

    return {
      "ctx_len": ctx_len,
      "recall": recall,
      "precision": precision,
      "support_size": len(short_set),
      "l1_diff": l1,
      "js_distance": jsd,
      "recall_mass": mass,
      "entropy": None,
      "kl_div": kl,
      "f1": f1
    }



# === Main ===

def main(args):
    torch.manual_seed(42)
    random.seed(42)

    # Constants
    min_tokens = 7000
    idx_low, idx_high = 100, 1000
    num_stories = 100
    num_samples_per_doc = 100
    SHORT_THRESH = 32

    alpha = 2
    beta = -1

    if args.model_name.lower() == "llama":
        #model_id = "meta-llama/Llama-2-7b-hf"  # ✅ non-chat version
        model_id = "meta-llama/Meta-Llama-3-8B"
    elif args.model_name.lower() == "mistral":
        model_id = "mistralai/Mistral-7B-Instruct-v0.2"
    elif args.model_name.lower() == "qwen":
        model_id = "Qwen/Qwen2-7B"  # ✅ Use BASE model, not Instruct
    else:
        raise ValueError("Model name must be either 'llama' or 'mistral'")

    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.pad_token = tokenizer.eos_token

    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch.float16,
        use_cache=True
    ).cuda().eval()


    if args.dataset == "reddit":
        dataset = load_dataset("reddit_tifu", "long", split="train", cache_dir=args.cache_dir)
        text_key = "documents"
    elif args.dataset == "cnn":
        dataset = load_dataset("cnn_dailymail", "3.0.0", split="train", cache_dir=args.cache_dir)
        text_key = "article"
    elif args.dataset == "gov":
        dataset = load_dataset("ccdv/govreport-summarization", split="train", cache_dir=args.cache_dir)
        text_key = "report"
    elif args.dataset == "wiki":
        dataset = load_dataset(
            "wikipedia", "20220301.en",
            split="train",
            cache_dir=args.cache_dir,
            trust_remote_code=True  # <-- add this
        )        
        text_key = "text"
    elif args.dataset == "qmsum":
        dataset = load_dataset("pszemraj/qmsum-cleaned", split="train")
        text_key = "input"
    elif args.dataset == "booksum":
        dataset = load_dataset("kmfoda/booksum", split="train")
        text_key = "chapter" 

    else:
        raise ValueError("Unsupported dataset specified.")

    os.makedirs(args.output_dir, exist_ok=True)
    out_file = open(os.path.join(args.output_dir, "records.jsonl"), "a")

    if args.dataset == "gov" or args.dataset == "qmsum" or args.dataset == "booksum":
        idx_low = 3000
        idx_high = 4000

    story_counter = 0
    processed = 0
    short_counter = 0
    long_counter = 0
    

    for ex in tqdm(dataset, desc="Scanning stories"):
        if processed >= num_stories:
            break
        text = ex[text_key]
        tokens = tokenizer(text, return_tensors="pt", truncation=False)["input_ids"]
        if tokens.shape[1] < args.min_tokens:
            continue

        story_counter += 1
        input_ids_full = tokens[0][:args.min_tokens].cuda()

        print("\n" + "+" * 40)
        print(f"number of short: {short_counter}, number of long: {long_counter}")
        print(f"Analyzing story {story_counter}")

        # --- New sampling: random 100 indices between [6000,7000] ---
        max_idx = min(idx_high, input_ids_full.size(0)-1)
        all_indices = list(range(idx_low, max_idx))
        if not all_indices:
            continue
        target_indices = random.sample(all_indices, min(num_samples_per_doc, len(all_indices)))
        target_indices.sort()

        # --- For each sampled token, compute LSD & LCL and metrics ---
        tokens = tokens[0]
        for tok in target_indices:
            # full vs short logits
            with torch.no_grad():
                out_full = model(input_ids=tokens[:tok].unsqueeze(0).cuda())
                logits_full = out_full.logits[0, -1].cpu()

                mask = torch.zeros(1, tok, dtype=torch.long).cuda()
                mask[0, -SHORT_THRESH:] = 1
                out_short = model(input_ids=tokens[:tok].unsqueeze(0).cuda(), attention_mask=mask)
                logits_short = out_short.logits[0, -1].cpu()

            # log probs
            logp_full = torch.log_softmax(logits_full, dim=-1)
            logp_short = torch.log_softmax(logits_short, dim=-1)
            actual = int(tokens[tok].item())
            lcl = logp_full[actual].item()
            lsd = lcl - logp_short[actual].item()

            # classification
            if lsd > alpha and lcl > beta:
                group = "long"
                long_counter += 1
            else:
                group = "short"
                short_counter += 1

            # set-based probs
            probs_full = F.softmax(logits_full, dim=-1)
            probs_short = F.softmax(logits_short, dim=-1)
            idx_f, p_f = top_p_sampling(probs_full)
            idx_s, p_s = top_p_sampling(probs_short)
            #normalize
            p_f = p_f / p_f.sum()
            p_s = p_s / p_s.sum()
            full_probs = {int(i): float(p) for i,p in zip(idx_f.tolist(), p_f.tolist())}
            short_probs = {int(i): float(p) for i,p in zip(idx_s.tolist(), p_s.tolist())}

            metrics = compute_metrics_against_full(short_probs, full_probs, set(full_probs), ctx_len=tok)

            record = {"story": story_counter, "tok": tok, "group": group,
                      **metrics, "lcl": lcl, "lsd": lsd}
            out_file.write(json.dumps(record) + "\n")
            out_file.flush()

        processed += 1

    out_file.close()
    print(f"✅ Wrote records to {os.path.join(args.output_dir, 'records.jsonl')}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument("--model_name", type=str, default="llama", choices=["llama", "mistral","qwen"])
    parser.add_argument("--conf_thresh", type=float, default=0.2)
    parser.add_argument("--output_dir", type=str, default="/output_path")
    parser.add_argument("--cache_dir", type=str, default="/cache_path")
    parser.add_argument("--dataset", type=str, choices=["reddit", "cnn", "gov", "wiki", "qmsum", "booksum", "wiki-en", "wiki-ar", "wiki-fr", "wiki-de", "wiki-zh", "wiki-ru", "wiki-ko", "wiki-th", "openmath", "ccdv"], default="reddit")
    parser.add_argument("--min_tokens", type=int, default=1024)
    args = parser.parse_args()

    
    args.output_dir = os.path.join(
        args.output_dir,
        f"{args.dataset}_{args.model_name}_longshort"
    )
    os.makedirs(args.output_dir, exist_ok=True)

    main(args)
