# import torch
# from .utils import parent_module, brackets_to_periods
# import transformers
# import os
# os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

# def euc(query, key):
#     # Euclidean distance
#     if len(key.shape) < 2:
#         key = key.view(1, -1)
#     return torch.cdist(key, query, p=2)

# def perturb_values(chosen_value, num_pert, device):
#     # Create a bunch of noised versions of the value, then create batch, then train value
#     chosen_value = chosen_value
#     noise = torch.normal(0, 1, chosen_value.shape, device=device)
#     noise[0] = noise[0]*0
#     noise.requires_grad = True
#     chosen_value = chosen_value + noise
#     return chosen_value

# class GRACE(torch.nn.Module):
#     def __init__(self, config, model, device):
#         super(GRACE, self).__init__()
#         self.config = config
#         self.log_dict = {}
#         self.model = model
#         # self.tokenizer = model.tokenizer
#         layer = config.inner_params[0]
#         self.device = device

#         # --- ensure proper formatting (GRACE edits ~layers~ not weights matrices) ---
#         suffixes = [".weight", ".bias"]
#         self.layer = layer.rsplit(".", 1)[0] if any(layer.endswith(x) for x in suffixes) else layer

#         for n, p in self.model.named_parameters():
#             p.requires_grad = False

#         if isinstance(self.model, transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel):
#             transpose = False
#         else:
#             transpose = True

#         # --- Add GRACE to chosen layers ---
#         edit_module = parent_module(self.model, brackets_to_periods(self.layer))
#         layer_name = self.layer.rsplit(".", 1)[-1]
#         original_layer = getattr(edit_module, layer_name)

#         if type(original_layer) is not GRACEAdapter:
#             setattr(edit_module, layer_name, GRACEAdapter(config, original_layer, transpose=transpose).to(self.device))

#     def __call__(self, **kwargs):
#         # if self.config.task == "hallucination":
#         #     print(kwargs)
#         #     key_id = (kwargs["labels"] == -100).sum() - 1
#         #     setattr(eval(f"self.model.{self.layer}"), "key_id", key_id) # Tell GRACE which token to use for its query (default is the last token)
#         return self.model(**kwargs)

#     def generate(self, *args, **kwargs):
#         setattr(eval(f"self.model.{self.layer}"), "key_id", -1)
#         return self.model.generate(*args, **kwargs)

#     def edit(self, config, tokens):
#         key_id = (tokens["labels"] == -100).sum() - 1
#         setattr(eval(f"self.model.{self.layer}"), "key_id", key_id)

#         # --- pass edit label, training mode, and key_id into GRACE ---
#         setattr(eval(f"self.model.{self.layer}"), "training", True)
#         setattr(eval(f"self.model.{self.layer}"), "edit_label", tokens["labels"])

#         self.losses = []
#         # --- train GRACE value ---
#         for i in range(config.n_iter):
#             # --- insert iteration into each layer (only initiate keys on iteration 1) ---
#             setattr(eval(f"self.model.{self.layer}"), "iter", i)

#             # --- pass tokens through model (including through the GRACE layer) ---
#             outputs = self.model(**tokens)
#             if i == 0:
#                 # --- we only need to create an optimizer for the first iteration (but forward pass instantiates the key, so optimzer is passed after first inference) ---
#                 optimizer = torch.optim.Adam(self.model.parameters(), config.edit_lr)
#             loss = outputs.loss
#             loss.backward()
#             optimizer.step()
#             optimizer.zero_grad()
#             self.losses.append(loss.detach().cpu().numpy())

#         self.loss = loss # Log final loss

#         # --- pull out info we want to log from the GRACE layer ---
#         setattr(eval(f"self.model.{self.layer}"), "training", False)
#         chosen_key = getattr(eval(f"self.model.{self.layer}"), "chosen_key")
#         nkeys = len(getattr(eval(f"self.model.{self.layer}"), "keys"))

#         self.log_dict["chosen_key"] =  chosen_key
#         self.log_dict["nkeys"] = nkeys

# class GRACEAdapter(torch.nn.Module):
#     def __init__(self, config, layer, transpose):
#         super(GRACEAdapter, self).__init__()

