import json
import math
from functools import partial
from typing import Optional, Tuple

import einops
import numpy as np
import pandas as pd
from datasets import Dataset
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from EXP3 import calculate_ppl

from KnowledgeSynapticNetwork.BaseNeuroSynapticEditing import BaseNeuroSynapticEdit
from .patch_mlp import *

class NeuroSynapticEdit(BaseNeuroSynapticEdit):
    def __init__(self, model_name_or_path='gpt2', device='cuda'):
        super().__init__(model_name_or_path, device)

    def suppress_mlp(
            self,
            query: str,
            ground_truth: str,
            neurons: List[List[int]],
            undo_modification: bool = True,
    ) -> Tuple[dict, Callable]:
        """
        query the model with `query`, zeroing the activations at the positions specified by `neurons`,
        and measure the resulting effect on the ground truth probability.
        """
        return self._modify_neurons(
            query=query,
            ground_truth=ground_truth,
            neurons=neurons,
            mode="suppress",
            undo_modification=undo_modification,
            layer_attr=self.mlp_input,
        )
    def enhance_mlp(
            self,
            query: str,
            ground_truth: str,
            neurons: List[List[int]],
            undo_modification: bool = True,
    ) -> Tuple[dict, Callable]:
        """
        query the model with `query`, multiplying the activations at the positions
        specified by `neurons` by 2, and measure the resulting affect on the ground truth probability.
        """
        return self._modify_neurons(
            query=query,
            ground_truth=ground_truth,
            neurons=neurons,
            mode="enhance",
            undo_modification=undo_modification,
            layer_attr=self.mlp_input,
        )


    @torch.no_grad()
    def erase_knowledge(self, query, answer, neurons, query_related, query_unrelated, answer_unrelated):
        results_dict = {}

        (
            gt_baseline_prob,
            argmax_baseline_prob,
            argmax_completion_str,
            argmax_tokens,
        ) = self._generate(query, answer)
        results_dict["before"] = {
            "gt_prob": gt_baseline_prob,
            "argmax_completion": argmax_completion_str,
            "argmax_prob": argmax_baseline_prob,
        }

        (
            gt_baseline_prob_generalize,
            argmax_baseline_prob_generalize,
            argmax_completion_str_generalize,
            argmax_tokens_generalize,
        ) = self._generate(query_related, answer)
        results_dict["before"] = {
            "gt_prob": gt_baseline_prob_generalize,
            "argmax_completion": argmax_completion_str_generalize,
            "argmax_prob": argmax_baseline_prob_generalize,
        }

        (
            gt_baseline_prob_specific,
            argmax_baseline_prob_specific,
            argmax_completion_str_specific,
            argmax_tokens_specific,
        ) = self._generate(query_unrelated, answer_unrelated)
        results_dict["before"] = {
            "gt_prob": gt_baseline_prob_specific,
            "argmax_completion": argmax_completion_str_specific,
            "argmax_prob": argmax_baseline_prob_specific,
        }

        results_dict['baseline_ppl'] = calculate_ppl(sentence=query, model=self.model,
                                            tokenizer=self.tokenizer, device=self.device)

        erase_value = 0

        # modify the weights by subtracting the original prediction's word embedding
        # and adding the target embedding
        original_weight_values = []  # to reverse the action later
        for layer_idx, position in neurons:
            output_ff_weights = self._get_output_ff_layer(layer_idx)
            if 'gpt' in self.model_name:
                original_weight_values.append(
                    output_ff_weights[position, :].detach().clone()
                )
                output_ff_weights[position, :] = erase_value
            elif 'Llama' or 'gemma' in self.model_name:
                original_weight_values.append(
                    output_ff_weights[:, position].detach().clone()
                )
                output_ff_weights[:, position] = erase_value
            else:
                raise NotImplementedError

        # get the probabilities of the target being generated + the argmax / greedy completion after modifying the weights
        (
            new_gt_prob,
            new_argmax_prob,
            new_argmax_completion_str,
            new_argmax_tokens,
        ) = self._generate(query, answer)
        results_dict["after"] = {
            "gt_prob": new_gt_prob,
            "argmax_completion": new_argmax_completion_str,
            "argmax_prob": new_argmax_prob,
        }


        (
            new_gt_prob_generalize,
            new_argmax_prob_generalize,
            new_argmax_completion_str_generalize,
            new_argmax_tokens_generalize,
        ) = self._generate(query_related, answer)

        results_dict["generalize"] = {
            "gt_prob": new_gt_prob_generalize,
            "argmax_completion": new_argmax_completion_str_generalize,
            "argmax_prob": new_argmax_prob_generalize,
        }


        (
            new_gt_prob_specific,
            new_argmax_prob_specific,
            new_argmax_completion_str_specific,
            new_argmax_tokens_specific,
        ) = self._generate(query_unrelated, answer_unrelated)
        results_dict["specific"] = {
            "gt_prob": new_gt_prob_specific,
            "argmax_completion": new_argmax_completion_str_specific,
            "argmax_prob": new_argmax_prob_specific,
        }


        results_dict['prob_change_acc'] = (new_gt_prob - gt_baseline_prob) / gt_baseline_prob if gt_baseline_prob != 0 else 0 # Acc
        results_dict['prob_change_generalize'] = (new_gt_prob_generalize - gt_baseline_prob_generalize) / gt_baseline_prob_generalize if gt_baseline_prob_generalize != 0 else 0 # generalize
        results_dict['prob_change_specific'] = (new_gt_prob_specific - gt_baseline_prob_specific) / gt_baseline_prob_specific if gt_baseline_prob_specific != 0 else 0 # specific
        results_dict['new_ppl'] = calculate_ppl(sentence=query, model=self.model,
                                            tokenizer=self.tokenizer, device=self.device)

        def unpatch_fn():
            for idx, (layer_idx, position) in enumerate(neurons):
                output_ff_weights = self._get_output_ff_layer(layer_idx)
                if 'gpt' in self.model_name:
                    output_ff_weights[position, :] = original_weight_values[idx]
                elif 'Llama' or 'gemma' in self.model_name:
                    output_ff_weights[:, position] = original_weight_values[idx]
                else:
                    raise NotImplementedError

        unpatch_fn()

        return results_dict

    @torch.no_grad()
    def cal_ppl_wikitext(self, neurons_for_erase, save_path,
        wiki_path = '/home/chenyuheng/chenyuheng/NIPS2024/Datasets/EXP3/test-wikitext-2-v1.parquet'):
        if neurons_for_erase:
            erase_value = 0
            original_weight_values = []  # to reverse the action later
            for layer_idx, position in neurons_for_erase:
                output_ff_weights = self._get_output_ff_layer(layer_idx)
                if 'gpt' in self.model_name:
                    original_weight_values.append(
                        output_ff_weights[position, :].detach().clone()
                    )
                    output_ff_weights[position, :] = erase_value
                elif 'Llama' or 'llama' in self.model_name:
                    original_weight_values.append(
                        output_ff_weights[:, position].detach().clone()
                    )
                    output_ff_weights[:, position] = erase_value
                else:
                    raise NotImplementedError

        data = pd.read_parquet(wiki_path)

        hf_dataset_full = Dataset.from_pandas(data)["text"]
        # hf_dataset_full_filter = [s for s in hf_dataset_full if 100<len(s)<200]
        hf_dataset = [s for s in hf_dataset_full if s != '']
        # hf_dataset = hf_dataset_full_filter[:5]
        encodings = self.tokenizer("\n\n".join(hf_dataset), return_tensors="pt").to(self.device)
        max_length = 512
        stride = 512
        seq_len = encodings.input_ids.size(1)
        nlls = []
        generated_texts = []
        prev_end_loc = 0
        for begin_loc in tqdm(range(0, seq_len, stride)):
            end_loc = min(begin_loc + max_length, seq_len)
            trg_len = end_loc - prev_end_loc  # may be different from stride on last loop
            input_ids = encodings.input_ids[:, begin_loc:end_loc].to(self.device)
            target_ids = input_ids.clone()
            target_ids[:, :-trg_len] = -100

            with torch.no_grad():
                outputs = self.model(input_ids, labels=target_ids)
                neg_log_likelihood = outputs.loss
            nlls.append(neg_log_likelihood)
            prev_end_loc = end_loc
            if end_loc == seq_len:
                break
        ppl = torch.exp(torch.stack(nlls).mean()).item()

        # Generate text examples
        for text_input in hf_dataset:
            input_ids = self.tokenizer.encode(text_input, return_tensors="pt").to(self.device)
            with torch.no_grad():
                generated_ids = self.model.generate(input_ids, max_new_tokens=20)
            generated_text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
            prefix_length = len(text_input)
            if generated_text.startswith(text_input):
                generated_text = generated_text[prefix_length:].strip()
            # generated_texts.append({'original': text_input, 'generated': generated_text})
        # result = f"Model: {self.model_name}, Perplexity: {ppl}, generated_texts': {generated_texts}"
        result = f"Model: {self.model_name}, Perplexity: {ppl}"
        # print(result)

        with open(save_path, 'a') as json_file:
            json_file.write(json.dumps(result) + '\n')


    def suppress_attn(self, query, answer, synapses):
        return self._modify_synapses(
            query=query, ground_truth=answer, synapses=synapses)

    def enhance_attn(self, query, answer, synapses):
        return self._modify_synapses(
            query=query, ground_truth=answer, synapses=synapses)


    def test_generalization(self, input_texts):
        """
        Test the model's ability to generalize the knowledge edit across different expressions.
        """
        # Placeholder for generalization testing logic
        pass

    def test_specificity(self, unrelated_input_texts):
        """
        Test the specificity of the knowledge edit, ensuring unrelated knowledge remains unaffected.
        """
        # Placeholder for specificity testing logic
        pass

    def iterative_editing_process(self):
        """
        The main iterative process for editing knowledge within the model.
        """
        # Placeholder for the iterative editing process logic
        pass







def main():
    model_components = NeuroSynapticEdit(model_name_or_path='gpt2', device='cuda')

    # Example of accessing and printing the shape of MLP output weights
    layer = model_components.model.transformer.h[0]  # Access the first transformer layer
    mlp_output_weights = getattr(layer, model_components.mlp_output.split('.')[-1])
    print(f"Shape of MLP output weights: {mlp_output_weights.shape}")

    # Example of accessing and printing the shape of attention QKV combined weights
    attn_QKV_weights = getattr(layer, model_components.attn_QKV_input.split('.')[-1])
    print(f"Shape of combined QKV weights: {attn_QKV_weights.shape}")


if __name__ == "__main__":
    main()