import json
import os
import random
import re

import numpy as np
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer


@torch.no_grad()
def get_proj(V, emb, norm=None):
    if norm is not None:
        return norm(emb) @ V
    else:
        return emb @ V


def seed_everything(seed: int):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True


def load_model(model_name, **kwargs):
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        # torch_dtype="auto",
        # device_map="cuda:0",
        **kwargs,
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    return model, tokenizer


def jsonlload(path: str) -> list:
    with open(path, "r", encoding="utf-8") as f:
        data = f.readlines()
    for i, string in enumerate(data):
        data[i] = json.loads(string)
    return data


def split_data(data, train_ratio, shuffle=False):
    train_size = int(len(data) * train_ratio)
    if shuffle:
        indices = torch.randperm(len(data))
        train_data = [data[i] for i in indices[:train_size]]
        valid_data = [data[i] for i in indices[train_size:]]
    else:
        train_data = data[:train_size]
        valid_data = data[train_size:]
    return train_data, valid_data


def load_data(pt_path, key, device, need_proj: bool = False, **kwargs):
    with torch.no_grad():
        res = torch.load(pt_path, map_location=device)
        flags = torch.tensor(res["all_hallucination_flag"]).to(device)
        if need_proj is False:
            sentences = res["all_emb"][key]
            return list(zip(sentences, flags))
        else:
            svd_path = kwargs["svd_path"]
            norm = kwargs["norm"]
            proj_dim = kwargs["proj_dim"]

            _, _, V = torch.load(svd_path)
            sentences = [
                get_proj(V[:, -proj_dim:], emb, norm) for emb in res["all_emb"][key]
            ]
            return list(zip(sentences, flags))


def get_prob_in_pos(generated_ans, pos):
    logits_pos = generated_ans.logits[pos]
    prob_list = torch.nn.functional.softmax(logits_pos, dim=-1)[0, :]
    return logits_pos, prob_list


class ModelUtils:
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer

    def get_response(self, text, **kwargs):
        if isinstance(text, list):
            model_inputs = self.tokenizer(text, return_tensors="pt", padding=True).to(
                self.model.device
            )
        else:
            model_inputs = self.tokenizer([text], return_tensors="pt").to(
                self.model.device
            )
        # print("model_inputs:")
        # print(model_inputs['input_ids'])
        input_len = model_inputs.input_ids.size(1)

        kwargs["return_dict_in_generate"] = True
        kwargs["output_attentions"] = True
        kwargs["output_hidden_states"] = True
        kwargs["output_scores"] = True
        kwargs["output_logits"] = True

        generated_ans = self.model.generate(
            **model_inputs,
            # max_new_tokens=1,
            # do_sample=False,
            # top_k=1,`
            # top_p=0.0,
            **kwargs,
            # return_dict_in_generate=True,
            # output_attentions=True,
            # output_hidden_states=True,
            # output_scores=True,
            # output_logits=True
        )

        generated_ids = generated_ans["sequences"]
        # attentions = generated_ans["attentions"]

        generated_ids = [
            output_ids[len(input_ids) :]
            for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
        ]

        pos2text = []
        for item in model_inputs.input_ids[0]:
            pos2text.append(self.tokenizer.decode(item))
        for item in generated_ids[0]:
            pos2text.append(self.tokenizer.decode(item))
        response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[
            0
        ]

        debug_str = ""
        # print("input:")
        # print("".join(pos2text[:input_len]))
        debug_str += "input:\n" + "".join(pos2text[:input_len]) + "\n"
        # print("output:")
        # print("".join(pos2text[input_len:]))
        debug_str += "output:\n" + "".join(pos2text[input_len:]) + "\n"

        top_k = 5
        all_pos_token_prob = self.get_top_k_prob(
            top_k=top_k, generated_ans=generated_ans
        )
        generated_prob_top_k_pos = []
        for pos in range(input_len, generated_ans.sequences.size(1)):
            token_prob_list = all_pos_token_prob[pos]
            token_prob_str = []
            for k in range(top_k):
                token_prob_str.append(
                    f"{token_prob_list[k]['word']}:p={token_prob_list[k]['prob']:.2f}, logits={token_prob_list[k]['logits']:.2f}"
                )
            generated_prob_top_k_pos.append(
                repr(f"{pos}[{pos2text[pos]}](" + " | ".join(token_prob_str) + ")")
            )
        # print("prob:")
        # print("\n".join(generated_prob_top_k_pos))
        debug_str += "prob:\n" + "\n".join(generated_prob_top_k_pos) + "\n"

        return model_inputs, response, pos2text, generated_ans, debug_str

    def get_chat_text(self, prompt):
        messages = [
            {"role": "user", "content": prompt}
        ]

        text = self.tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        return text

    def get_top_k_prob(self, top_k, generated_ans):
        input_len = generated_ans.sequences.size(1) - len(generated_ans.logits)
        all_pos_token_prob = {}
        for pos in range(input_len, generated_ans.sequences.size(1)):
            logits_pos, prob_list = get_prob_in_pos(generated_ans, pos - input_len)

            prob_top_k, top_indices = torch.topk(prob_list, k=top_k)
            prob_top_k = prob_top_k.tolist()
            top_indices = top_indices.tolist()

            token_prob_list = []
            for k in range(top_k):
                token_prob_list.append(
                    {
                        "token": top_indices[k],
                        "word": self.tokenizer.decode(top_indices[k]),
                        "prob": prob_top_k[k],
                        "logits": logits_pos[0, top_indices[k]].item(),
                    }
                )
            all_pos_token_prob[pos] = token_prob_list
        return all_pos_token_prob