#         self.layer = layer
#         self.weight = self.layer.weight
#         self.init_epsilon = config.eps
#         self.dist_fn = config.dist_fn
#         self.replacement = config.replacement
#         self.device = layer.weight.device
#         self.config = config
#         self.num_pert = config.num_pert
#         self.key_id = -1
#         self.ensure_replace_token_loc = False

#         if transpose:
#             self.key_shape = layer.weight.shape[1]
#             self.value_shape = layer.weight.shape[0]
#         else:
#             self.key_shape = layer.weight.shape[0]
#             self.value_shape = layer.weight.shape[1]
#         self.training = False

#     def add_key(self, new_key, new_value):
#         keys = torch.vstack([self.keys, new_key.detach()]) # Add new key to list of keys

#         values = torch.nn.Parameter(torch.vstack([self.values, new_value]), requires_grad=True) # Add new value to list of values

#         new_epsilon = torch.tensor(self.init_epsilon, device=self.device).view(1)
#         epsilons = torch.vstack([self.epsilons, new_epsilon]) # Add new epsilon to list of epsilons

#         key_labels = self.key_labels + [self.edit_label] # Add new key_label to list of key_labels

#         return keys, values, epsilons, key_labels

#     def init_key_value(self, query, value):
#         key = query.detach()
#         epsilon = torch.tensor(self.init_epsilon, device=self.device, requires_grad=False).view(1)
#         key_label = [self.edit_label]
#         return key, value, epsilon, key_label

#     def label_match(self, edit_label, key_label):
#         return edit_label.float().mean() == key_label.float().mean()

#     def split_epsilons_in_half(self, nearest_key, smallest_distance):
#         self.epsilons[nearest_key] = (smallest_distance / 2) - 1e-5 # Cut nearest epsilon in half
#         self.epsilons[-1] = smallest_distance / 2 # Cut new epsilon in half

#     def forward(self, *args):
#         # Run layer forward and save what it would have returned for this instance
#         layer_out = self.layer(*args)

#         ### If training, we need to modify the codebook
#         if (not self.training) & ('keys' not in self.__dict__):
#             # If it's not training time and we haven't added any keys yet (this is before doing any editing)
#             # print(self.__dict__)
#             return layer_out
#         else:
#             if not self.training and not self.ensure_replace_token_loc and self.key_id == -1:
#                 token_to_edit = args[0].shape[1]-1
#                 self.key_id = args[0].shape[1]-1
#                 self.ensure_replace_token_loc = True
#             else:
#                 token_to_edit = min(self.key_id, args[0].shape[1]-1) # args[0].shape[1] - 1 is sequence length
#             query = args[0][:, token_to_edit, :] # Just use activation for last token
#             if self.config.val_init == "cold":
#                 new_value = torch.nn.Parameter(torch.rand(1, self.value_shape, requires_grad=True, device=self.device))
#             elif self.config.val_init == "warm":
#                 new_value = torch.nn.Parameter(layer_out[:, token_to_edit, :].detach(), requires_grad=True)

#             if 'keys' not in self.__dict__:
#                 # If no keys exist, initialize keys, values, epsilons, and key labels
#                 self.keys, self.values, self.epsilons, self.key_labels = self.init_key_value(query, new_value)
#             elif self.iter == 0:
#                 # Keys exist, so we have decide whether or not to update them (the fact that we've made it to this point means there was an error!)

#                 # --- search through keys for a match for query ---
#                 dists = torch.cdist(self.keys, query, p=2).view(-1, len(query))
#                 smallest_distance, nearest_key = dists.min(0)

