import torch
import os
import argparse
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
from models import Llama2Tokenizer, PruneLlama2ForCausalLM, PruneLlama2DecoderLayer, PruneLlama3ForCausalLM
from pruning import help_functions_hn, collect_info_reg_llama
# from models.llama3 import modeling_llama
from transformers.models.llama import modeling_llama
from datasets import load_dataset
import math
import random
from tqdm import tqdm
import numpy as np
import gc
import time
import pickle
import pathlib

from lib.lrp_utils import get_lrp_masks, create_full_pass_masks, get_lrp_input_relevance
from lib.lrp_utils2 import get_progressive_lrp_masks
from lxt.efficient import monkey_patch

from lib.dataset_loader import calculate_perplexity, sample_wikitext_sequences, build_wikitext_ids, build_c4_ids, build_alpaca_ids, build_redpajama_ids, load_mc_dataset, format_mc_example, evaluate_mc_example, evaluate_mc_dataset, calculate_sequence_log_prob, get_mc_context_for_features, get_prompt_and_answer_position, format_mc_prompt_with_ans

class LRPDataCollector:
    def __init__(self):
        self.samples_data = []
    
    def add_sample(self, sample_ids, lrp_scores, activations):
        """add sample data"""
        self.samples_data.append({
            "sample_id": sample_ids,
            "lrp": lrp_scores,
            "activations": activations
        })

    def add_sample_with_lrp(self, sample_ids, label_pos, lrp_scores, activations, original_example=None):
        """add sample data"""
        sample = {
            "sample_id": sample_ids,
            "label": label_pos,
            "lrp": lrp_scores,
            "activations": activations
        }
        if original_example is not None:
            sample["original_example"] = original_example
        self.samples_data.append(sample)
    
    def save_all(self, save_path="./data/lrp_train_samples.pkl"):
        """save all sample data"""
        save_path = pathlib.Path(save_path)
        save_path.parent.mkdir(parents=True, exist_ok=True)
        
        with save_path.open("wb") as f:
            pickle.dump(self.samples_data, f, protocol=pickle.HIGHEST_PROTOCOL)
        
        print(f"Saved {len(self.samples_data)} samples to {save_path}")
    
    def clear(self):
        """clear collected data"""
        self.samples_data = []


# ===================== LRP Mask Utilities =====================

def lrp_masks_to_vectors(masks):
    """Convert LRP generated masks dictionary to vectors format"""
    # if masks already list, return directly
    if isinstance(masks, list):
        return masks
        
    # Collect all layer indices and component names
    all_layers = set()
    all_components = {}
    
    for (layer_idx, component_name) in masks.keys():
        all_layers.add(layer_idx)
        if component_name not in all_components:
            all_components[component_name] = []
        all_components[component_name].append(layer_idx)
    
    # Sort layers and components
    layer_indices = sorted(list(all_layers))
    component_order = ["Ind1", "Ind2", "Ind3", "Ind4", "Ind5"] 
    
    # Build vectors list
    vectors = []
    for layer_idx in layer_indices:
        for comp_name in component_order:
            if (layer_idx, comp_name) in masks:
                vectors.append(masks[(layer_idx, comp_name)])
    
    return vectors

def apply_lrp_masks(model, masks, param_reg_structures, use_gate=True):
    """
    Apply LRP masks to a model
    
    Args:
        model: The language model
        masks: The LRP masks
        param_reg_structures: Parameter regularization structures
        use_gate: Whether to activate the gates
        
    Returns:
        help_functions_hn: The helper object for later use
    """
    # Convert masks to vectors format
    vectors = lrp_masks_to_vectors(masks)

    # Create help_functions_hn instance and apply masks
    hn_helper = help_functions_hn(param_reg_structures)
    hn_helper.set_gate_vectors(model, vectors)
    
    # Set gate status
    hn_helper.set_gate_status(model, use_gate=use_gate)
    
    return hn_helper

