import argparse
from importlib.metadata import version
import os 
import random
import numpy as np
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from tqdm import tqdm 

from lib.dataset_loader import build_wikitext_ids, build_ptb_ids, sample_wikitext_sequences, calculate_perplexity, evaluate_mc_dataset

# Set seed for reproducibility
def set_seed(seed):
    np.random.seed(seed)
    torch.random.manual_seed(seed)

# Wrapper for tokenized input IDs
class TokenizerWrapper:
    def __init__(self, input_ids):
        self.input_ids = input_ids

# Load and process wikitext2 dataset
def get_wikitext2(nsamples, seed, seqlen, tokenizer):
    # Load train and test datasets
    traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
    testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')

    # Encode datasets
    trainenc = tokenizer(" ".join(traindata['text']), return_tensors='pt')
    testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt')

    # Generate samples from training set
    random.seed(seed)
    trainloader = []
    for _ in range(nsamples):
        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
        j = i + seqlen
        inp = trainenc.input_ids[:, i:j]
        tar = inp.clone()
        tar[:, :-1] = -100
        trainloader.append((inp, tar))
    return trainloader, testenc

# Load and process c4 dataset
def get_c4(nsamples, seed, seqlen, tokenizer):
    # Load train and validation datasets
    traindata = load_dataset('allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train')
    valdata = load_dataset('allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation')

    # Generate samples from training set
    random.seed(seed)
    trainloader = []
    for _ in range(nsamples):
        while True:
            i = random.randint(0, len(traindata) - 1)
            trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
            if trainenc.input_ids.shape[1] > seqlen:
                break
        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
        j = i + seqlen
        inp = trainenc.input_ids[:, i:j]
        tar = inp.clone()
        tar[:, :-1] = -100
        trainloader.append((inp, tar))

    # Prepare validation dataset
    valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt')
    valenc = valenc.input_ids[:, :(256 * seqlen)]
    valenc = TokenizerWrapper(valenc)
    return trainloader, valenc

# Function to select the appropriate loader based on dataset name
def get_loaders(name, nsamples=128, seed=0, seqlen=4096, tokenizer=None):
    if 'wikitext2' in name:
        return get_wikitext2(nsamples, seed, seqlen, tokenizer)
    if "c4" in name:
        return get_c4(nsamples, seed, seqlen, tokenizer)

def eval_ppl(args, model, tokenizer, device=torch.device("cuda:0")):
    dataset = "wikitext2"

    print(f"evaluating on {dataset}")

    _, testloader = get_loaders(
        dataset, seed=0, seqlen=model.seqlen, tokenizer=tokenizer,
    )

    with torch.no_grad():
        ppl_test = eval_ppl_wikitext(model, testloader, 1, device)
    return ppl_test 

def eval_ppl_wikitext_train(model, trainloader, bs=1, device=None):
    nsamples = len(trainloader)
    nlls = []
    print(f"nsamples {nsamples}")

    for i in tqdm(range(0, nsamples, bs), desc="Evaluating train set"):
        j = min(i+bs, nsamples)
        inputs = trainloader[i][0].to(device)
        inputs = inputs.reshape(j-i, model.seqlen)

        lm_logits = model(inputs).logits

        shift_logits = lm_logits[:, :-1, :].contiguous()
        shift_labels = inputs[:, 1:]

        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1))

        neg_log_likelihood = loss.float() * model.seqlen * (j-i)

        nlls.append(neg_log_likelihood)

    ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))

    torch.cuda.empty_cache()

    return ppl.item()

def eval_ppl_wikitext(model, testenc, bs=1, device=None):
    testenc = testenc.input_ids
    nsamples = testenc.numel() // model.seqlen

    nlls = []
    print(f"nsamples {nsamples}")

    for i in tqdm(range(0, nsamples, bs), desc="Evaluating test set"):
        j = min(i+bs, nsamples)
        inputs = testenc[:,(i * model.seqlen):(j * model.seqlen)].to(device)
        inputs = inputs.reshape(j-i, model.seqlen)

        lm_logits = model(inputs).logits

        shift_logits = lm_logits[:, :-1, :].contiguous()
        shift_labels = inputs[:, 1:]

        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1))
        neg_log_likelihood = loss.float() * model.seqlen * (j-i)

        nlls.append(neg_log_likelihood)

    ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))

    torch.cuda.empty_cache()

    return ppl.item()

print('torch', version('torch'))
print('transformers', version('transformers'))
print('accelerate', version('accelerate'))
print('# of gpus: ', torch.cuda.device_count())

def get_llm(model_name, seqlen=4096, device=None):
    model = AutoModelForCausalLM.from_pretrained(
        model_name, 
        torch_dtype=torch.float16, 
        device_map=device if device else "auto"
    )
    # model.seqlen = 4096 if model.config.max_position_embeddings>=4096 else model.config.max_position_embeddings
    model.seqlen = seqlen
    print(f"model.seqlen: {model.seqlen}")
    return model