#                 if smallest_distance > (self.init_epsilon + self.epsilons[nearest_key]):
#                     # If there's no close key, make a new key
#                     self.keys, self.values, self.epsilons, self.key_labels = self.add_key(query, new_value)
#                 else:
#                     # If there is a close key, we need to handle conflicts
#                     if not self.label_match(self.edit_label, self.key_labels[nearest_key]):
#                         self.keys, self.values, self.epsilons, self.key_labels = self.add_key(query, new_value)
#                         self.split_epsilons_in_half(nearest_key, smallest_distance)
#                     else:
#                         # If the current label is the SAME as the nearest label, just make the nearest epsilon bigger
#                         if smallest_distance > self.epsilons[nearest_key]:
#                             if self.config.eps_expand== "coverage":
#                                 self.epsilons[nearest_key] = smallest_distance # Replace nearest epsilon with dist between old key and new key
#                             elif self.config.eps_expand == "moving_average":
#                                 a = 0.5
#                                 self.keys[nearest_key] = a*self.keys[nearest_key] + (1-a)*query # Move old key to be halfway between
#                                 self.epsilons[nearest_key] = smallest_distance
#                                 # self.epsilons[nearest_key] = smallest_distance + self.init_epsilon
#             else:
#                 # If not iter 0, we don't need to change keys, we just need to learn the value
#                 pass
#         # print(token_to_edit)
#         # compute distance from query to all keys and find the closest keys
#         dists = torch.cdist(self.keys, query, p=2).view(-1, len(query))
#         smallest_dist, self.chosen_key = dists.min(0)
#         smallest_dist = smallest_dist.view(-1, 1)
#         chosen_value = self.values[self.chosen_key]
#         eps = self.epsilons[self.chosen_key].view(-1, 1)

#         if (self.config.val_train == "adv") and (self.training):
#             chosen_value = perturb_values(chosen_value, self.num_pert, self.device)

#         if self.replacement == "replace_all":
#             layer_out = torch.where((smallest_dist <= eps).view(-1, 1, 1), chosen_value.unsqueeze(1).repeat_interleave(layer_out.shape[1], 1), layer_out)
#         elif self.replacement == "replace_last":
#             layer_out[:, token_to_edit] = torch.where((smallest_dist <= eps), chosen_value, layer_out[:, token_to_edit])
#         elif self.replacement == "replace_prompt":
#             layer_out[:, :token_to_edit] = torch.where((smallest_dist <= eps), chosen_value, layer_out[:, :token_to_edit])
#         else:
#             print("token replacement choice not found")
#         return layer_out
import copy

import torch
from .utils import parent_module, brackets_to_periods
import transformers
import os

os.environ['CUDA_LAUNCH_BLOCKING'] = "1"


def euc(query, key):
    # Euclidean distance
    if len(key.shape) < 2:
        key = key.view(1, -1)
    return torch.cdist(key, query, p=2)


def perturb_values(chosen_value, num_pert, device):
    # Create a bunch of noised versions of the value, then create batch, then train value
    chosen_value = chosen_value
    noise = torch.normal(0, 1, chosen_value.shape, device=device)
    noise[0] = noise[0] * 0
    noise.requires_grad = True
    chosen_value = chosen_value + noise
    return chosen_value


