import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
import matplotlib.pyplot as plt
from datasets import load_dataset
import os
from tqdm import tqdm
import argparse
from collections import defaultdict
import json
import numpy as np


def find_tokenwise_true_context_lengths(model, tokenizer, input_ids, target_indices, threshold=0.2,
                                        step_size=8, min_len=32, use_masking=True,
                                        batch_sizes={32: 1, 512: 1}):
    """
    For each index in target_indices, find the minimum context length required to:
      - get the token correct (match_ctx_len)
      - get it confidently correct (conf_ctx_len)
    
    Args:
        model: LLaMA or other Causal LM
        tokenizer: huggingface tokenizer
        input_ids: (seq_len,) tokenized input
        target_indices: list of token positions to check
        threshold: confidence gap between top-1 and top-2
        step_size: granularity to increase context size
        min_len: minimum context window size to start with
        use_masking: if True, use attention mask; if False, slice input_ids
        batch_sizes: map from window size to batch size
    
    Returns:
        Dictionary mapping token index to { "match_ctx_len": ..., "conf_ctx_len": ... }
    """
    results = {}
    input_ids = input_ids.cuda()

    for idx in tqdm(target_indices, desc="Checking tokens for context lengths"):
        max_ctx = idx
        if max_ctx < min_len:
            continue  # skip tokens too early in the sequence

        candidate_lens = list(range(min_len, max_ctx + 1, step_size))
        actual_token = input_ids[idx].item()

        conf_ctx_len = None

        # Determine batch size depending on max window length
        def get_batch_size(l):
            for k in sorted(batch_sizes.keys(), reverse=True):
                if l >= k:
                    return batch_sizes[k]
            return batch_sizes[min(batch_sizes)]

        all_probs = []

        with torch.no_grad():
            for batch_start in range(0, len(candidate_lens), get_batch_size(candidate_lens[-1])):
                batch_lens = candidate_lens[batch_start: batch_start + get_batch_size(candidate_lens[-1])]

                if use_masking:
                    context = input_ids[:idx].unsqueeze(0).repeat(len(batch_lens), 1)
                    attention_mask = torch.zeros_like(context)
                    for i, ctx_len in enumerate(batch_lens):
                        attention_mask[i, -ctx_len:] = 1
                    context = context.cuda()
                    attention_mask = attention_mask.cuda()
                else:
                    # context = torch.stack([input_ids[idx - l:idx] for l in batch_lens])
                    # attention_mask = None
                    max_len = max(batch_lens)
                    padded_contexts = []

                    for l in batch_lens:
                        slice = input_ids[idx - l:idx]
                        pad_len = max_len - slice.size(0)
                        padded = F.pad(slice, (pad_len, 0), value=tokenizer.pad_token_id)
                        padded_contexts.append(padded)

                    context = torch.stack(padded_contexts)
                    attention_mask = None

                outputs = model(input_ids=context, attention_mask=attention_mask)
                logits = outputs.logits[:, -1, :]
                probs = F.softmax(logits, dim=-1).cpu()

                for i, ctx_len in enumerate(batch_lens):
                    top_probs, top_ids = torch.topk(probs[i], 2)
                    top1, top2 = top_ids.tolist()
                    p1, p2 = top_probs.tolist()
                    margin = p1 - p2

                    if conf_ctx_len is None and top1 == actual_token and margin > threshold:
                        conf_ctx_len = ctx_len

                del outputs, logits, probs
                torch.cuda.empty_cache()

                # ✅ Early exit if both are found
                if conf_ctx_len is not None:
                    print("Breaking")
                    print(idx)
                    print(conf_ctx_len)
                    break

        results[idx] = {
            "conf_ctx_len": conf_ctx_len
        }

    return results



def analyze_prediction_confidence_fast(model, input_ids, conf_thresh=0.2):
    input_ids = input_ids.unsqueeze(0).cuda()
    seq_len = input_ids.shape[1]

    with torch.no_grad():
        logits = model(input_ids=input_ids, use_cache=False).logits[:, :-1, :]
        probs = F.softmax(logits, dim=-1).squeeze(0).cpu()

    targets = input_ids[0, 1:].cpu()
    top_probs, top_ids = torch.topk(probs, 2, dim=-1)

    correct_conf = []

    for i in range(len(targets)):
        gold = targets[i].item()
        top1, top2 = top_ids[i]
        p1, p2 = top_probs[i]
        margin = p1.item() - p2.item()

        if top1.item() == gold and margin > conf_thresh:
            correct_conf.append(i + 1)

    return correct_conf, input_ids[0]