# ===================== Sample-wise Evaluation =====================

def evaluate_sample_with_lrp(
    eval_model,
    lrp_model,
    tokenizer,
    sample,
    dataset_name=None,
    param_reg_structures=None,
    mask_ratio=0.1,
    output_dir="./output",
    device1="cuda",
    device2="cuda",
    is_wikitext=False,
    seqlen=2048,
    lrp_data_collector=None
):
    mask_ratios = [0.542, 0.159, 0.178, 0.381, 0.118, 0.155, 0.046, 0.159, 0.16 ,0.039, 0.068, 0.036, 0.062, 0.095, 0.026, 0.072, 0.054, 0.031,
        0.117, 0.039, 0.092, 0.052, 0.04 , 0.12 , 0.034, 0.108, 0.045,
        0.06 , 0.142, 0.052, 0.101, 0.045, 0.06 , 0.153, 0.034, 0.119,
        0.05 , 0.076, 0.168, 0.052, 0.131, 0.064, 0.086, 0.176, 0.056,
        0.121, 0.06 , 0.106, 0.186, 0.055, 0.122, 0.061, 0.112, 0.217,
        0.06 , 0.15 , 0.083, 0.104, 0.204, 0.05 , 0.133, 0.072, 0.095,
        0.206, 0.073, 0.154, 0.063, 0.098, 0.194, 0.044, 0.13 , 0.057,
        0.079, 0.197, 0.037, 0.114, 0.048, 0.065, 0.189, 0.037, 0.114,
        0.047, 0.07 , 0.173, 0.028, 0.112, 0.061, 0.053, 0.127, 0.022,
        0.114, 0.07 , 0.043, 0.148, 0.028, 0.112, 0.05 , 0.051, 0.122,
        0.032, 0.143, 0.053, 0.034, 0.139, 0.028, 0.17 , 0.11 , 0.042,
        0.111, 0.021, 0.135, 0.089, 0.041, 0.141, 0.031, 0.145, 0.092,
        0.036, 0.17 , 0.032, 0.204, 0.126, 0.034, 0.165, 0.038, 0.19 ,
        0.123, 0.038, 0.18 , 0.047, 0.204, 0.125, 0.04 , 0.204, 0.065,
        0.167, 0.119, 0.036, 0.216, 0.088, 0.189, 0.126, 0.057, 0.244,
        0.093, 0.235, 0.144, 0.04 , 0.241, 0.119, 0.209, 0.171, 0.038,
        0.227, 0.08 , 0.219, 0.14 , 0.03 , 0.212, 0.008]

    # For WikiText samples
    if is_wikitext:
        sample_ids = sample
        # sample_ids = torch.tensor(sample).unsqueeze(0)
        
        '''
        input_relevance = get_lrp_input_relevance(
            model=lrp_model,
            tokenizer=tokenizer,
            mask_ratio=mask_ratio,
            prompt=sample_text,
            output_dir=output_dir,
            device=device2,
            apply_monkey_patch=False,  # Model should already be patched
            iteration=0
        )
        '''
        
        # Generate LRP masks
        print("Generating LRP masks for WikiText sample...")
        masks, lrp_scores, activations = get_lrp_masks(
        # masks = get_progressive_lrp_masks(
            model=lrp_model,
            tokenizer=tokenizer,
            mask_ratios=mask_ratios,
            input_ids=sample_ids,
            output_dir=output_dir,
            device=device2,
            apply_monkey_patch=False,  # Model should already be patched
            iteration=0,
            # input_relevance=input_relevance
        )

        lrp_data_collector.add_sample(sample_ids[0], lrp_scores, activations)

        # Evaluate without masks first
        print("Evaluating WikiText sample without masks...")
        original_ppl = calculate_perplexity(eval_model, sample_ids, limit_length=seqlen, device=device1)
        
        # Apply masks
        print("Applying LRP masks...")
        hn_helper = apply_lrp_masks(eval_model, masks, param_reg_structures, use_gate=True)
        
        # Evaluate with masks
        print("Evaluating WikiText sample with masks...")
        masked_ppl = calculate_perplexity(eval_model, sample_ids, limit_length=seqlen, device=device1)
        
        # Remove masks
        hn_helper.set_gate_status(eval_model, use_gate=False)
        
        # Calculate change percentage
        change_percentage = (masked_ppl - original_ppl) / original_ppl * 100

        return {
            "dataset": "wikitext",
            "original_ppl": original_ppl,
            "masked_ppl": masked_ppl,
            "change_percentage": change_percentage
        }
    
    # For multiple-choice samples
    else:
        # Format the example
        formatted_example = format_mc_example(sample, dataset_name)
        formatted_example["dataset_name"] = dataset_name

        # context_for_features = get_mc_context_for_features(formatted_example, dataset_name)
        question = formatted_example["question"] + " " + "Answer:"
        options = formatted_example["options"]
        label = formatted_example["label"]
        label_context = options[label]


        full_input_ids = tokenizer(question, label_context, return_tensors="pt", truncation=True, max_length=seqlen).input_ids
        label_pos = len(tokenizer(question)['input_ids'])
        # get label's context
        #full_input_ids = tokenizer(context_for_features + " " + label_context, 
        #                          return_tensors="pt", truncation=True, max_length=seqlen).input_ids
        #label_pos = len(tokenizer(context_for_features)['input_ids'])

        '''
        full_context, full_input_ids, label_pos = get_prompt_and_answer_position(
            formatted_example, 
            tokenizer, 
            method="per_input",  # or "per_task"
            include_task_instruction=False
        )
        ''' 
        
        # Generate LRP masks
        print(f"Generating LRP masks for {dataset_name} sample...")
        masks, lrp_scores, activations = get_lrp_masks(
            model=lrp_model,
            tokenizer=tokenizer,
            mask_ratios=mask_ratios,
            input_ids=full_input_ids,  # change to use context_for_features
            output_dir=output_dir,
            device=device2,
            apply_monkey_patch=False,  # Model should already be patched
            iteration=0,
            label_context=label_pos
        )


        lrp_data_collector.add_sample_with_lrp(full_input_ids[0], label_pos, lrp_scores, activations, original_example=formatted_example)
        
        # Evaluate without masks first
        print(f"Evaluating {dataset_name} sample without masks...")
        original_result = evaluate_mc_example(eval_model, tokenizer, formatted_example, device=device1)
        
        # Apply masks
        print("Applying LRP masks...")
        hn_helper = apply_lrp_masks(eval_model, masks, param_reg_structures, use_gate=True)
        
        # Evaluate with masks
        print(f"Evaluating {dataset_name} sample with masks...")
        masked_result = evaluate_mc_example(eval_model, tokenizer, formatted_example, device=device1)
        
        # Remove masks
        hn_helper.set_gate_status(eval_model, use_gate=False)
        
        return {
            "dataset": dataset_name,
            "original_correct": original_result["is_correct"],
            "original_correct_norm": original_result["is_correct_normalized"],
            "masked_correct": masked_result["is_correct"],
            "masked_correct_norm": masked_result["is_correct_normalized"],
            "original_prediction": original_result["prediction"],
            "masked_prediction": masked_result["prediction"],
            "original_prediction_norm": original_result["normalized_prediction"],
            "masked_prediction_norm": masked_result["normalized_prediction"],
            "label": original_result["label"]
        }