class GRACE(torch.nn.Module):
    def __init__(self, config, model, device):
        super(GRACE, self).__init__()
        self.config = config
        self.log_dict = {}
        self.model = model
        self.config = config
        # self.tokenizer = model.tokenizer
        layer = config.inner_params[0]
        self.device = device
        self.original_layer = None

        # --- ensure proper formatting (GRACE edits ~layers~ not weights matrices) ---
        suffixes = [".weight", ".bias"]
        self.layer = layer.rsplit(".", 1)[0] if any(layer.endswith(x) for x in suffixes) else layer

        for n, p in self.model.named_parameters():
            p.requires_grad = False

        if isinstance(self.model, transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel):
            transpose = False
        else:
            transpose = True

        # --- Add GRACE to chosen layers ---
        self.edit_module = parent_module(self.model, brackets_to_periods(self.layer))
        self.layer_name = self.layer.rsplit(".", 1)[-1]
        original_layer = getattr(self.edit_module, self.layer_name)
        if type(original_layer) is not GRACEAdapter:
            setattr(self.edit_module, self.layer_name,
                    GRACEAdapter(config, original_layer, transpose=transpose).to(self.device))
            self.original_layer = copy.deepcopy(original_layer)

    def __call__(self, **kwargs):
        # if self.config.task == "hallucination":
        #     print(kwargs)
        #     key_id = (kwargs["labels"] == -100).sum() - 1
        #     setattr(eval(f"self.model.{self.layer}"), "key_id", key_id) # Tell GRACE which token to use for its query (default is the last token)
        return self.model(**kwargs)

    def get_adapter_layer(self):
        adapter_layer = getattr(self.edit_module, self.layer_name)
        assert type(adapter_layer) is GRACEAdapter, print('Adapter Layer is not added correctly....')
        return adapter_layer

    def reset_layer(self):
        layer = getattr(self.edit_module, self.layer_name)
        del layer
        setattr(self.edit_module, self.layer_name, self.get_adapter_layer().original_layer)

    def generate(self, *args, **kwargs):
        setattr(eval(f"self.model.{self.layer}"), "key_id", -1)
        return self.model.generate(*args, **kwargs)

    def edit(self, config, tokens):
        key_id = (tokens["labels"] == -100).sum() - 1
        setattr(eval(f"self.model.{self.layer}"), "key_id", key_id)

        # --- pass edit label, training mode, and key_id into GRACE ---
        setattr(eval(f"self.model.{self.layer}"), "training", True)
        setattr(eval(f"self.model.{self.layer}"), "edit_label", tokens["labels"])

        self.losses = []
        # --- train GRACE value ---
        for i in range(config.n_iter):
            # --- insert iteration into each layer (only initiate keys on iteration 1) ---
            setattr(eval(f"self.model.{self.layer}"), "iter", i)

            # --- pass tokens through model (including through the GRACE layer) ---
            outputs = self.model(**tokens)
            if i == 0:
                # --- we only need to create an optimizer for the first iteration (but forward pass instantiates the key, so optimzer is passed after first inference) ---
                optimizer = torch.optim.Adam(self.model.parameters(), config.edit_lr)
            loss = outputs.loss
            try:
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                self.losses.append(loss.detach().cpu().numpy())
            except Exception as e:
                pass

        self.loss = loss  # Log final loss

        # --- pull out info we want to log from the GRACE layer ---
        setattr(eval(f"self.model.{self.layer}"), "training", False)
        chosen_key = getattr(eval(f"self.model.{self.layer}"), "chosen_key")
        nkeys = len(getattr(eval(f"self.model.{self.layer}"), "keys"))

        self.log_dict["chosen_key"] = chosen_key
        self.log_dict["nkeys"] = nkeys


