import torch
import re
import math
import transformers
from functools import partial
from easyeditor.models.grace.utils import parent_module, brackets_to_periods
# from .agg_utils import AverageMeter, get_context_templates
from ..rome import agg_utils
from ..rome import repr_tools
from ...util import nethook

class Aggregator(torch.nn.Module):
    def __init__(self, hparams, model, tok, device):
        super(Aggregator, self).__init__()
        # self.config = config
        self.log_dict = {}
        self.model = model
        self.tok = tok
        self.hparams = hparams
        # self.tokenizer = model.tokenizer
        # layer = config.inner_params[0]
        self.device = device

        # --- ensure proper formatting (GRACE edits ~layers~ not weights matrices) ---        
        # suffixes = [".weight", ".bias"]
        # self.layer = layer.rsplit(".", 1)[0] if any(layer.endswith(x) for x in suffixes) else layer
        
        for n, p in self.model.named_parameters():
            p.requires_grad = False
        
        # if isinstance(self.model, transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel):
        #     transpose = False
        # else:
        #     transpose = True
        
        # self.layer = "model.layers[31].mlp.down_proj.weight"

        # # --- Add GRACE to chosen layers ---
        # breakpoint()
        # edit_module = parent_module(self.model, brackets_to_periods(self.layer))
        # self.layer_name = self.layer.rsplit(".", 1)[-1]
        # original_layer = getattr(edit_module, self.layer_name)
        # assert len(self.hparams.layers) == 1
        layers = self.hparams.layers
        
        self.layer_names = [hparams.rewrite_module_tmp.format(layer) for layer in layers]
        layer = agg_utils.get_module(self.model, self.layer_names[0])
        
        if isinstance(layer, torch.nn.Linear):
            feature_dim = layer.in_features
        elif 'gpt2' in hparams.model_name:
            feature_dim = layer.weight.shape[0]
        else:
            raise
        
        self.adaptor = Adaptor(hparams, feature_dim).to('cuda')

    def __call__(self, **kwargs):
        cur_adapt_fn = partial(adapt_fn, layer_name=self.layer_name, adaptor=self.adaptor)
        with nethook.TraceInput(
            self.model, 
            self.layer_name, 
            edit_input=cur_adapt_fn
        ) as tr:
            out = self.model(**kwargs)

        return out
    
    def generate(self, *args, **kwargs):
        return self.model.generate(*args, **kwargs)
        
    def edit(self, request, target_vector):
        self.training = True
        self.adaptor.training = True
        # from peft import get_peft_model, TaskType, LoraConfig
        # if hasattr(self.model,'peft_config'): breakpoint()
        
        layers = self.hparams.layers
        # assert len(layers) == 1
        # layer = self.hparams.layers[0]
        
        # make sure target vector is not changed
        breakpoint()
        target_vector = target_vector.detach().clone()
        
        context_templates = agg_utils.get_context_templates(self.model, self.tok, self.hparams.context_template_length_params)

        # get inv_cov
        from .compute_u import get_inv_cov
        Cs = []
        for layer in layers:
            C = get_inv_cov(
                self.model,
                self.tok,
                self.hparams.rewrite_module_tmp.format(layer),
                self.hparams.mom2_dataset,
                self.hparams.mom2_n_samples,
                self.hparams.mom2_dtype,
                hparams=self.hparams,
            )
            dd = torch.sqrt(torch.diag(C))
            norm = torch.outer(dd, dd)
            if self.hparams.agg_norm_c:
                norm_C = C / norm
            else:
                norm_C = C
            Cs.append(norm_C)

        # define loss weights
        loss_weights = {
            'agg': self.hparams.agg_loss_weight,
            'spread': self.hparams.spread_loss_weight,
            'kl': self.hparams.kl_loss_weight,
        }
        
        # if 'llama' in self.hparams.model_name.lower():
        #     target_modules = ['up_proj', 'gate_proj'] # key generations
        # else:
        #     raise
        # target_modules = self.hparams.agg_lora_module
        # target_modules = ['up_proj']
        
        lr = 5e-4 # llama 5e-3
        opt = torch.optim.Adam(
            self.adaptor.parameters(),
            lr=lr,
            weight_decay=0,
        )
        # self.adaptor.print_trainable_parameters()
        
        loss_meter = agg_utils.AverageMeter()
        
        # Tokenize target into list of int token IDs
        # target_ids = tok(request["target_new"], return_tensors="pt").to(f"cuda:{self.hparams.device}")[
        #     "input_ids"
        # ][0]
        
        # context_templates = get_context_templates(model, tok, )
        # fill in prompt, leave subject blank
        if '{}' not in request['prompt']:
            assert request['subject'] in request['prompt'] or \
                print(f"Subject:{request['subject']} do not exist in prompt: {request['prompt']}")

            prompt_template = request['prompt'].replace(request['subject'], '{}')
        else:
            prompt_template = request['prompt']

        prompt_templates = [prompt_template, ]
        context_templates = context_templates[:]
        
        # add rephrase templates
        if self.hparams.agg_add_rephrases:
            rep_p_templates = [reph_prompt.replace(request['subject'], "{}") for reph_prompt in request['rephrase_prompt']]
            if self.hparams.agg_train_on_held:
                half = math.ceil(len(rep_p_templates)/2)
                rep_p_templates = rep_p_templates[:half]
            # filter if there are multiple {}
            # rep_p_templates = [text for text in rep_p_templates if len(re.findall(r'\{\}', text)) == 1]
            prompt_templates += rep_p_templates
        if self.hparams.agg_add_para_attack and 'para_attack' in request:
            para_attacks = request['para_attack']
            para_attack_templates = [reph_prompt.replace(request['subject'], "{}") for reph_prompt in para_attacks]
            if self.hparams.agg_train_on_held:
                half = math.ceil(len(para_attack_templates)/2)
                para_attack_templates = para_attack_templates[:half]
            # filter if there are multiple {}
            # para_attack_templates = [text for text in para_attack_templates if len(re.findall(r'\{\}', text)) == 1]
            prompt_templates += para_attack_templates
        if self.hparams.agg_add_long_context:
            if self.hparams.agg_train_on_held: # for 
                # half = math.ceil(len(para_attack_templates)/2)
                # para_attack_templates = para_attack_templates[:half]
                random_texts = request['long_context_held']
                context_templates += random_texts
            else:
                assert 'long_context' in request
                random_text = request['long_context']
                context_templates += [random_text,]

        templates = [
            c_templ.format(p_templ) for c_templ in context_templates for p_templ in prompt_templates
        ]
        # filtering and dedup
        templates = list(set(templates))
        templates = [text for text in templates if len(re.findall(r'\{\}', text)) == 1]
        def is_valid_format_string(s):
            stack = []
            for char in s:
                if char == '{':
                    stack.append(char)
                elif char == '}':
                    if not stack:
                        return False  # Unmatched closing brace
                    stack.pop()
            return len(stack) == 0  # True if all braces are matched
        templates = [s for s in templates if is_valid_format_string(s)]

        print(f'# of templates: {len(templates)}')
        subtoken=self.hparams.fact_token[len("subject_") :]
        
        # repha subjects
        if self.hparams.agg_train_on_held:
            half = math.ceil(len(request['rephrase_subjects'])/2)
            reph_subjects = request['rephrase_subjects'][:half] # only use first half
        else:
            reph_subjects = request['rephrase_subjects']
        
        subjects = reph_subjects
        if self.hparams.agg_add_ori:
            subjects += [request['subject'],]
        if self.hparams.agg_add_shuffle and 'shuffled_subject' in request:
            if self.hparams.agg_train_on_held:
                half = math.ceil(len(request['shuffled_subject'])/2)
                subjects += request['shuffled_subject'][:half]
            else:
                subjects += request['shuffled_subject']
            
        words = [word for word in subjects for _ in templates]
        stack_templates = [temp for word in subjects for temp in templates]
        contexts = [stack_templates[i].format(words[i]) for i in range(len(words))] # full inputs
        idxs = repr_tools.get_words_idxs_in_templates(self.tok, stack_templates, words, subtoken)
            
        import random
        batch_size = self.hparams.agg_batch_size
        index = 0
        # create data tuple
        ctx_idx_tup = [(i, contexts[i], idxs[i]) for i in range(len(contexts))]
        random.shuffle(ctx_idx_tup)

        for it in range(self.hparams.num_train_agg_steps):
            loss_meter.reset()
            opt.zero_grad()
            def get_next_batch(data, index):
                if index + batch_size > len(data):
                    # Reshuffle the list if we've reached the end
                    random.shuffle(data)
                    return data[:batch_size], 0
                else:
                    return data[index:index + batch_size], index + batch_size
            
            batch, index = get_next_batch(ctx_idx_tup, index)
            batch_id = [item[1] for item in batch]
            batch_contexts = [item[1] for item in batch]
            batch_idxs = [item[2] for item in batch]
            contexts_tok = self.tok(
                batch_contexts,
                return_tensors="pt",
                padding=True,
            ).to(f"cuda:{self.hparams.device}")

            with nethook.TraceInput(
                module=self.model, 
                layer=self.layer_name,
                # edit_input=cur_adapt_fn,
                retain_input=True,
            ) as tr:
                _ = self.model(**contexts_tok)

            # init_kl_repr = tr[self.hparams.rewrite_module_tmp.format(layer)].input # [bsz, l, D]
            init_kl_repr = tr.input.detach()
            # normalize init_kl_repr
            repr, cls = self.adaptor(init_kl_repr)

            # gather loss
            assert all(len(idx)==1 for idx in batch_idxs)
            idxs_t = torch.tensor(batch_idxs).long().to(repr.device)
            
            mask = torch.ones(repr.size(0), repr.size(1), dtype=torch.bool).to(repr.device)
            batch_indices = torch.arange(repr.size(0))[:, None].to(mask.device)
            mask[batch_indices, idxs_t] = False
        
            # if self.hparams.cls_loss_weight > 0:
            loss_fn = torch.nn.BCEWithLogitsLoss()
            cls_loss = loss_fn(cls.squeeze(-1), (~mask).float())
            # else:
                # cls_loss = 0.0
            
            agg_repr = repr[~mask]
            # before and after lora.
            kl_repr = repr[mask]
            init_kl_repr = init_kl_repr[mask]
            # norm
            # kl_repr = kl_repr / kl_repr.norm(dim=-1, keepdim=True)
            # init_kl_repr = init_kl_repr / init_kl_repr.norm(dim=-1, keepdim=True)
            
            # if init_kl_repr is None: init_kl_repr = kl_repr.detach().clone()
            # v1: mse
            # agg_loss = mse_loss(agg_repr, target_vector[None])
            # v2: k C^-1 k
            # agg_losses = -1 * (agg_repr @ C @ target_vector)
            # agg_loss = agg_losses.mean()
            
            # v3: k1^T (C / norm) k2
            # agg_losses = -1 * torch.abs(agg_repr @ norm_C @ target_vector)
            # agg_loss = agg_losses.mean()
            # v4: norm agg_repr and kl_repr
            if self.hparams.agg_loss_norm:
                norm_agg_repr = agg_repr / agg_repr.norm(dim=-1, keepdim=True)
            else:
                norm_agg_repr = agg_repr
            # agg_losses = -1 * torch.abs(norm_agg_repr @ norm_C @ target_vector)
            # v5: remove abs in agg loss
            if not self.hparams.agg_remove_c:
                agg_losses = -1 * (norm_agg_repr @ norm_C @ target_vector)
            else:
                agg_losses = -1 * (norm_agg_repr @ target_vector)
                
            agg_loss = agg_losses.mean()
            # v6: add kl loss
            # norm_kl_repr = kl_repr / kl_repr.norm(dim=-1, keepdim=True)
            if self.hparams.kl_loss_type == 'whiten_diff':
                if self.hparams.agg_remove_c:
                    kl_loss = torch.abs(init_kl_repr @ (kl_repr - init_kl_repr).T)
                else:
                    kl_loss = torch.abs(init_kl_repr @ norm_C @ (kl_repr - init_kl_repr).T)
            elif self.hparams.kl_loss_type == 'whiten':
                if not self.hparams.agg_remove_c:
                    kl_loss = -1 * init_kl_repr @ norm_C @ kl_repr.T
                else:
                    kl_loss = -1 * init_kl_repr @ kl_repr.T
            elif self.hparams.kl_loss_type == 'mse':
                import torch.nn.functional as F
                kl_loss = F.mse_loss(init_kl_repr, kl_repr, reduction='mean')
            else:
                raise
                # kl_loss = torch.abs(init_kl_repr @ norm_C @ (kl_repr - init_kl_repr).T)
            kl_loss = kl_loss.mean()
            # kl_loss = 0.0
            
            # v5: min(k1^T (C / norm) k2)
            # agg_losses = -1 * torch.abs(agg_repr @ norm_C @ target_vector)
            # agg_loss = agg_losses.min()
            
            # kl_loss = mse_loss(kl_repr, init_kl_repr)
            # kl_loss = torch.abs(kl_repr)
            if not self.hparams.agg_remove_c:
                spread_loss = torch.abs(kl_repr @ norm_C @ target_vector).mean()
            else:
                spread_loss = torch.abs(kl_repr @ target_vector).mean()
            loss = agg_loss * loss_weights['agg'] + spread_loss * loss_weights['spread'] + kl_loss * loss_weights['kl'] + cls_loss * self.hparams.cls_loss_weight
            print(f"Step {it}: Agg loss {agg_loss.item()}; Spread loss: {spread_loss}; KL loss: {kl_loss}; CLS loss: {cls_loss}")
            # print(f"Step {it}: Agg loss {agg_loss.item()}; KL loss: {kl_loss}")
            # print(f"Step {it}: Agg loss worst {agg_losses.max()}")
            loss_meter.update(loss.item(), n=1)
            loss.backward()
            opt.step()
        
        self.training = False
        self.adaptor.training = False

        # key_id = (tokens["labels"] == -100).sum() - 1
        # setattr(eval(f"self.model.{self.layer}"), "key_id", key_id)
        
        # # --- pass edit label, training mode, and key_id into GRACE ---
        # setattr(eval(f"self.model.{self.layer}"), "training", True)
        # setattr(eval(f"self.model.{self.layer}"), "edit_label", tokens["labels"])
                
        # self.losses = []
        # # --- train GRACE value ---
        # for i in range(config.n_iter):
        #     # --- insert iteration into each layer (only initiate keys on iteration 1) ---
        #     setattr(eval(f"self.model.{self.layer}"), "iter", i)
            
        #     # --- pass tokens through model (including through the GRACE layer) ---
        #     outputs = self.model(**tokens)
        #     if i == 0:
        #         # --- we only need to create an optimizer for the first iteration (but forward pass instantiates the key, so optimzer is passed after first inference) ---
        #         optimizer = torch.optim.Adam(self.model.parameters(), config.edit_lr)
        #     loss = outputs.loss
        #     loss.backward()
        #     optimizer.step()
        #     optimizer.zero_grad()
        #     self.losses.append(loss.detach().cpu().numpy())
        
        # self.loss = loss # Log final loss


        # --- pull out info we want to log from the GRACE layer ---
        # setattr(eval(f"self.model.{self.layer}"), "training", False)
        # chosen_key = getattr(eval(f"self.model.{self.layer}"), "chosen_key")
        # nkeys = len(getattr(eval(f"self.model.{self.layer}"), "keys"))
            
        # self.log_dict["chosen_key"] =  chosen_key
        # self.log_dict["nkeys"] = nkeys

