import torch
import re
import math
import transformers
from functools import partial
from easyeditor.models.grace.utils import parent_module, brackets_to_periods
# from .agg_utils import AverageMeter, get_context_templates
from ..rome import agg_utils
from ..rome.aggregator import Aggregator
from ..rome import repr_tools
from ...util import nethook

class SequentialAggregator(Aggregator):
    def __call__(self, **kwargs):
        adapt_fn_dict = {}
        for i, layer_name in enumerate(self.layer_names):
            cur_adapt_fn = partial(adapt_fn, layer_name=layer_name, adaptor=self.adaptors[i])
            adapt_fn_dict[layer_name] = cur_adapt_fn
            
        with nethook.TraceInputDict(
            self.model, 
            layers=self.layer_names, 
            # edit_input=cur_adapt_fn
            edit_input_dict=adapt_fn_dict
        ) as tr:
            out = self.model(**kwargs)
        return out
        
    def edit(self, request, target_vectors):
        self.training = True
        for i in range(len(self.adaptors)):
            self.adaptors[i].training = True

        layers = self.hparams.layers
        
        # make sure target vector is not changed
        if isinstance(target_vectors, list):
            target_vectors = torch.stack(target_vectors).clone()
        else:
            target_vectors = target_vectors.detach().clone()
            if target_vectors.ndim == 1: target_vectors = target_vectors[None]
        assert target_vectors.shape[0] == len(layers)
        target_vectors = target_vectors.float()
        
        # not sure the difference between these two impls. 
        try:
            context_templates = agg_utils.get_context_templates_rome(self.model, self.tok, self.hparams.context_template_length_params)
        except:
            context_templates = agg_utils.get_context_templates_memit(self.model, self.tok)
            context_templates = [x for xs in context_templates for x in xs]

        # get inv_cov
        from .compute_u import get_inv_cov
        Cs = []
        for layer in layers:
            C = get_inv_cov(
                self.model,
                self.tok,
                self.hparams.rewrite_module_tmp.format(layer),
                self.hparams.mom2_dataset,
                self.hparams.mom2_n_samples,
                self.hparams.mom2_dtype,
                hparams=self.hparams,
            )
            dd = torch.sqrt(torch.diag(C))
            norm = torch.outer(dd, dd)
            if self.hparams.agg_norm_c:
                norm_C = C / norm
            else:
                norm_C = C
            Cs.append(norm_C)

        # define loss weights
        loss_weights = {
            'agg': self.hparams.agg_loss_weight,
            'spread': self.hparams.spread_loss_weight,
            'kl': self.hparams.kl_loss_weight,
        }
        
        # lr = 5e-4 # llama 5e-3
        lr = self.hparams.agg_lr

        opt = torch.optim.Adam(
            (param for adaptor in self.adaptors for param in adaptor.parameters()),
            lr=lr,
            weight_decay=0,
        )
        # self.adaptor.print_trainable_parameters()
        
        loss_meter = agg_utils.AverageMeter()

        # fill in prompt, leave subject blank
        if '{}' not in request['prompt']:
            assert request['subject'] in request['prompt'] or \
                print(f"Subject:{request['subject']} do not exist in prompt: {request['prompt']}")

            prompt_template = request['prompt'].replace(request['subject'], '{}')
        else:
            prompt_template = request['prompt']

        prompt_templates = [prompt_template, ]
        context_templates = context_templates[:]

        assert self.hparams.agg_add_ori is True
        base_subjects = [request['subject'],]
        subtoken = self.hparams.fact_token[len("subject_") :]
        
        base_templates = get_templates(context_templates, prompt_templates)
        ctxs, idxs = get_ctx_idx(self.tok, base_subjects, base_templates, subtoken)
        
        # add rephrase templates
        if self.hparams.agg_add_rephrases:
            raise
            rep_p_templates = [reph_prompt.replace(request['subject'], "{}") for reph_prompt in request['rephrase_prompt']]
            if self.hparams.agg_train_on_held:
                half = math.ceil(len(rep_p_templates)/2)
                rep_p_templates = rep_p_templates[:half]
            # filter if there are multiple {}
            # rep_p_templates = [text for text in rep_p_templates if len(re.findall(r'\{\}', text)) == 1]
            prompt_templates += rep_p_templates
            
        if self.hparams.agg_add_para_attack and 'para_attack' in request:
            raise
            para_attacks = request['para_attack']
            para_attack_templates = [reph_prompt.replace(request['subject'], "{}") for reph_prompt in para_attacks]
            if self.hparams.agg_train_on_held:
                half = math.ceil(len(para_attack_templates)/2)
                para_attack_templates = para_attack_templates[:half]
            # filter if there are multiple {}
            # para_attack_templates = [text for text in para_attack_templates if len(re.findall(r'\{\}', text)) == 1]
            prompt_templates += para_attack_templates
        
        if self.hparams.agg_add_long_context:
            context_templates_long = []
            if self.hparams.agg_train_on_held: # for 
                # half = math.ceil(len(para_attack_templates)/2)
                # para_attack_templates = para_attack_templates[:half]
                random_texts = request['long_context_held']
                context_templates_long += random_texts
            else:
                assert 'long_context' in request
                random_text = request['long_context']
                context_templates_long += [random_text,]
            
            templates_long = get_templates(context_templates_long, prompt_templates)
            # base_templates = get_templates(context_templates, prompt_templates)
            ctxs_l, idxs_l = get_ctx_idx(self.tok, base_subjects, templates_long, subtoken)
            ctxs += ctxs_l
            idxs += idxs_l

        # templates = [
        #     c_templ.format(p_templ) for c_templ in context_templates for p_templ in prompt_templates
        # ]
        # # filtering and dedup
        # templates = list(set(templates))
        # templates = [text for text in templates if len(re.findall(r'\{\}', text)) == 1]
        # def is_valid_format_string(s):
        #     stack = []
        #     for char in s:
        #         if char == '{':
        #             stack.append(char)
        #         elif char == '}':
        #             if not stack:
        #                 return False  # Unmatched closing brace
        #             stack.pop()
        #     return len(stack) == 0  # True if all braces are matched
        # templates = [s for s in templates if is_valid_format_string(s)]

        # print(f'# of templates: {len(templates)}')
        # subtoken=self.hparams.fact_token[len("subject_") :]

        # repha subjects
        if self.hparams.agg_train_on_held:
            half = math.ceil(len(request['rephrase_subjects'])/2)
            reph_subjects = request['rephrase_subjects'][:half] # only use first half
        else:
            reph_subjects = request['rephrase_subjects']
        
        # templates_long = get_templates(context_template_long, prompt_templates)
        # base_templates = get_templates(context_templates, prompt_templates)
        ctxs_r, idxs_r = get_ctx_idx(self.tok, reph_subjects, base_templates, subtoken)
        ctxs += ctxs_r
        idxs += idxs_r
        
        # subjects = reph_subjects
        if self.hparams.agg_add_shuffle and 'shuffled_subject' in request:
            if self.hparams.agg_train_on_held:
                half = math.ceil(len(request['shuffled_subject'])/2)
                shuf_subjects = request['shuffled_subject'][:half]
            else:
                shuf_subjects = request['shuffled_subject']
                
            # base_templates = get_templates(context_templates, prompt_templates)
            ctxs_s, idxs_s = get_ctx_idx(self.tok, shuf_subjects, base_templates, subtoken)
            ctxs += ctxs_s
            idxs += idxs_s
            
        # words = [word for word in subjects for _ in templates]
        # stack_templates = [temp for word in subjects for temp in templates]
        # contexts = [stack_templates[i].format(words[i]) for i in range(len(words))] # full inputs
        # idxs = repr_tools.get_words_idxs_in_templates(self.tok, stack_templates, words, subtoken)
            
        import random
        batch_size = self.hparams.agg_batch_size
        index = 0
        # create data tuple
        ctx_idx_tup = [(i, ctxs[i], idxs[i]) for i in range(len(ctxs))]
        print(f'Data Size: {len(ctx_idx_tup)}')
        random.shuffle(ctx_idx_tup)

        for i in range(len(layers)): # each aggregator will affect the following layers.
            for it in range(self.hparams.num_train_agg_steps):
                loss_meter.reset()
                opt.zero_grad()
                def get_next_batch(data, index):
                    if index + batch_size > len(data):
                        # Reshuffle the list if we've reached the end
                        random.shuffle(data)
                        return data[:batch_size], 0
                    else:
                        return data[index:index + batch_size], index + batch_size
                    
                batch, index = get_next_batch(ctx_idx_tup, index)
                batch_id = [item[1] for item in batch]
                batch_contexts = [item[1] for item in batch]
                batch_idxs = [item[2] for item in batch]
                contexts_tok = self.tok(
                    batch_contexts,
                    return_tensors="pt",
                    padding=True,
                ).to(f"cuda:{self.hparams.device}")
                bsz, L = contexts_tok.input_ids.shape[:2]
    
                # with nethook.TraceInputDict(
                #     module=self.model, 
                #     layers=self.layer_names,
                #     retain_input=True,
                #     edit_input=None,
                # ) as tr:
                #     _ = self.model(**contexts_tok)
                
                # make agg loss mask
                assert all(len(idx)==1 for idx in batch_idxs)
                idxs_t = torch.tensor(batch_idxs).long().to(f"cuda:{self.hparams.device}")
                
                mask = torch.ones(bsz, L, dtype=torch.bool).to(f"cuda:{self.hparams.device}")
                batch_indices = torch.arange(bsz)[:, None].to(mask.device)
                
                mask[batch_indices, idxs_t] = False
                
                loss_dict = {}
                
                init_kl_reprs = [tr[layer_name].input.detach() for layer_name in self.layer_names] # [bsz, l, D]
                target_vector = target_vectors[i]                
                norm_C = Cs[i]
                init_kl_repr = init_kl_reprs[i]
                # normalize init_kl_repr
                repr, cls = self.adaptors[i](init_kl_repr)
            
                # if self.hparams.cls_loss_weight > 0:
                loss_fn = torch.nn.BCEWithLogitsLoss()
                cls_loss = loss_fn(cls.squeeze(-1), (~mask).float())
                
                agg_repr = repr[~mask]
                # before and after lora.
                kl_repr = repr[mask]
                init_kl_repr = init_kl_repr[mask]
                # norm
                # kl_repr = kl_repr / kl_repr.norm(dim=-1, keepdim=True)
                # init_kl_repr = init_kl_repr / init_kl_repr.norm(dim=-1, keepdim=True)
                
                # if init_kl_repr is None: init_kl_repr = kl_repr.detach().clone()
                # v1: mse
                # agg_loss = mse_loss(agg_repr, target_vector[None])
                # v2: k C^-1 k
                # agg_losses = -1 * (agg_repr @ C @ target_vector)
                # agg_loss = agg_losses.mean()
                
                # v3: k1^T (C / norm) k2
                # agg_losses = -1 * torch.abs(agg_repr @ norm_C @ target_vector)
                # agg_loss = agg_losses.mean()
                # v4: norm agg_repr and kl_repr
                if self.hparams.agg_loss_norm:
                    norm_agg_repr = agg_repr / agg_repr.norm(dim=-1, keepdim=True)
                else:
                    norm_agg_repr = agg_repr
                # agg_losses = -1 * torch.abs(norm_agg_repr @ norm_C @ target_vector)
                # v5: remove abs in agg loss
                if not self.hparams.agg_remove_c:
                    agg_losses = -1 * (norm_agg_repr @ norm_C @ target_vector.to(norm_agg_repr.dtype))
                else:
                    agg_losses = -1 * (norm_agg_repr @ target_vector.to(norm_agg_repr.dtype))
                    
                agg_loss = agg_losses.mean()
                # v6: add kl loss
                # norm_kl_repr = kl_repr / kl_repr.norm(dim=-1, keepdim=True)
                if self.hparams.kl_loss_type == 'whiten_diff':
                    if self.hparams.agg_remove_c:
                        kl_loss = torch.abs(init_kl_repr @ (kl_repr - init_kl_repr).T)
                    else:
                        kl_loss = torch.abs(init_kl_repr @ norm_C @ (kl_repr - init_kl_repr).T)
                elif self.hparams.kl_loss_type == 'whiten':
                    if not self.hparams.agg_remove_c:
                        kl_loss = -1 * init_kl_repr @ norm_C @ kl_repr.T
                    else:
                        kl_loss = -1 * init_kl_repr @ kl_repr.T
                elif self.hparams.kl_loss_type == 'mse':
                    import torch.nn.functional as F
                    kl_loss = F.mse_loss(init_kl_repr, kl_repr, reduction='mean')
                else:
                    raise
                    # kl_loss = torch.abs(init_kl_repr @ norm_C @ (kl_repr - init_kl_repr).T)
                kl_loss = kl_loss.mean()
                # kl_loss = 0.0
                
                # v5: min(k1^T (C / norm) k2)
                # agg_losses = -1 * torch.abs(agg_repr @ norm_C @ target_vector)
                # agg_loss = agg_losses.min()
                
                # kl_loss = mse_loss(kl_repr, init_kl_repr)
                # kl_loss = torch.abs(kl_repr)
                if not self.hparams.agg_remove_c:
                    spread_loss = torch.abs(kl_repr @ norm_C @ target_vector).mean()
                else:
                    spread_loss = torch.abs(kl_repr @ target_vector).mean()
                
                loss_dict[layer] = {
                    'agg_loss': agg_loss,
                    'spread_loss': spread_loss,
                    'kl_loss': kl_loss,
                    'cls_loss': cls_loss,
                }
            
            agg_loss_sum = sum([item['agg_loss'] for key, item in loss_dict.items()])
            spread_loss_sum = sum([item['spread_loss'] for key, item in loss_dict.items()])
            kl_loss_sum = sum([item['kl_loss'] for key, item in loss_dict.items()])
            cls_loss_sum = sum([item['cls_loss'] for key, item in loss_dict.items()])
            
            loss = agg_loss_sum * loss_weights['agg'] + spread_loss_sum * loss_weights['spread'] + kl_loss_sum * loss_weights['kl'] + cls_loss_sum * self.hparams.cls_loss_weight
            print(f"Step {it}: Agg loss {agg_loss_sum.item()}; Spread loss: {spread_loss_sum}; KL loss: {kl_loss_sum}; CLS loss: {cls_loss_sum}")
            
            loss_meter.update(loss.item(), n=1)
            loss.backward()
            opt.step()
        
        self.training = False
        # self.adaptor.training = False
        for i in range(len(self.adaptors)):
            self.adaptors[i].training = False

        # key_id = (tokens["labels"] == -100).sum() - 1
        # setattr(eval(f"self.model.{self.layer}"), "key_id", key_id)
        
        # # --- pass edit label, training mode, and key_id into GRACE ---
        # setattr(eval(f"self.model.{self.layer}"), "training", True)
        # setattr(eval(f"self.model.{self.layer}"), "edit_label", tokens["labels"])
                
        # self.losses = []
        # # --- train GRACE value ---
        # for i in range(config.n_iter):
        #     # --- insert iteration into each layer (only initiate keys on iteration 1) ---
        #     setattr(eval(f"self.model.{self.layer}"), "iter", i)
            
        #     # --- pass tokens through model (including through the GRACE layer) ---
        #     outputs = self.model(**tokens)
        #     if i == 0:
        #         # --- we only need to create an optimizer for the first iteration (but forward pass instantiates the key, so optimzer is passed after first inference) ---
        #         optimizer = torch.optim.Adam(self.model.parameters(), config.edit_lr)
        #     loss = outputs.loss
        #     loss.backward()
        #     optimizer.step()
        #     optimizer.zero_grad()
        #     self.losses.append(loss.detach().cpu().numpy())
        
        # self.loss = loss # Log final loss


        # --- pull out info we want to log from the GRACE layer ---
        # setattr(eval(f"self.model.{self.layer}"), "training", False)
        # chosen_key = getattr(eval(f"self.model.{self.layer}"), "chosen_key")
        # nkeys = len(getattr(eval(f"self.model.{self.layer}"), "keys"))
            
        # self.log_dict["chosen_key"] =  chosen_key
        # self.log_dict["nkeys"] = nkeys