import os
import sys
import inspect

from tqdm import tqdm
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir)

import json, pickle
import torch
import logging
import numpy as np
from copy import deepcopy
from alignment import build_aligner
from utils.model_utils import (project_into_vocabluary, is_key, is_value, is_gate,
                               get_lm_head, get_last_transformer_layer, load_from_pickle, save_to_pickle,
                               get_num_transformer_layers, get_hidden_dim, get_model_category)


class Prob():
    def __init__(self, model, tokenizer, pref_data_dps, cache_path, language_list, save_tag, centering=True, top_k_ranks=2, edit_layer_range=None, random_dps=True, scale_rate=1.0, dare_drop_rate=0.0, align_method='lsar+1'):

        self.model = model
        self.model.eval()

        self.tokenizer = tokenizer

        self.model_category = get_model_category(model)  # 'gpt2' or 'llama' like architectures
        print(self.model_category)
        self.D = get_hidden_dim(model)  # Hidden dimension of the model
        self.num_layers = get_num_transformer_layers(model)
        self.E = get_lm_head(self.model)  # (V, D) for GPT-2

        self.pref_data_dps = pref_data_dps
        self.random_dps = random_dps
        self.centering = centering
        self.scale_rate = scale_rate
        self.dare_drop_rate = dare_drop_rate
        self.top_k_ranks = top_k_ranks
        self.align_method = align_method

        self.cache_path = cache_path
        self.language_list = language_list
        self.save_tag = save_tag

        if edit_layer_range is None:
            self.edit_layer_range = np.arange(self.num_layers)
        else:
            self.edit_layer_range = edit_layer_range


    def _load_preference_data(self):
        num_dps = self.pref_data_dps
        filedir = os.path.join(os.environ["DATASET_DIR"], os.environ["DATASET_NAME"])
        filepath = os.path.join(filedir, 'test.jsonl')

        if not os.path.exists(filepath):
            logging.error(f'File not found at: {filepath}')
            return

        lang_list = self.language_list.split(',')
        lang_list = [item.strip() for item in lang_list]
        lang_data = {lang: [] for lang in lang_list}

        with open(filepath, 'r') as f:
            for line in f:
                for lang in lang_list:
                    raw_data = json.loads(line)[lang].strip()
                    # messages = [
                    #     {"role": "system", "content": 'You are a helpful assistant.'},
                    #     {"role": "user", "content": raw_data},
                    # ]
                    # data = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
                    template = """Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n###Instruction:\n{query}\n\n### Response:\n"""
                    lang_data[lang].append(template.format(query=raw_data))
        
        for k, v in lang_data.items():
            print(f"{k}_data_example: \n{v[0]}\n\n")

        if num_dps != -1:  # 4096 points
            if not self.random_dps:
                preferred_data = preferred_data[:num_dps]
                non_preferred_data = non_preferred_data[:num_dps]
            else:
                indices = np.random.choice(len(preferred_data), num_dps, replace=False)
                preferred_data = [preferred_data[i] for i in indices]
                non_preferred_data = [non_preferred_data[i] for i in indices]
        
        for k, v in lang_data.items():
            logging.info(f'Loaded {len(v)} {k} samples. \n')

        lang_data = {k: self.tokenizer(v, return_tensors="pt", padding=True, truncation=True, max_length=64) for k, v in lang_data.items()}

        return lang_data

    def _get_hidden_sentence_embeddings(self, inputs):
        input_ids = inputs.input_ids
        attention_mask = inputs.attention_mask

        batch_size = min(50, input_ids.size(0))
        num_batches = inputs.input_ids.size(0) // batch_size
        sent_embs = []

        for i in range(num_batches):
            batch_input_ids = input_ids[i * batch_size: (i + 1) * batch_size]
            batch_attention_mask = attention_mask[i * batch_size: (i + 1) * batch_size]
            logging.info(f'Batch {i + 1}/{num_batches} of size {batch_input_ids.size(0)}')

            with torch.no_grad():
                outputs = self.model(input_ids=batch_input_ids.to(self.model.device), attention_mask=batch_attention_mask.to(self.model.device), output_hidden_states=True)
                hidden_states = outputs.hidden_states  # Tuple of len L tensors: (N, seq_len, D), N = batch_size
            del outputs
            hidden_states = hidden_states[1:]  # Remove the input layer embeddings
            hidden_states = torch.stack(hidden_states)  # (L, N, seq_len, D)

            last_layer = get_last_transformer_layer(self.model)
            penultimate_layer_embedding = hidden_states[-2]  # (N, seq_len, D)

            if self.model_category in ['gpt2', 'opt']:
                last_layer_emb = last_layer(penultimate_layer_embedding)[0]  # (N, seq_len, D)
            elif self.model_category in ['llama', 'gemma', 'mistral', 'phi']:
                # inputs_embeds = self.model.model.embed_tokens(batch_input_ids)
                # past_seen_tokens = 0
                # cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device)
                # causal_mask = self.model.model._update_causal_mask(batch_attention_mask, inputs_embeds, cache_position, past_seen_tokens)
                # position_ids = cache_position.unsqueeze(0)
                # last_layer_emb = last_layer(
                #     penultimate_layer_embedding,
                #     attention_mask=causal_mask,
                #     position_ids=position_ids,
                # )[0]
                last_layer_emb = hidden_states[-1]
            elif self.model_category == 'gptj':
                last_layer_emb = hidden_states[-1]
            else:
                raise NotImplementedError(f'Model category not recognized: {self.model_category}')
            hidden_states[-1] = last_layer_emb

            # hidden_sent_embs = torch.mean(hidden_states, dim=2)  # (L, N, D)
            hidden_sent_embs = hidden_states[:, :, -1, :]
            sent_embs.append(hidden_sent_embs.detach().to('cpu'))
            del hidden_sent_embs, hidden_states
            torch.cuda.empty_cache()

        # sent_embs is a list of tensors of shape (L, N, D). Concatenate them along the batch dimension
        hidden_sent_embs = torch.cat(sent_embs, dim=1)  # (L, N, D)
        del sent_embs
        logging.info(f'Hidden sent: {hidden_sent_embs.shape}')
        torch.cuda.empty_cache()
        return hidden_sent_embs

    def projection(self, emb, lang_dir):
    
        lang_dir_norm = lang_dir / torch.linalg.norm(lang_dir, axis=1, keepdims=True)
        proj = torch.matmul(emb, lang_dir_norm.T)

        return torch.matmul(proj, lang_dir_norm)

    def _get_preference_matrix(self):
        lang_data = self._load_preference_data()

        source_lan_emb = {}
        # 检查lang_data的数据是否完全一致
        keys = list(lang_data.keys())
        # for i in range(len(keys)):
        #     for j in range(i+1, len(keys)):
        #         assert torch.equal(lang_data[keys[i]]['input_ids'], lang_data[keys[j]]['input_ids'])

        for lang, data in lang_data.items():
            sent_embs = load_from_pickle(os.path.join(self.cache_path,self.save_tag, f'{lang}_hidden_last.pkl'))
            if sent_embs is None:
                sent_embs = self._get_hidden_sentence_embeddings(data)  # (L, N, D)
                save_to_pickle(sent_embs, os.path.join(self.cache_path, self.save_tag, f'{lang}_hidden_last.pkl'))
            source_lan_emb[lang] = sent_embs

        # difference_matrix = (preferred_sent_embs - non_preferred_sent_embs) / 2  # (L, N, D)
        # logging.info('Preference matrix calculated.')
        # # del non_preferred_sent_embs
        # for layer_num in range(preference_matrix.shape[0]):
        #     sorted_tokens = project_into_vocabluary(preference_matrix[layer_num].mean(dim=0).squeeze(), self.E.cpu(), self.tokenizer, top_k=10)
        #     print(f'Layer {layer_num} - Rank 0: {" | ".join([x for x in sorted_tokens])}')
        
        preference_matrix, Wu_matrix = [], []
        for layer_num in range(source_lan_emb["en"].shape[0]):
            cur_source_lan_emb = {lang: emb[layer_num].numpy() for lang, emb in source_lan_emb.items()}
            print("cur_source_lan_emb", cur_source_lan_emb)
            Wu, aligner = build_aligner(self.align_method, cur_source_lan_emb)
            preference_matrix.append(torch.tensor(aligner.T))
            print("Wu: ", Wu.shape)
            print("Aligner: ", aligner.shape)
            Wu_matrix.append(torch.tensor(Wu.T))

        preference_matrix = torch.stack(preference_matrix, dim=0)
        Wu_matrix = torch.stack(Wu_matrix, dim=0)
        logging.info('Preference matrix calculated.')

        save_to_pickle(preference_matrix, os.path.join(self.cache_path, self.save_tag, 'lang_specific_space_last.pkl'))
        save_to_pickle(Wu_matrix, os.path.join(self.cache_path, self.save_tag, 'lang_shared_space_last.pkl'))
        
        lang_specific_proj = preference_matrix

        en_direction = []
        for layer_idx in range(len(source_lan_emb['en'])):
            en_direction.append(self.projection(source_lan_emb['en'][layer_idx], lang_specific_proj[layer_idx].to(source_lan_emb['en'].dtype).to(source_lan_emb['en'].device)))

        for lang in source_lan_emb.keys():
            lang_direction = []
            for layer_idx in range(len(source_lan_emb[lang])):
                lang_direction.append(self.projection(source_lan_emb[lang][layer_idx], lang_specific_proj[layer_idx].to(source_lan_emb[lang].dtype).to(source_lan_emb[lang].device)))

            delta = [lang.mean(dim=0) - en.mean(dim=0) for lang, en in zip(lang_direction, en_direction)]

            save_to_pickle(delta, os.path.join(self.cache_path, self.save_tag, f'{lang}_direction_last.pkl'))
        print(stop)

        # for layer_num in range(preference_matrix.shape[0]):
        #     sorted_tokens = project_into_vocabluary(preferred_sent_embs[layer_num][-1].squeeze(), self.E.cpu(), self.tokenizer, top_k=10)
        #     print(f'Layer {layer_num} - Rank 0: {" | ".join([x for x in sorted_tokens])}')
        # print("-" * 20)
        # for layer_num in range(preference_matrix.shape[0]):
        #     zh_direction = preference_matrix[layer_num]
        #     last_token_states = preferred_sent_embs[layer_num][-1].unsqueeze(0).to(zh_direction.dtype)
        #     zh_direction /= torch.linalg.norm(zh_direction, axis=1, keepdims=True)
        #     proj = torch.matmul(last_token_states, zh_direction.T)
        #     last_token_states = torch.matmul(proj, zh_direction)
        #     # sorted_tokens = project_into_vocabluary(preference_matrix[layer_num].mean(dim=0).squeeze(), self.E.cpu(), self.tokenizer, top_k=10)
        #     sorted_tokens = project_into_vocabluary(last_token_states.squeeze(), self.E.cpu(), self.tokenizer, top_k=10)
        #     print(f'Layer {layer_num} - Rank 0: {" | ".join([x for x in sorted_tokens])}')
        # print(stop)

        if self.centering:
            logging.info('Centering: Removing first singular vector from preference matrix.')

            for layer_num in range(preference_matrix.shape[0]):
                d = preference_matrix[layer_num].to(torch.float32)
                pref = deepcopy(preferred_sent_embs[layer_num].to(torch.float32))
                no_pref = deepcopy(non_preferred_sent_embs[layer_num].to(torch.float32))
                projection_vector = (pref.mean(dim=0) - no_pref.mean(dim=0)).unsqueeze(dim=1)

                # u, s, vt = torch.linalg.svd(pref, full_matrices=False)  # (N, D) -> (N, N), (N,), (N, D)
                # projection_vector = vt[0].unsqueeze(dim=1)  # (D, 1)
                P = projection_vector @ projection_vector.T  # (D, D)
                I = torch.eye(projection_vector.shape[0]).to(pref.device)  # (D, D)
                d = d @ (I - P)  # (N, D) @ (D, D) -> (N, D)

                # sorted_tokens = project_into_vocabluary(projection_vector.squeeze(), self.E.cpu(), self.tokenizer, top_k=10)
                # print(f'Layer {layer_num} - Rank 0: {" | ".join([x for x in sorted_tokens])}')

                preference_matrix[layer_num] = d.to(preference_matrix[layer_num].dtype) # d

        return preference_matrix


    def get_ats(self):

        preference_matrix = self._get_preference_matrix()  # (L, N, D)
        ats = {}

        for key in self.model.state_dict():
            if 'weight' in key and 'mlp' in key:
                layer_num = int(key.split('.')[2])  # Format: transformer.h.19.mlp.c_fc.weight
                ats[key] = preference_matrix[layer_num]
        return ats


    def svd_on_ats(self, ats):
        '''
        Key(D, 4D) -> U(D, D) S(D) V^T(D, 4D)
        Value(4D, D) -> U(4D, D) S(4D) V^T(D, D)
        x_l (N, D) -> U(N, N); S(N,); V^T(N, D)

        Note: v @ v.T is not numerically I, but plotting it as a heatmap shows that it is close to I.
        '''
        svd = {}
        for key in ats:
            logging.info(f'Calculating SVD for: {key}')
            M = ats[key].to(torch.float32)  # SVD function only works with float32

            u, s, vt = torch.linalg.svd(M.cuda(), full_matrices=False)  # Skinny SVD, vt is V^T
            svd[key] = {'u': u.cpu(), 's': s.cpu(), 'v': vt.T.cpu()}
        logging.info('SVD of ATS calculated.')
        return svd


    def find_p_toxic(self, svd, rank_range=20):
        toxic_subspace = {}

        for key in tqdm(svd.keys()):
            layer_num = int(key.split('.')[2])  # Format: transformer.h.19.mlp.c_fc.weight
            if layer_num not in self.edit_layer_range:
                logging.info(f'Skipping layer {layer_num}')
                continue
            logging.info(f'Calculating toxic subspace for: {key}')

            singular_vectors = svd[key]['v']  # (D, N): N cols of (D,) vectors
            toxic_rank_list = np.arange(self.top_k_ranks)  # [0, 1] by default

            # Sum outer products of shortlisted ranks
            p_toxic = torch.zeros(self.D, self.D)
            for r in toxic_rank_list:
                singular_vector = singular_vectors[:, r].unsqueeze(dim=1)  # (D, 1)
                p_toxic += singular_vector @ singular_vector.T  # (D, 1) @ (1, D) -> (D, D)

                sorted_tokens = project_into_vocabluary(singular_vector.squeeze(), self.E.cpu(), self.tokenizer, top_k=10)
                print(f'Layer {layer_num} - Rank {r}: {" | ".join([x for x in sorted_tokens])}')

            toxic_subspace[key] = p_toxic
        logging.info('Toxic subspace calculated.')
        return toxic_subspace


    def edit_model(self, toxic_subspace, edit_keys=True, edit_gates=True, edit_values=True, layer_range=None):
        assert edit_keys or edit_gates or edit_values, 'At least one of edit_keys, edit_gates or edit_values should be True'
        logging.info(f'Editing keys: {edit_keys}, Editing gates: {edit_gates}, Editing values: {edit_values}.')

        if layer_range is None:
            layer_range = np.arange(get_num_transformer_layers(self.model))
        logging.info(f'Editing layers: {layer_range}')

        edited_state_dict = self.model.state_dict()
        for key in edited_state_dict:
            if key in toxic_subspace:

                layer_num = int(key.split('.')[2])
                if layer_num in layer_range:

                    # Modified - -> +
                    P_filter = torch.eye(self.D) + self.scale_rate * toxic_subspace[key]
                    P_filter = P_filter.to(edited_state_dict[key].device).to(self.model.dtype)

                    weight = edited_state_dict[key]
                    if self.model_category in ['llama', 'mistral', 'opt', 'gptj']:
                        weight = weight.T

                    if edit_keys and is_key(key, self.model_category):
                        modified_weight = P_filter @ weight  # (D, D) @ (D, 4D) -> (D, 4D)
                        logging.info(f'Editing: {key}')
                        logging.info(f'Module {key}: P_toxic mean: {toxic_subspace[key].mean()}.')
                    if edit_gates and is_gate(key, self.model_category):
                        modified_weight = P_filter @ weight  # (D, D) @ (D, 4D) -> (D, 4D)
                        logging.info(f'Editing: {key}')
                        logging.info(f'Module {key}: P_toxic mean: {toxic_subspace[key].mean()}.')
                    elif edit_values and is_value(key, self.model_category):
                        modified_weight = weight @ P_filter  # (4D, D) @ (D, D) -> (4D, D)
                        logging.info(f'Editing: {key}')
                        logging.info(f'Module {key}: P_toxic mean: {toxic_subspace[key].mean()}.')
                    else:
                        continue
                    if torch.allclose(weight, modified_weight) and ('gate_proj' not in key):
                        logging.warning(f'Module {key} not edited after projection.')

                    if self.model_category in ['llama', 'mistral', 'opt', 'gptj']:
                        modified_weight = modified_weight.T
                    
                    # delta_weight = modified_weight - weight.T
                    # delta_weight = torch.where(torch.rand_like(delta_weight) > self.dare_drop_rate, delta_weight, torch.zeros_like(delta_weight))
                    
                    # modified_weight = weight.T + self.scale_rate * delta_weight / (1 - self.dare_drop_rate)
                    edited_state_dict[key] = modified_weight.to('cuda').contiguous()  # contiguous for saving to disk

        self.model.load_state_dict(edited_state_dict, assign=True)
        logging.info('Edited model created.')
        return self.model


    def setup_for_edits(self):
        ats = self.get_ats()
        svd = self.svd_on_ats(ats)
        del ats
        self.toxic_subspace = self.find_p_toxic(svd)
        del svd
        torch.cuda.empty_cache()


    def apply_edit_end_to_end(self, edit_keys=True, edit_gates=True, edit_values=True, layer_range=None):
        # Measure speed and memory use
        import time
        import psutil
        import pynvml
        start_time = time.time()
        before_memory = psutil.virtual_memory().used
        pynvml.nvmlInit()
        handle = pynvml.nvmlDeviceGetHandleByIndex(0)
        info = pynvml.nvmlDeviceGetMemoryInfo(handle)
        before_gpu_memory_used = info.used

        # Find P_toxic
        self.setup_for_edits()

        # Apply edit
        edited_model = self.edit_model(self.toxic_subspace, edit_keys, edit_gates, edit_values, layer_range)
        torch.cuda.empty_cache()

        end_time = time.time()
        time.sleep(1)
        after_memory = psutil.virtual_memory().used
        info = pynvml.nvmlDeviceGetMemoryInfo(handle)
        after_gpu_memory_used = info.used
        print(f"Elapsed time: {end_time - start_time} seconds")
        print(f"System Memory Used: {(after_memory - before_memory) / (1024 * 1024)} MB")
        print(f"GPU Memory Used: {(after_gpu_memory_used - before_gpu_memory_used) / (1024 ** 2)} MB")

        return edited_model
