import torch

from baselines.rome.layer_stats import layer_stats
from util.globals import *

class OursPlugBatch:
    KK_cache = {}

    align_tau = 1
    slide_size = 5
    kl_ratio = 0.05
    loc_ratio = 1e-5

    def get_KK(self, model, tok, layer_name):
        layer_name = layer_name.replace('.weight', '')
        model_name = model.config._name_or_path.replace("/", "_")
        key = (model_name, layer_name)
        if key not in self.KK_cache:
            print(
                f"Retrieving inverse covariance statistics for {model_name} @ {layer_name}. "
                f"The result will be cached to avoid repetitive computation."
            )
            stat = layer_stats(
                model,
                tok,
                layer_name,
                STATS_DIR,
                'wikipedia',
                to_collect=["mom2"],
                sample_size=100000,
                precision='float32',
            )
            self.KK_cache[key] = stat.mom2.moment().to("cuda").float()
        return self.KK_cache[key]

    def compute_loc_loss(self, model, tok, weights, weights_copy):
        loc_loss_list = []
        for key in weights.keys():
            w, wc = weights[key], weights_copy[key]
            kk = self.get_KK(model, tok, key)
            delta = (w - wc) @ kk
            key_loc_loss = (delta * delta).sum()
            loc_loss_list.append(key_loc_loss)
        loc_loss = sum(loc_loss_list) / len(loc_loss_list)
        (loc_loss * self.loc_ratio).backward()

        return loc_loss

    def compute_kl_loss(self, model, inputs):
        bs = inputs["input_ids"].shape[0]

        kl_loss_all = torch.tensor(0.).cuda()
        for bsid in range(bs):
            l, r = self.slide_size, (inputs["attention_mask"][bsid].sum()-2).item()
            if l < r:
                logits = model(**{k: v[bsid:bsid+1] for k, v in inputs.items()}).logits
                kl_inputs = {k: torch.stack([v[bsid, i-self.slide_size+1:i+1] for i in range(l, r)]) for k, v in inputs.items()}
                kl_logits = model(**kl_inputs).logits
                kl_probs = kl_logits.softmax(-1) / self.align_tau
                kl_loss_each = torch.nn.functional.kl_div(
                    (logits[0, l:r, :].softmax(-1) / self.align_tau).log(),
                    kl_probs[:, -1, :],
                    reduction='none'
                ).sum(-1)
                kl_loss_bs = kl_loss_each.mean() / bs

                (kl_loss_bs * self.kl_ratio).backward()
                kl_loss_all += kl_loss_bs

        return kl_loss_all
    
    def compute_pt_loss(self, model, inputs, target, chunk_size):
        pt_loss_all = 0
        sub_input_ids = [inputs['input_ids'][i:i+chunk_size] for i in range(0, len(inputs['input_ids']), chunk_size)]
        sub_attention_mask = [inputs['attention_mask'][i:i+chunk_size] for i in range(0, len(inputs['attention_mask']), chunk_size)]
        sub_target = [target[i:i+chunk_size] for i in range(0, len(target), chunk_size)]

        for input_ids, attention_mask, target in zip(sub_input_ids, sub_attention_mask, sub_target):
            logits = model(input_ids=input_ids, attention_mask=attention_mask).logits
            probs = torch.log_softmax(logits, dim=2)
            loss = torch.gather(
                probs,
                2,
                torch.where(target != -100, target, 0).unsqueeze(2),
            ).squeeze(2)
            mask = (target != -100).float()
            nll_loss_each = -(loss * mask).sum(1) / mask.sum(1)

            pt_loss = nll_loss_each.sum() / len(inputs['input_ids'])
            pt_loss.backward()
            pt_loss_all += pt_loss

        return pt_loss_all

class OursPlug:
    KK_cache = {}

    align_tau = 1
    slide_size = 5
    kl_ratio = 0.1
    loc_ratio = 0.0001

    def get_KK(self, model, tok, layer_name):
        layer_name = layer_name.replace('.weight', '')
        model_name = model.config._name_or_path.replace("/", "_")
        key = (model_name, layer_name)
        if key not in self.KK_cache:
            print(
                f"Retrieving inverse covariance statistics for {model_name} @ {layer_name}. "
                f"The result will be cached to avoid repetitive computation."
            )
            stat = layer_stats(
                model,
                tok,
                layer_name,
                STATS_DIR,
                'wikipedia',
                to_collect=["mom2"],
                sample_size=100000,
                precision='float32',
            )
            self.KK_cache[key] = stat.mom2.moment().to("cuda").float()
        return self.KK_cache[key]

    def compute_loc_loss(self, model, tok, weights, weights_copy):
        loc_loss_list = []
        for key in weights.keys():
            w, wc = weights[key], weights_copy[key]
            kk = self.get_KK(model, tok, key)
            delta = (w - wc) @ kk
            key_loc_loss = (delta * delta).sum()
            loc_loss_list.append(key_loc_loss)
        loc_loss = sum(loc_loss_list)
        (loc_loss * self.loc_ratio).backward()

        return loc_loss

    def split_forward(self, model, inputs, limit=50):
        inputs_ids_list = inputs['input_ids'].split(limit)
        attention_mask_list = inputs['attention_mask'].split(limit)
        result = []
        for input_ids, attention_mask in zip(inputs_ids_list, attention_mask_list):
            logits = model(input_ids=input_ids, attention_mask=attention_mask).logits
            result.append(logits)
        return_data = torch.concat(result, dim=0)
        return return_data

    def compute_kl_loss(self, model, inputs, logits):
        bs = inputs["input_ids"].shape[0]

        kl_loss_all = torch.tensor(0.).cuda()
        for bsid in range(bs):
            l, r = self.slide_size, (inputs["attention_mask"][bsid].sum()-2).item()
            if l < r:
                kl_inputs = {k: torch.stack([v[bsid, i-self.slide_size+1:i+1] for i in range(l, r)]) for k, v in inputs.items()}
                kl_logits = self.split_forward(model, kl_inputs)
                kl_probs = kl_logits.softmax(-1) / self.align_tau
                kl_loss_each = torch.nn.functional.kl_div(
                    (logits[bsid, l:r, :].softmax(-1) / self.align_tau).log(),
                    kl_probs[:, -1, :],
                    reduction='none'
                ).sum(-1)
                kl_loss_bs = kl_loss_each.mean()

                (kl_loss_bs / bs * self.kl_ratio).backward()
                kl_loss_all += kl_loss_bs / bs

        return kl_loss_all

    def forward(self, model, tok, inputs, target, weights=None, weights_copy=None, compute_kl_loss=True, compute_loc_loss=True):
        logits = model(**inputs).logits
        probs = torch.log_softmax(logits, dim=2)
        loss = torch.gather(
            probs,
            2,
            torch.where(target != -100, target, 0).unsqueeze(2),
        ).squeeze(2)
        mask = (target != -100).float()
        nll_loss_each = -(loss * mask).sum(1) / mask.sum(1)
        loss = nll_loss_each.mean()
        loss.backward(retain_graph=True)

        return_data = [loss]
        if compute_kl_loss:
            kl_loss = self.compute_kl_loss(model, inputs, logits)
            return_data.append(kl_loss)
        if compute_loc_loss:
            assert weights is not None and weights_copy is not None
            loc_loss = self.compute_loc_loss(model, tok, weights, weights_copy)
            return_data.append(loc_loss)
        return return_data

plug = OursPlug()
batch_plug = OursPlugBatch()