class Adaptor(torch.nn.Module):
    def __init__(self, hparams, feature_dim):
        super(Adaptor, self).__init__()
        
        self.hparams = hparams        
        # default the same as in lora aggregator
        r = 32
        lora_alpha = 32
        lora_dropout = 0.1
        
        self.lora_A = torch.nn.Linear(feature_dim, r, bias=False)
        self.lora_B = torch.nn.Linear(r, feature_dim, bias=False)
        
        import math
        torch.nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
        torch.nn.init.zeros_(self.lora_B.weight)

        self.scaling = lora_alpha / r

        # self.dropout = self.lora_dropout[active_adapter]
        if lora_dropout > 0.0:
            self.dropout = torch.nn.Dropout(p=lora_dropout)
        else:
            self.dropout = torch.nn.Identity()
        
        # cls
        if hparams.agg_use_cls:
            inner_dim = int(feature_dim * 0.1)
            self.cls_mlp1 = torch.nn.Linear(feature_dim, inner_dim, bias=True)
            self.cls_act = torch.nn.GELU()
            self.cls_mlp2 = torch.nn.Linear(inner_dim, 1, bias=True)
            self.cls_sigmoid = torch.nn.Sigmoid()
        
        self.cls_threshold = self.hparams.cls_threshold
    
    def forward(self, x):
        result = self.lora_B(self.lora_A(self.dropout(x))) * self.scaling # [bsz, L, D]
        if self.hparams.agg_use_cls:
            # compute cls
            cls_out = self.cls_mlp2(self.cls_act(self.cls_mlp1(x)))
            cls_pred = self.cls_sigmoid(cls_out)
        else:
            cls_pred = torch.ones((result.shape[0], result.shape[1], 1)).to(result.device) # dummy pred
        
        if self.training is True:
            result = x + result * cls_pred
            return result, cls_pred
        else:
            result = x + result * (cls_pred > self.cls_threshold).float()
        return result

def adapt_fn(layer_name, adaptor, inputs, layer):
    assert isinstance(inputs, tuple)
    assert len(inputs) == 1 # we only deal with case with only one inputs
    inputs = inputs[0]
    if layer == layer_name:
        # return tuple(adaptor(inputs[0]), )
        return adaptor(inputs)
    else:
        return inputs