class ModelHook(ModelUtils):
    def __init__(self, model, tokenizer):
        super().__init__(model, tokenizer)
        self.module2name = {}
        self.input_dict = {}
        self.output_dict = {}

        self.hook_handle_list = []

        # hook model input and output
        # self.hook_model_input_output()
        # self.is_hook = True
        self.is_hook = False

        self.remove_bigger_num = 0
        self.remove_bigger_target_value = 0

        self.remove_smaller_num = 0
        self.remove_smaller_target_value = 0

        self.skip_output_from = None
        self.skip_output_to = None

        self.proj_base_matrix = None

    def remove_all_hook(self):
        for handel in self.hook_handle_list:
            handel.remove()

        self.hook_handle_list.clear()
        self.is_hook = False

    def remove_hook_handle_list(self, hook_handle_list: list):
        for handel in hook_handle_list:
            if handel in self.hook_handle_list:
                handel.remove()
                self.hook_handle_list.remove(handel)

        if len(self.hook_handle_list) == 0:
            self.is_hook = False

    def clear_io(self):
        self.input_dict.clear()
        self.output_dict.clear()

    def _hook_all(self, module, input, output):
        name = self.module2name[module]
        if name not in self.input_dict:
            self.input_dict[name] = []
            self.output_dict[name] = []

        self.input_dict[name].append(input)
        self.output_dict[name].append(output)

    def _hook_remove_bigger(self, module, input, output):
        remove_num = self.remove_bigger_num
        target_value = self.remove_bigger_target_value
        if isinstance(output, tuple):
            all_data = output[0]
        else:
            all_data = output
        bsz, seq_len, hidden_size = all_data.size()
        for i in range(bsz):
            for j in range(seq_len):
                data = all_data[i, j, :]
                sorted_data = torch.sort(torch.abs(data), descending=True)[0]
                filter_value_ind = (
                    int(remove_num) if remove_num >= 1 else int(remove_num * len(data))
                )
                filter_value = sorted_data[filter_value_ind]
                data[torch.abs(data) > filter_value] = target_value
                # if isinstance(output, tuple):
                #     output[0][i, j, :] = data
                # else:
                #     output[i, j, :] = data
        return output

    def _hook_remove_smaller(self, module, input, output):
        remove_num = self.remove_smaller_num
        target_value = self.remove_smaller_target_value
        if isinstance(output, tuple):
            all_data = output[0]
        else:
            all_data = output
        bsz, seq_len, hidden_size = all_data.size()
        for i in range(bsz):
            for j in range(seq_len):
                data = all_data[i, j, :]
                sorted_data = torch.sort(torch.abs(data), descending=False)[0]
                filter_value_ind = (
                    int(remove_num) if remove_num >= 1 else int(remove_num * len(data))
                )
                filter_value = sorted_data[filter_value_ind]
                data[torch.abs(data) < filter_value] = target_value
                # if isinstance(output, tuple):
                #     output[0][i, j, :] = data
                # else:
                #     output[i, j, :] = data
        return output

    def _hook_skip_output(self, module, input, output):
        name = self.module2name[module]
        if name == self.skip_output_to:
            output = self.output_dict[self.skip_output_from][-1]
        return output

    def _hook_clean_proj(self, module, input, output):
        if isinstance(output, tuple):
            all_data = output[0]
        else:
            all_data = output
        bsz, seq_len, hidden_size = all_data.size()

        for i in range(bsz):
            for j in range(seq_len):
                data = all_data[i, j, :]

                p = torch.matmul(data, self.proj_base_matrix.t())
                data = data - torch.matmul(p, self.proj_base_matrix)

                if isinstance(output, tuple):
                    output[0][i, j, :] = data
                else:
                    output[i, j, :] = data
        return output

    def hook_model_input_output(self, hook_name_list=[]):
        cur_handle = []
        if len(hook_name_list) > 0:
            for name, module in self.model.named_modules():
                if name in hook_name_list:
                    self.module2name[module] = name
                    handle = module.register_forward_hook(self._hook_all)
                    cur_handle.append(handle)
        else:
            for name, module in self.model.named_modules():
                self.module2name[module] = name
                handle = module.register_forward_hook(self._hook_all)
                cur_handle.append(handle)
        self.hook_handle_list.extend(cur_handle)
        if cur_handle:
            self.is_hook = True
        return cur_handle

    def hook_model_remove_bigger(self, name_pattern, remove_num, target_value=0):
        self.remove_bigger_num = remove_num
        self.remove_bigger_target_value = target_value

        cur_handle = []
        for name, module in self.model.named_modules():
            if re.match(name_pattern, name):
                self.module2name[module] = name
                handle = module.register_forward_hook(self._hook_remove_bigger)
                cur_handle.append(handle)
        self.hook_handle_list.extend(cur_handle)
        if cur_handle:
            self.is_hook = True
        return cur_handle

    def hook_model_remove_smaller(self, name_pattern, remove_num, target_value=0):
        self.remove_bigger_num = remove_num
        self.remove_bigger_target_value = target_value

        cur_handle = []
        for name, module in self.model.named_modules():
            if re.match(name_pattern, name):
                self.module2name[module] = name
                handle = module.register_forward_hook(self._hook_remove_smaller)
                cur_handle.append(handle)
        self.hook_handle_list.extend(cur_handle)
        if cur_handle:
            self.is_hook = True
        return cur_handle

    def hook_skip_layer(self, module_name_from, module_name_to):
        self.skip_output_from = module_name_from
        self.skip_output_to = module_name_to

        cur_handle = []
        for name, module in self.model.named_modules():
            if name == module_name_to:
                self.module2name[module] = name
                handle = module.register_forward_hook(self._hook_skip_output)
                cur_handle.append(handle)
        self.hook_handle_list.extend(cur_handle)
        if cur_handle:
            self.is_hook = True
        return cur_handle

    def hook_model_clean_proj(self, name_pattern, proj_base_matrix):
        self.proj_base_matrix = proj_base_matrix
        cur_handle = []
        for name, module in self.model.named_modules():
            if re.match(name_pattern, name):
                self.module2name[module] = name
                handle = module.register_forward_hook(self._hook_clean_proj)
                cur_handle.append(handle)
        self.hook_handle_list.extend(cur_handle)
        if cur_handle:
            self.is_hook = True
        return cur_handle

    def generate_with_hook(self, text, **kwargs):
        if self.is_hook is False:
            self.hook_model_input_output()

        input_tokens, response, pos2text, generated_ans, debug_str = self.get_response(
            text=text, **kwargs
        )

        # output add: self_attn + residual
        layer_num = self.model.config.num_hidden_layers
        for layer in range(layer_num):
            if  f"model.layers.{layer}.post_attention_layernorm" in self.input_dict.keys():
                self.output_dict[f"model.layers.{layer}.self_attn+res"] = self.input_dict[
                    f"model.layers.{layer}.post_attention_layernorm"
                ]

        hook_output_dict = {}
        for key, sub_list in self.output_dict.items():
            hook_output_dict[key] = []
            for value in sub_list:
                if isinstance(value, tuple):
                    hook_output_dict[key].append(value[0])
                else:
                    hook_output_dict[key].append(value)

        result = {
            "hook_input_dict": self.input_dict,
            "hook_output_dict": hook_output_dict,
            "input_tokens": input_tokens,
            "response": response,
            "pos2text": pos2text,
            "generated_ans": generated_ans,
            "debug_str": debug_str,
        }

        return result