def test_dataset_with_lrp(
    eval_model,
    lrp_model,
    tokenizer,
    param_reg_structures,
    dataset_name,
    device1="cuda",
    device2="cuda",
    mask_ratio=0.1,
    output_dir="./output",
    num_samples=None,
    seqlen=2048,
    lrp_path=None
):
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Load dataset
    is_wikitext = (dataset_name.lower() == "wikitext")
    
    if is_wikitext:        

        # STEP 1 LOAD DATA
        # input_ids = build_wikitext_ids(tokenizer, split="test")   
        input_ids = build_c4_ids(tokenizer, split="train")
        # input_ids = build_alpaca_ids(tokenizer, split="train")
        # input_ids = build_redpajama_ids(tokenizer, split="train")

        # STEP 2 SAMPLE 
        # shape: [num_samples, 2048]
        samples = sample_wikitext_sequences(input_ids,
                                            seqlen=seqlen,
                                            n=num_samples,
                                            # n=None, # take all inputs
                                            random_sample=False)   
        
        print(samples.shape)
        
    else:
        # Load multiple-choice dataset
        dataset = load_mc_dataset(dataset_name, split="train") 
        
        # Determine sample size
        if num_samples is None or num_samples > len(dataset):
            print(f"Using all {len(dataset)} examples from {dataset_name}")
            num_samples = len(dataset)
            samples = dataset
        else:
            print(f"Using {num_samples} examples from {dataset_name}")
            # Use dataset.select() method instead of manual indexing
            samples = dataset.shuffle(seed=58).select(range(num_samples)) 
    
    all_results = []
    lrp_data_collector = LRPDataCollector()

    if is_wikitext:
        # Process each sample
        total_nll, total_tokens = 0, 0
        total_nll_origin, total_tokens_origin = 0, 0
        bs = 1
        for i, sample in enumerate(tqdm(samples, desc=f"Processing {dataset_name}")):
            batch = samples[i:i+bs]

            # Evaluate sample with its own LRP masks
            result = evaluate_sample_with_lrp(
                eval_model=eval_model,
                lrp_model=lrp_model,
                tokenizer=tokenizer,
                sample=batch,
                dataset_name=None if is_wikitext else dataset_name,
                param_reg_structures=param_reg_structures,
                mask_ratio=mask_ratio,
                output_dir=output_dir,
                device1=device1,
                device2=device2,
                is_wikitext=is_wikitext,
                seqlen=seqlen,
                lrp_data_collector=lrp_data_collector
            )
            
            # Add sample index
            total_nll += result["masked_ppl"]
            total_tokens += (batch.size(0) * (batch.size(1) - 1))
            result["sample_id"] = i+1
            all_results.append(result)
        
            total_nll_origin += result["original_ppl"]
            total_tokens_origin += (batch.size(0) * (batch.size(1) - 1))
    else:
        # Process MC dataset samples
        for i, sample in enumerate(tqdm(samples, desc=f"Processing {dataset_name}")):
            # Evaluate sample with its own LRP masks
            result = evaluate_sample_with_lrp(
                eval_model=eval_model,
                lrp_model=lrp_model,
                tokenizer=tokenizer,
                sample=sample,
                dataset_name=dataset_name,  # Pass dataset name for MC
                param_reg_structures=param_reg_structures,
                mask_ratio=mask_ratio,
                output_dir=output_dir,
                device1=device1,
                device2=device2,
                is_wikitext=is_wikitext,
                seqlen=seqlen,
                lrp_data_collector=lrp_data_collector
            )
            
            # Add sample index
            result["sample_id"] = i+1
            all_results.append(result)
    
    if lrp_path is not None:
        lrp_data_collector.save_all(save_path=lrp_path)
    else:
        raise ValueError("lrp_path is not provided")

    # Calculate aggregated statistics
    if is_wikitext:
        # For WikiText, calculate average PPL stats
        avg_original_ppl = torch.exp(total_nll_origin / total_tokens_origin)
        avg_masked_ppl = torch.exp(total_nll / total_tokens)
        # avg_masked_ppl = sum(r["masked_ppl"] for r in all_results) / len(all_results)
        # avg_change_percentage = sum(r["change_percentage"] for r in all_results) / len(all_results)
        
        print("\n===== WikiText LRP Mask Evaluation Summary =====")
        print(f"Number of samples: {len(all_results)}")
        print(f"Average original PPL: {avg_original_ppl:.2f}")
        print(f"Average masked PPL: {avg_masked_ppl:.2f}")
        # print(f"Average PPL change: {avg_change_percentage:.2f}%")
        
        return {
            "avg_original_ppl": avg_original_ppl,
            "avg_masked_ppl": avg_masked_ppl,
            # "avg_change_percentage": avg_change_percentage,
            "num_samples": len(all_results)
        }
    else:
        # For multiple-choice, calculate accuracy stats
        original_correct = sum(1 for r in all_results if r["original_correct"])
        masked_correct = sum(1 for r in all_results if r["masked_correct"])
        original_correct_norm = sum(1 for r in all_results if r["original_correct_norm"])
        masked_correct_norm = sum(1 for r in all_results if r["masked_correct_norm"])
        
        original_acc = original_correct / len(all_results)
        masked_acc = masked_correct / len(all_results)
        original_acc_norm = original_correct_norm / len(all_results)
        masked_acc_norm = masked_correct_norm / len(all_results)
        
        acc_change = ((masked_acc - original_acc) / original_acc * 100) if original_acc > 0 else 0
        acc_norm_change = ((masked_acc_norm - original_acc_norm) / original_acc_norm * 100) if original_acc_norm > 0 else 0
        
        print(f"\n===== {dataset_name} LRP Mask Evaluation Summary =====")
        print(f"Number of samples: {len(all_results)}")
        print(f"Original accuracy: {original_acc:.4f} ({original_correct}/{len(all_results)})")
        print(f"Masked accuracy: {masked_acc:.4f} ({masked_correct}/{len(all_results)})")
        print(f"Accuracy change: {acc_change:.2f}%")
        print(f"Original normalized accuracy: {original_acc_norm:.4f} ({original_correct_norm}/{len(all_results)})")
        print(f"Masked normalized accuracy: {masked_acc_norm:.4f} ({masked_correct_norm}/{len(all_results)})")
        print(f"Normalized accuracy change: {acc_norm_change:.2f}%")
        
        return {
            "original_acc": original_acc,
            "masked_acc": masked_acc,
            "acc_change": acc_change,
            "original_acc_norm": original_acc_norm,
            "masked_acc_norm": masked_acc_norm,
            "acc_norm_change": acc_norm_change,
            "num_samples": len(all_results)
        }