class GRACEAdapter(torch.nn.Module):
    def __init__(self, config, layer, transpose):
        super(GRACEAdapter, self).__init__()

        self.layer = layer
        self.original_layer = copy.deepcopy(self.layer)
        self.weight = self.layer.weight
        self.init_epsilon = config.eps
        self.dist_fn = config.dist_fn
        self.replacement = config.replacement
        self.device = layer.weight.device
        self.config = config
        self.num_pert = config.num_pert
        self.key_id = -1

        if transpose:
            self.key_shape = layer.weight.shape[1]
            self.value_shape = layer.weight.shape[0]
        else:
            self.key_shape = layer.weight.shape[0]
            self.value_shape = layer.weight.shape[1]
        self.training = False

    def add_key(self, new_key, new_value):
        keys = torch.vstack([self.keys, new_key.detach()])  # Add new key to list of keys

        values = torch.nn.Parameter(torch.vstack([self.values, new_value]),
                                    requires_grad=True)  # Add new value to list of values

        new_epsilon = torch.tensor(self.init_epsilon, device=self.device).view(1)
        epsilons = torch.vstack([self.epsilons, new_epsilon])  # Add new epsilon to list of epsilons

        key_labels = self.key_labels + [self.edit_label]  # Add new key_label to list of key_labels

        return keys, values, epsilons, key_labels

    def init_key_value(self, query, value):
        key = query.detach()
        epsilon = torch.tensor(self.init_epsilon, device=self.device, requires_grad=False).view(1)
        key_label = [self.edit_label]
        return key, value, epsilon, key_label

    def label_match(self, edit_label, key_label):
        return edit_label.float().mean() == key_label.float().mean()

    def split_epsilons_in_half(self, nearest_key, smallest_distance):
        self.epsilons[nearest_key] = (smallest_distance / 2) - 1e-5  # Cut nearest epsilon in half
        self.epsilons[-1] = smallest_distance / 2  # Cut new epsilon in half

    def forward(self, *args):
        # Run layer forward and save what it would have returned for this instance
        layer_out = self.layer(*args)

        ### If training, we need to modify the codebook
        if (not self.training) & ('keys' not in self.__dict__):
            # If it's not training time and we haven't added any keys yet (this is before doing any editing)
            # print(self.__dict__)
            return layer_out
        else:
            if not self.training:
                if self.key_id == -1:
                    token_to_edit = args[0].shape[1] - 1
                    self.key_id = args[0].shape[1] - 1
                else:
                    token_to_edit = min(self.key_id, args[0].shape[1] - 1)
            else:
                token_to_edit = min(self.key_id, args[0].shape[1] - 1)  # args[0].shape[1] - 1 is sequence length
            query = args[0][:, token_to_edit, :]  # Just use activation for last token
            if self.config.val_init == "cold":
                new_value = torch.nn.Parameter(torch.rand(1, self.value_shape, requires_grad=True, device=self.device))
            elif self.config.val_init == "warm":
                new_value = torch.nn.Parameter(layer_out[:, token_to_edit, :].detach(), requires_grad=True)

            if 'keys' not in self.__dict__:
                # If no keys exist, initialize keys, values, epsilons, and key labels
                self.keys, self.values, self.epsilons, self.key_labels = self.init_key_value(query, new_value)
            elif self.iter == 0:
                # Keys exist, so we have decide whether or not to update them (the fact that we've made it to this point means there was an error!)

                # --- search through keys for a match for query ---
                dists = torch.cdist(self.keys, query, p=2).view(-1, len(query))
                smallest_distance, nearest_key = dists.min(0)

                if smallest_distance > (self.init_epsilon + self.epsilons[nearest_key]):
                    # If there's no close key, make a new key
                    self.keys, self.values, self.epsilons, self.key_labels = self.add_key(query, new_value)
                else:
                    # If there is a close key, we need to handle conflicts
                    if not self.label_match(self.edit_label, self.key_labels[nearest_key]):
                        self.keys, self.values, self.epsilons, self.key_labels = self.add_key(query, new_value)
                        self.split_epsilons_in_half(nearest_key, smallest_distance)
                    else:
                        # If the current label is the SAME as the nearest label, just make the nearest epsilon bigger
                        if smallest_distance > self.epsilons[nearest_key]:
                            if self.config.eps_expand == "coverage":
                                self.epsilons[
                                    nearest_key] = smallest_distance  # Replace nearest epsilon with dist between old key and new key
                            elif self.config.eps_expand == "moving_average":
                                a = 0.5
                                self.keys[nearest_key] = a * self.keys[nearest_key] + (
                                            1 - a) * query  # Move old key to be halfway between
                                self.epsilons[nearest_key] = smallest_distance
                                # self.epsilons[nearest_key] = smallest_distance + self.init_epsilon
            else:
                # If not iter 0, we don't need to change keys, we just need to learn the value
                pass
        # print(token_to_edit)
        # compute distance from query to all keys and find the closest keys
        dists = torch.cdist(self.keys, query, p=2).view(-1, len(query))
        smallest_dist, self.chosen_key = dists.min(0)
        smallest_dist = smallest_dist.view(-1, 1)
        chosen_value = self.values[self.chosen_key]
        eps = self.epsilons[self.chosen_key].view(-1, 1)

        if (self.config.val_train == "adv") and (self.training):
            chosen_value = perturb_values(chosen_value, self.num_pert, self.device)

        if self.replacement == "replace_all":
            layer_out = torch.where((smallest_dist <= eps).view(-1, 1, 1),
                                    chosen_value.unsqueeze(1).repeat_interleave(layer_out.shape[1], 1), layer_out)
        elif self.replacement == "replace_last":
            layer_out[:, token_to_edit] = torch.where((smallest_dist <= eps), chosen_value, layer_out[:, token_to_edit])
        elif self.replacement == "replace_prompt":
            layer_out[:, :token_to_edit] = torch.where((smallest_dist <= eps), chosen_value,
                                                       layer_out[:, :token_to_edit])
        else:
            print("token replacement choice not found")
        return layer_out