def load_and_filter_dataset(dataset_name, subset=None, text_key="text", tokenizer_name="your-model-name", min_tokens=1024, max_docs=None, cache_dir=None):
    """
    Loads and filters a dataset based on minimum token length.

    Args:
        dataset_name (str): Name of the dataset to load.
        subset (str, optional): Subset of the dataset, if applicable.
        text_key (str): Key to access the text in the dataset examples.
        tokenizer_name (str): Name of the tokenizer to use.
        min_tokens (int): Minimum number of tokens required.
        max_docs (int, optional): Maximum number of documents to process.
        cache_dir (str, optional): Directory to cache the dataset.

    Returns:
        List[Dict]: List of dictionaries with 'input_ids' and 'raw_text'.
    """
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    tokenizer.pad_token = tokenizer.eos_token

    if subset:
        dataset = load_dataset(dataset_name, subset, split="train", cache_dir=cache_dir)
    else:
        dataset = load_dataset(dataset_name, split="train", cache_dir=cache_dir)

    filtered_data = []
    for ex in tqdm(dataset, desc=f"Processing {dataset_name}"):
        text = ex[text_key]
        tokens = tokenizer(text, return_tensors="pt", truncation=False)["input_ids"]
        if tokens.shape[1] >= min_tokens:
            filtered_data.append({
                "input_ids": tokens[0],
                "raw_text": text
            })
            if max_docs and len(filtered_data) >= max_docs:
                break

    return filtered_data

def stratified_sample(indices, seq_len, num_samples=100, bins=10):
    indices = np.array(indices)
    sampled_indices = []
    bin_edges = np.linspace(0, seq_len, bins + 1, dtype=int)
    samples_per_bin = num_samples // bins

    for i in range(bins):
        start, end = bin_edges[i], bin_edges[i + 1]
        bin_members = indices[(indices >= start) & (indices < end)]
        if len(bin_members) > 0:
            num = min(samples_per_bin, len(bin_members))
            sampled = np.random.choice(bin_members, size=num, replace=False)
            sampled_indices.extend(sampled.tolist())

    if len(sampled_indices) < num_samples:
        extra_needed = num_samples - len(sampled_indices)
        extra_pool = np.setdiff1d(indices, sampled_indices)
        if len(extra_pool) >= extra_needed:
            extra = np.random.choice(extra_pool, size=extra_needed, replace=False)
            sampled_indices.extend(extra.tolist())

    return sorted(sampled_indices)



