import json
import math
from itertools import chain
import evaluate
import pandas as pd
import torch
from datasets import Dataset, load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import DataCollatorForLanguageModeling, AutoTokenizer, AutoModelForCausalLM, AutoModelWithLMHead

from KnowledgeSynapticNetwork.NeuroSynapticEditing import NeuroSynapticEdit
from KnowledgeSynapticNetwork.patch_mlp import get_ff_layer
from KnowledgeSynapticNetwork.utils import read_lama_json
import lmppl

def calculate_ppl(sentence, model, tokenizer, device='cuda'):
    tokenize_input = tokenizer.tokenize(sentence)
    tensor_input = torch.tensor([tokenizer.convert_tokens_to_ids(tokenize_input)]).to(device)
    with torch.no_grad():
        loss = model(tensor_input, labels=tensor_input)[0]
    return torch.exp(loss).item()


def load_evaldata_from_json(eval_data_path):
    with open(eval_data_path, 'r') as file:
        data_json = json.load(file)

    # Prepare a list of sentences for evaluation
    sentences = [sentence for item in data_json.values() for sentence in item['sentences']]

    # Create a dataset from the list of sentences
    eval_dataset = Dataset.from_dict({'text': sentences})
    return eval_dataset


def load_evaldata(eval_data_path, tokenizer, max_seq_length=128):
    # Define the tokenization and grouping functions as before
    def tokenize_function(examples):
        # Tokenize the text and truncate or pad to max_seq_length
        return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=max_seq_length,
                         return_tensors="pt")

    eval_dataset = load_evaldata_from_json(eval_data_path)
    # Tokenize the dataset
    tokenized_datasets = eval_dataset.map(
        tokenize_function,
        batched=True,
        remove_columns=["text"],
        load_from_cache_file=False,
        desc="Running tokenizer on every text in dataset",
    )

    def group_texts(examples):
        # This function assumes examples have been tokenized into input IDs
        concatenated_input_ids = list(chain(*examples["input_ids"]))
        total_length = len(concatenated_input_ids)
        if total_length >= max_seq_length:
            total_length = (total_length // max_seq_length) * max_seq_length
        result = {
            "input_ids": [concatenated_input_ids[i:i + max_seq_length] for i in range(0, total_length, max_seq_length)]
        }
        return result

    # Group texts into chunks
    tokenized_datasets = tokenized_datasets.map(
        group_texts,
        batched=True,
        load_from_cache_file=False,
        desc=f"Grouping texts in chunks of {max_seq_length}",
    )

    # Prepare the DataLoader
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer,
                                                    mlm=False)  # Ensure mlm matches your model's type
    eval_dataloader = DataLoader(tokenized_datasets, collate_fn=data_collator, batch_size=8)

    return eval_dataloader

def eval_ppl(eval_dataloader,device,model):
    losses = []
    for step, batch in enumerate(eval_dataloader):
        with torch.no_grad():
            batch.to(device)
            outputs = model(**batch)

        loss = outputs.loss
        losses.append(loss.repeat(8))

    losses = torch.cat(losses)
    try:
        eval_loss = torch.mean(losses)
        perplexity = math.exp(eval_loss)
    except OverflowError:
        perplexity = float("inf")

    return perplexity


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def calculate_ppl_wikitext(model_name, neuron_for_erase):
    if model_name == 'gpt2':
        model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
    else:
        model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if neuron_for_erase:

        def _get_output_ff_layer(layer_idx):
            if "gpt2" in model_name:
                return get_ff_layer(
                    model,
                    layer_idx,
                    transformer_layers_attr="transformer.h",
                    ff_attrs="mlp.c_fc",
                )
            if "Llama" in model_name:
                return get_ff_layer(
                    model,
                    layer_idx,
                    transformer_layers_attr="model.layers",
                    ff_attrs="mlp.gate_proj",
                )

        original_weight_values = []  # to reverse the action later
        for layer_idx, position in neuron_for_erase:
            output_ff_weights = _get_output_ff_layer(layer_idx).weight
            # if 'gpt' in model_name:
            #     original_weight_values.append(
            #         output_ff_weights[position, :].detach().clone()
            #     )
            #     output_ff_weights[position, :] = 0
            # elif 'Llama' or 'gemma' in model_name:
            original_weight_values.append(
                output_ff_weights[:, position].detach().clone()
            )
            output_ff_weights[:, position] = 0
            # else:
            #     raise NotImplementedError


    wiki_path = '/home/chenyuheng/chenyuheng/NIPS2024/Datasets/EXP3/test-wikitext-2-v1.parquet'

    data = pd.read_parquet(wiki_path)

    hf_dataset_full = Dataset.from_pandas(data)["text"]
    hf_dataset_full_filter = [s for s in hf_dataset_full if s != ""]
    hf_dataset = hf_dataset_full_filter[:5]
    encodings = tokenizer("\n\n".join(hf_dataset), return_tensors="pt").to(device)
    max_length=512
    stride = 512
    seq_len = encodings.input_ids.size(1)
    nlls = []
    generated_texts = []
    prev_end_loc = 0
    for begin_loc in tqdm(range(0, seq_len, stride)):
        end_loc = min(begin_loc + max_length, seq_len)
        trg_len = end_loc - prev_end_loc  # may be different from stride on last loop
        input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
        target_ids = input_ids.clone()
        target_ids[:, :-trg_len] = -100

        with torch.no_grad():
            outputs = model(input_ids, labels=target_ids)
            neg_log_likelihood = outputs.loss
        nlls.append(neg_log_likelihood)
        prev_end_loc = end_loc
        if end_loc == seq_len:
            break
    ppl = torch.exp(torch.stack(nlls).mean()).item()


    # Generate text examples
    for text_input in hf_dataset:
        input_ids = tokenizer.encode(text_input, return_tensors="pt").to(device)
        with torch.no_grad():
            generated_ids = model.generate(input_ids, max_new_tokens=50)
        generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
        generated_texts.append({'original': text_input, 'generated': generated_text})

    print(f"Model: {model_name}, Perplexity: {ppl}, generated_texts': {generated_texts}")

    return {'model_name': model_name, 'perplexity': ppl, 'generated_texts': generated_texts}