def test_all_datasets(
    model_path,
    device1="cuda",
    device2="cuda",
    mask_ratio=0.1,
    output_dir="./output",
    apply_monkey_patch=True,
    num_samples=None,
    seqlen=2048,
    lrp_path=None,
    datasets=None
):
    # Define datasets to test
    # "wikitext", "winogrande", "arc-e", "arc-c", "Rowan/hellaswag", "ybisk/piqa"
    datasets = [datasets]
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Set random seed for reproducibility
    random.seed(42)
    torch.manual_seed(42)
    np.random.seed(42)
    
    # Load model and tokenizer for evaluation - ONLY ONCE
    print(f"Loading evaluation model: {model_path}")
    #eval_model = PruneLlama3ForCausalLM.from_pretrained(
    eval_model = PruneLlama2ForCausalLM.from_pretrained(
        model_path, 
        torch_dtype=torch.float16,
        device_map=device1
    )
    eval_model.config.use_cache = False
    eval_model.eval()
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    
    # Load model for LRP computation with monkey patching if requested - ONLY ONCE
    print(f"Loading LRP computation model: {model_path}")
    if apply_monkey_patch:
        monkey_patch(modeling_llama, verbose=True)
        
    lrp_model = modeling_llama.LlamaForCausalLM.from_pretrained(
        model_path,
        device_map=device2,
        torch_dtype=torch.bfloat16 # bfloat16
    )
    if apply_monkey_patch:
        lrp_model._monkey_patched = True
    lrp_model.eval()
    

    # Parameter regularization structure information - ONLY ONCE
    param_reg = collect_info_reg_llama(eval_model, p=mask_ratio)
    
    # Store summary results
    summary_results = {}

    # Test each dataset
    for dataset_name in datasets:
        print(f"\n\n===== Testing {dataset_name} dataset =====\n")
        
        # Test dataset
        result = test_dataset_with_lrp(
            eval_model=eval_model,
            lrp_model=lrp_model,
            tokenizer=tokenizer,
            param_reg_structures=param_reg.structures,
            dataset_name=dataset_name,
            device1=device1,
            device2=device2,
            mask_ratio=mask_ratio,
            output_dir=output_dir,
            num_samples=num_samples,
            seqlen=seqlen,
            lrp_path=lrp_path
        )
        
        # Store results
        summary_results[dataset_name] = result
        
        print(f"\n===== Completed testing {dataset_name} dataset =====\n")
        
        # Force garbage collection to free memory
        # gc.collect()
        # torch.cuda.empty_cache()
    
    # Print overall summary
    print("\n\n===== OVERALL EVALUATION SUMMARY =====\n")
    
    for dataset_name, result in summary_results.items():
        print(f"\n{dataset_name} Summary:")
        
        if dataset_name.lower() == "wikitext":
            print(f"  Average original PPL: {result['avg_original_ppl']:.2f}")
            print(f"  Average masked PPL: {result['avg_masked_ppl']:.2f}")
            # print(f"  Average PPL change: {result['avg_change_percentage']:.2f}%")
        else:
            print(f"  Original accuracy: {result['original_acc']:.4f}")
            print(f"  Masked accuracy: {result['masked_acc']:.4f}")
            print(f"  Accuracy change: {result['acc_change']:.2f}%")
            print(f"  Original normalized accuracy: {result['original_acc_norm']:.4f}")
            print(f"  Masked normalized accuracy: {result['masked_acc_norm']:.4f}")
            print(f"  Normalized accuracy change: {result['acc_norm_change']:.2f}%")
        
        print(f"  Samples evaluated: {result['num_samples']}")
    
    print("\n===== Evaluation complete =====")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Evaluate model with LRP masks on multiple datasets")
    parser.add_argument("--model_path", type=str, default="xxx/llms/meta/Llama-2-7B-hf", # xxx/llms/meta/Llama-2-7B-hf, xxx/llms/meta/Llama-3.1-8B
                        help="Path to the model")
    parser.add_argument("--device1", type=str, default="cuda:0", 
                        help="Device for evaluation")
    parser.add_argument("--device2", type=str, default="cuda:2",
                        help="Device for LRP computation")
    parser.add_argument("--output_dir", type=str, default="./output", 
                        help="Directory for outputs")
    parser.add_argument("--lrp_path", type=str, default="xxx/project/DISP/arc-c/lrp_train_ppl.pkl")
    parser.add_argument("--mask_ratio", type=float, default=0.2, 
                        help="Mask ratio for LRP")
    parser.add_argument("--seqlen", type=int, default=2048, 
                        help="Sequence length for WikiText")
    parser.add_argument("--num_samples", type=lambda x: None if x.lower() == 'none' else int(x), default=None, help="Number of samples per dataset (None for all)")
    parser.add_argument("--datasets", type=str, default="arc-c", help="Datasets to test") # arc-c, winogrande, arc-e, arc-c, hellaswag, piqa, commonsense_qa

    args = parser.parse_args()
    
    # Test with specified parameters
    test_all_datasets(
        model_path=args.model_path,
        device1=args.device1,
        device2=args.device2,
        mask_ratio=args.mask_ratio,
        output_dir=args.output_dir,
        apply_monkey_patch=True,
        num_samples=args.num_samples,
        seqlen=args.seqlen,
        lrp_path=args.lrp_path,
        datasets=args.datasets
    )