def main(args):
    torch.manual_seed(42)

    if args.model_name.lower() == "llama":
        # model_id = "meta-llama/Llama-2-7b-hf"  # ✅ non-chat version
        model_id = "meta-llama/Meta-Llama-3-8B"
    elif args.model_name.lower() == "mistral":
        model_id = "mistralai/Mistral-7B-Instruct-v0.2"
    elif args.model_name.lower() == "qwen":
        model_id = "Qwen/Qwen2-7B"  # ✅ Use BASE model, not Instruct
    else:
        raise ValueError("Model name must be either 'llama' or 'mistral'")

    # tokenizer = AutoTokenizer.from_pretrained(model_id)
    # tokenizer.pad_token = tokenizer.eos_token

    # model = AutoModelForCausalLM.from_pretrained(
    #     model_id,
    #     torch_dtype=torch.float16,
    #     use_cache=True
    # ).cuda().eval()

    # input("MODEL WAITING")

    if args.dataset == "reddit":
        dataset = load_dataset("reddit_tifu", "long", split="train", cache_dir=args.cache_dir)
        text_key = "documents"
    elif args.dataset == "cnn":
        dataset = load_dataset("cnn_dailymail", "3.0.0", split="train", cache_dir=args.cache_dir)
        text_key = "article"
    elif args.dataset == "gov":
        dataset = load_dataset("ccdv/govreport-summarization", split="train", cache_dir=args.cache_dir)
        text_key = "report"
    elif args.dataset == "wiki":
        dataset = load_dataset(
            "wikipedia", "20220301.en",
            split="train",
            cache_dir=args.cache_dir,
            trust_remote_code=True  # <-- add this
        )        
        text_key = "text"
    else:
        raise ValueError("Unsupported dataset specified.")
    
    input(" WAIT HERE DATA DOWNLOADED ... ")

    # input("Waiting here we just wanted to download the dataset ...")
    total_correct_conf = 0
    total_tokens = 0

    processed = 0
    all_ctx_lens = {}
    conf_lens = []


    story_counter = 0
    for ex in tqdm(dataset, desc="Scanning stories"):
        if processed >= args.num_stories:
            break

        text = ex[text_key]  # ✅ use correct key per dataset
        tokens = tokenizer(text, return_tensors="pt", truncation=False)["input_ids"]

        if tokens.shape[1] < args.min_tokens:
            continue

        story_counter += 1
        print("+" * 40)
        print("Analyzing story " + str(story_counter))

        input_ids = tokens[0][:args.min_tokens].cuda()
        
        correct_conf, input_ids = analyze_prediction_confidence_fast(
            model, input_ids, args.conf_thresh
        )


        # Merge correct token indices
        all_candidates = correct_conf
        # print(correct_conf)
        # print(input_ids.shape)
        print("len(all_candidates): " + str(len(all_candidates)))
        # input()

        # Filter only those where index >= min_len
        valid_targets = [idx for idx in all_candidates if idx >= 32]

        # Sample up to 100 per document
        if len(valid_targets) > args.num_samples_per_doc:
            # target_indices = sorted(torch.randperm(len(valid_targets))[:100].tolist())
            # target_indices = [valid_targets[i] for i in target_indices]
            # target_indices = stratified_sample(valid_targets, seq_len=len(input_ids), num_samples=100)
            target_indices = stratified_sample(valid_targets, seq_len=len(input_ids), num_samples=args.num_samples_per_doc)
        else:
            target_indices = valid_targets


        # === NEW: Estimate context lengths for these tokens ===
        ctx_lens = find_tokenwise_true_context_lengths(
            model=model,
            tokenizer=tokenizer,
            input_ids=input_ids,
            target_indices=target_indices,
            threshold=args.conf_thresh,
            step_size=16,
            use_masking=True
        )

        # Add confident-only field
        for tok in ctx_lens:
            ctx_lens[tok]["conf_ctx_len"] = ctx_lens[tok].get("conf_ctx_len")
        
        # Collect lengths for confident token histogram
        for tok in ctx_lens:
            if ctx_lens[tok]["conf_ctx_len"] is not None:
                conf_lens.append(ctx_lens[tok]["conf_ctx_len"])


        story_key = f"story_{processed}"
        all_ctx_lens[story_key] = ctx_lens

        # Save immediately after processing this story
        os.makedirs(args.output_dir, exist_ok=True)
        json_path = os.path.join(args.output_dir, f"context_lengths_partial.json")
        with open(json_path, "w") as f:
            json.dump(all_ctx_lens, f, indent=2)

        # === Update stats ===
        total_correct_conf += len(correct_conf)
        total_tokens += input_ids.shape[0] - 1  # actual # of prediction targets
        processed += 1


        # === Summary Pie Chart ===
        plt.figure(figsize=(6, 6))
        confident_pct = total_correct_conf / total_tokens if total_tokens > 0 else 0
        other_pct = 1 - confident_pct

        plt.pie(
            [confident_pct, other_pct],
            labels=[f"Confident-Correct ({confident_pct*100:.1f}%)", f"Other ({other_pct*100:.1f}%)"],
            colors=["green", "lightgray"],
            autopct="%1.1f%%",
            startangle=90,
            counterclock=False
        )
        plt.title(f"Token Classification After {processed} Stories")

        summary_path = os.path.join(args.output_dir, "summary_classification_progress.png")
        plt.tight_layout()
        plt.savefig(summary_path)
        plt.close()


        # Save context length data
        os.makedirs(args.output_dir, exist_ok=True)
        json_path = os.path.join(args.output_dir, f"context_lengths.json")
        with open(json_path, "w") as f:
            json.dump(all_ctx_lens, f, indent=2)
        print(f"📝 Saved context length results to {json_path}")

        if len(conf_lens) > 0:
            plt.figure(figsize=(8, 5))
            bins = list(range(0, max(conf_lens) + 16, 16))

            plt.hist(conf_lens, bins=bins, alpha=0.8, label="Confident Context Length", color="green", edgecolor="black")

            plt.xlabel("Context Length (Tokens)")
            plt.ylabel("Token Count (log scale)")
            plt.yscale("log")
            plt.title(f"Confident Tokens: Context Length Distribution After {processed} Stories")
            plt.legend()

            hist_path = os.path.join(args.output_dir, "confident_token_ctx_lengths.png")
            plt.tight_layout()
            plt.savefig(hist_path)
            plt.close()


        print(f"\n📊 Updated summary plot at {summary_path}")

        print(f"🧮 Totals — Correct-Conf: {total_correct_conf}, Tokens: {total_tokens}")

    print("All Done !")



if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument("--model_name", type=str, default="llama", choices=["llama", "mistral","qwen"])
    parser.add_argument("--num_stories", type=int, default=1)
    parser.add_argument("--min_tokens", type=int, default=1024)
    parser.add_argument("--conf_thresh", type=float, default=0.2)
    parser.add_argument("--output_dir", type=str, default="/output_path")
    parser.add_argument("--cache_dir", type=str, default="/cache_path")
    parser.add_argument("--dataset", type=str, choices=["reddit", "cnn", "gov", "wiki"], default="reddit")
    parser.add_argument("--num_samples_per_doc", type=int, default=100, help="Number of token positions to sample per document")
    args = parser.parse_args()

    
    args.output_dir = os.path.join(
        args.output_dir,
        f"{args.dataset}_{args.model_name}_samples{args.num_samples_per_doc}"
    )
    os.makedirs(args.output_dir, exist_ok=True)

    main(args)