def default_evaluation(args):
    model_name = args.model.split("/")[-1]
    print(f"loading llm model {args.model}")
    model = get_llm(args.model, args.seqlen)
    model.eval()

    tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)
    device = torch.device(args.device)
    print("use device ", device)
    ppl_test = eval_ppl(args, model, tokenizer, device)
    print(f"wikitext perplexity {ppl_test}")

def sample_evaluation(args, name="wikitext"):
    model_name = args.model.split("/")[-1]
    print(f"loading llm model {args.model}")
    model = get_llm(args.model, args.seqlen, args.device)
    model.eval()

    tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)
    device = torch.device(args.device)
    print("use device ", device)

    if name == "wikitext":
        input_ids = build_wikitext_ids(tokenizer, split="test")
    elif name == "ptb":
        input_ids = build_ptb_ids(tokenizer, split="test")
    samples = sample_wikitext_sequences(input_ids,
                                            seqlen=args.seqlen,
                                            # n=num_samples,
                                            n=None, # get all inputs
                                            random_sample=False)   

    bs = 1
    total_nll, total_tokens = 0, 0
    # record the perplexity of each sample, only for visualization
    per_sample_ppls = []
    for i in tqdm(range(0, samples.size(0), bs), desc="Processing wikitext"):
        batch = samples[i : i + bs]            # shape [B, seqlen]
        nll = calculate_perplexity(
            model,
            batch,
            limit_length=args.seqlen,
            device=device,
        )
        total_nll += nll
        # record the perplexity of each sample
        per_sample_ppl = torch.exp(nll / (batch.size(0) * (batch.size(1) - 1)))
        print(f"per_sample_ppl: {per_sample_ppl}")
        per_sample_ppls.append(per_sample_ppl)
        
        total_tokens += (batch.size(0) * (batch.size(1) - 1))
    
    # save the perplexity of each sample
    per_sample_ppls_cpu = [ppl.cpu().item() for ppl in per_sample_ppls]
    import pickle
    with open(f'xxx/raw_llama_per_sample_ppls.pkl', 'wb') as f:
        pickle.dump(per_sample_ppls_cpu, f)
    print(f"saved per_sample_ppls to raw_llama_per_sample_ppls.pkl")

    ppl = torch.exp(total_nll / total_tokens)
    print(f"wikitext perplexity {ppl:.4f}")

def mc_evaluation(args, datasets):
    model_name = args.model.split("/")[-1]
    print(f"loading llm model {args.model}")
    model = get_llm(args.model, args.seqlen, args.device)
    model.eval()
    tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)
    device = torch.device(args.device)
    print("use device ", device)
    
    # Dictionary to store all results
    all_results = {}
    
    for dataset_name in datasets:
        print(f"\nEvaluating on {dataset_name}...")
        results = evaluate_mc_dataset(
            model, 
            tokenizer, 
            dataset_name, 
            device=device if torch.cuda.is_available() else "cpu",
            num_examples=None,
            split="test"
        )
        
        # Store results
        all_results[dataset_name] = results
        
        # Print individual results (optional - you can remove this if you only want the summary)
        print(f"\nResults for {dataset_name}:")
        print(f"  Accuracy (acc): {results['acc']:.4f}")
        print(f"  Normalized Accuracy (acc_norm): {results['acc_norm']:.4f}")
        if "note" in results:
            print(f"  Note: {results['note']}")
    
    # Print summary of all results after evaluation
    print("\n" + "="*60)
    print("SUMMARY OF ALL RESULTS")
    print("="*60)
    print(f"Model: {model_name}")
    print("-"*60)
    print(f"{'Dataset':<20} {'ACC':>10} {'ACC_NORM':>10} {'Samples':>10}")
    print("-"*60)
    
    # Calculate average scores
    total_acc = 0
    total_acc_norm = 0
    
    for dataset_name, results in all_results.items():
        print(f"{dataset_name:<20} {results['acc']:>10.4f} {results['acc_norm']:>10.4f} {results['num_examples']:>10}")
        total_acc += results['acc']
        total_acc_norm += results['acc_norm']


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    # xxx/llms/meta/Llama-3.1-8B
    # xxx/llms/meta/Llama-2-13b
    # xxx/llms/phi-2
    parser.add_argument('--model', type=str, default="xxx/llms/phi-2", help='LLaMA model')
    parser.add_argument('--seed', type=int, default=58, help='Seed for sampling the calibration data.')
    parser.add_argument('--nsamples', type=int, default=128, help='Number of calibration samples.')
    parser.add_argument('--seqlen', type=int, default=2048, help='Sequence length.')
    parser.add_argument('--device', type=str, default="cuda:3", help='Device.')
    args = parser.parse_args()


    np.random.seed(args.seed)
    torch.random.manual_seed(args.seed)
    # sample_evaluation(args, name="ptb")
    # hellaswag,arc_easy,arc_challenge,piqa,winogrande,boolq,openbookqa
    mc_evaluation(args, datasets=["boolq", "openbookqa", "piqa", "winogrande", "arc-e", "arc-c", "hellaswag"])