from copy import deepcopy
from typing import Any, Dict, List, Tuple

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from util import nethook

from peft import get_peft_model, LoraConfig, AdaLoraConfig, TaskType
from .lora_hparams import LORAHyperParams


def apply_lora_to_model(
    model: AutoModelForCausalLM,
    tok: AutoTokenizer,
    requests: List[Dict],
    hparams: LORAHyperParams,
    copy=False,
    return_orig_weights=False,
    **kwargs: Any,
) -> Tuple[AutoModelForCausalLM, Dict[str, Any]]:
    lora_model = lora_base(model, tok, requests, hparams)
    return lora_model, model

def lora_base(
    model: AutoModelForCausalLM,
    tok: AutoTokenizer,
    requests: List[Dict],
    hparams: LORAHyperParams,
    **kwargs: Any,
) -> Dict[str, Tuple[torch.Tensor]]:
    """
    Executes the FT update algorithm for the specified update at the specified layer
    Invariant: model at beginning of function == model at end of function
    """

    lora_config_dict = {
        'lora': LoraConfig,
        'adalora': AdaLoraConfig
    }

    lora_config = lora_config_dict[hparams.lora_type](
        task_type=TaskType.CAUSAL_LM,
        inference_mode=False,
        r=hparams.rank,
        lora_alpha=hparams.lora_alpha, lora_dropout=hparams.lora_dropout,
        layers_to_transform=hparams.layers if len(hparams.layers) > 0 else None,
        target_modules=hparams.target_modules
    )
    lora_model = get_peft_model(model, lora_config)
    lora_model.print_trainable_parameters()

    # Update target and print info
    requests = deepcopy(requests)
    texts = [r['text'].strip() for r in requests]

    # Configure optimizer / gradients
    opt = torch.optim.Adam(
        lora_model.parameters(),
        lr=hparams.lr,
        weight_decay=hparams.weight_decay,
    )

    # Update loop: intervene at layers simultaneously
    loss_meter = AverageMeter()
    for it in range(hparams.num_steps):
        print(20 * "=")
        print(f"Epoch: {it}")
        print(20 * "=")
        loss_meter.reset()

        for txt in chunks(texts, hparams.batch_size):
            inputs = tok(txt, return_tensors="pt", padding=True).to("cuda")
            target = torch.tensor(-100, device="cuda").repeat(len(txt), *inputs["input_ids"].shape[1:])
            for i in range(len(txt)):
                target[i, :inputs["attention_mask"][i].sum()-1] = inputs['input_ids'][i, 1:inputs["attention_mask"][i].sum()]

            opt.zero_grad()
            bs = inputs["input_ids"].shape[0]
            logits = lora_model(**inputs).logits
            probs = torch.log_softmax(logits, dim=2)
            loss = torch.gather(
                probs,
                2,
                torch.where(target != -100, target, 0).unsqueeze(2),
            ).squeeze(2)
            mask = (target != -100).float()
            nll_loss_each = -(loss * mask).sum(1) / mask.sum(1)
            loss = nll_loss_each.mean()

            print(f"Batch loss {loss.item()}")
            loss_meter.update(loss.item(), n=bs)

            if loss.item() >= 1e-2:
                loss.backward()
                opt.step()

        print(f"Total loss {loss_meter.avg}")

        if loss_meter.avg < 1e-2:
            break
    
    return lora_model

from baselines.rome.layer_stats import layer_stats
from util.globals import *

KK_cache = {}
def get_KK(
    model: AutoModelForCausalLM,
    tok: AutoTokenizer,
    layer_name: str,
) -> torch.Tensor:
    global KK_cache

    layer_name = layer_name.replace('.weight', '')
    model_name = model.config._name_or_path.replace("/", "_")
    key = (model_name, layer_name)

    if key not in KK_cache:
        print(
            f"Retrieving inverse covariance statistics for {model_name} @ {layer_name}. "
            f"The result will be cached to avoid repetitive computation."
        )
        stat = layer_stats(
            model,
            tok,
            layer_name,
            STATS_DIR,
            'wikipedia',
            to_collect=["mom2"],
            sample_size=100000,
            precision='float32',
        )
        KK_cache[key] = stat.mom2.moment().to("cuda").float()
    
    return KK_cache[key]

def compute_loc_loss(lora_model, tok):
    loc_loss_list = []
    for key in weights.keys():
        w, wc = weights[key], weights_copy[key]
        kk = get_KK(model, tok, key)
        delta = (w - wc) @ kk
        key_loc_loss = (delta * delta).sum()    # null space: delta.abs().sum()
        loc_loss_list.append(key_loc_loss)
    loc_loss = sum(loc_loss_list)

    return loc_loss

def compute_kl_loss(lora_model, inputs, logits):
    tau = 1
    slide = 5
    max_context_ratio = 1

    global first_probs
    bs = inputs["input_ids"].shape[0]
    if first_probs is None:
        with torch.no_grad():
            first_probs = []
            for bsid in range(bs):
                l, r = slide, (inputs["attention_mask"][bsid].sum()-2).item()
                mc = int(inputs["attention_mask"][bsid].size(0) * max_context_ratio)
                if slide+mc < r:
                    cur_probs = (torch.softmax(logits[bsid, slide:slide+mc, :], dim=-1) / tau).log().detach().clone()
                    f_inputs = {k: torch.stack([v[bsid, i-slide-mc+1:i+1] for i in range(slide+mc, r)]) for k, v in inputs.items()}
                    f_logits = lora_model(**f_inputs).logits
                    f_probs = (torch.softmax(f_logits, dim=-1) / tau).log().detach().clone()
                    first_probs.append(torch.cat([cur_probs, f_probs[:, -1, :]]))
                else:
                    first_probs.append((torch.softmax(logits[bsid, l:r, :], dim=-1) / tau).log().detach().clone())

    kl_loss_list = []
    for bsid in range(bs):
        l, r = slide, (inputs["attention_mask"][bsid].sum()-2).item()
        if l < r:
            kl_inputs = {k: torch.stack([v[bsid, i-slide+1:i+1] for i in range(l, r)]) for k, v in inputs.items()}
            kl_logits = lora_model(**kl_inputs).logits
            kl_probs = torch.nn.functional.softmax(kl_logits, dim=-1) / tau
            kl_loss_each = torch.nn.functional.kl_div(
                first_probs[bsid],
                kl_probs[:, -1, :],
                reduction='none'
            ).sum(-1)
            # kl_loss_bs = kl_loss_each.topk(int(kl_loss_each.size(0)*0.3)).values.mean()
            kl_loss_bs = kl_loss_each.mean()    # base
            # idx1, idx2 = int(kl_loss_each.size(0)//3), int(kl_loss_each.size(0)//3*2)
            # kl_loss_bs = kl_loss_each[:int(kl_loss_each.size(0)//6)].mean()
            # kl_loss_bs = torch.cat((kl_loss_each[:idx1], kl_loss_each[idx2:])).mean()
            kl_loss_list.append(kl_loss_bs)
        else:
            kl_loss_list.append(torch.tensor(0.).cuda())
    kl_loss = sum(kl_loss_list) / len(kl_loss_list)

    return kl_loss

def lora_kvloc(
    model: AutoModelForCausalLM,
    tok: AutoTokenizer,
    requests: List[Dict],
    hparams: LORAHyperParams,
    **kwargs: Any,
) -> Dict[str, Tuple[torch.Tensor]]:
    """
    Executes the FT update algorithm for the specified update at the specified layer
    Invariant: model at beginning of function == model at end of function
    """

    lora_config_dict = {
        'lora': LoraConfig,
        'adalora': AdaLoraConfig
    }

    lora_config = lora_config_dict[hparams.lora_type](
        task_type=TaskType.CAUSAL_LM,
        inference_mode=False,
        r=hparams.rank,
        lora_alpha=hparams.lora_alpha, lora_dropout=hparams.lora_dropout,
        layers_to_transform=hparams.layers if len(hparams.layers) > 0 else None,
        target_modules=hparams.target_modules
    )
    lora_model = get_peft_model(model, lora_config)
    lora_model.print_trainable_parameters()

    # Update target and print info
    requests = deepcopy(requests)
    texts = [r['text'].strip() for r in requests]

    # Configure optimizer / gradients
    opt = torch.optim.Adam(
        lora_model.parameters(),
        lr=hparams.lr,
        weight_decay=hparams.weight_decay,
    )

    # Update loop: intervene at layers simultaneously
    loss_meter, loc_meter = AverageMeter(), AverageMeter()
    for it in range(hparams.num_steps):
        print(20 * "=")
        print(f"Epoch: {it}")
        print(20 * "=")
        loss_meter.reset()
        loc_meter.reset()

        for txt in chunks(texts, hparams.batch_size):
            inputs = tok(txt, return_tensors="pt", padding=True).to("cuda")
            target = torch.tensor(-100, device="cuda").repeat(len(txt), *inputs["input_ids"].shape[1:])
            for i in range(len(txt)):
                target[i, :inputs["attention_mask"][i].sum()-1] = inputs['input_ids'][i, 1:inputs["attention_mask"][i].sum()]

            opt.zero_grad()
            bs = inputs["input_ids"].shape[0]
            logits = lora_model(**inputs).logits
            probs = torch.log_softmax(logits, dim=2)
            loss = torch.gather(
                probs,
                2,
                torch.where(target != -100, target, 0).unsqueeze(2),
            ).squeeze(2)
            mask = (target != -100).float()
            nll_loss_each = -(loss * mask).sum(1) / mask.sum(1)
            loss = nll_loss_each.mean()

            loc_ratio = 0.5
            loc_loss = compute_loc_loss(lora_model, tok)

            print(f"Batch loss {loss.item()}. Locality loss {loc_loss.item()}")
            loss_meter.update(loss.item(), n=bs)

            if loss.item() >= 1e-2:
                loss.backward()
                opt.step()

        print(f"Total loss {loss_meter.avg}")

        if loss_meter.avg < 1e-2:
            break
    
    return lora_model

def lora_slide_align(
    model: AutoModelForCausalLM,
    tok: AutoTokenizer,
    requests: List[Dict],
    hparams: LORAHyperParams,
    **kwargs: Any,
) -> Dict[str, Tuple[torch.Tensor]]:
    """
    Executes the FT update algorithm for the specified update at the specified layer
    Invariant: model at beginning of function == model at end of function
    """

    lora_config_dict = {
        'lora': LoraConfig,
        'adalora': AdaLoraConfig
    }

    lora_config = lora_config_dict[hparams.lora_type](
        task_type=TaskType.CAUSAL_LM,
        inference_mode=False,
        r=hparams.rank,
        lora_alpha=hparams.lora_alpha, lora_dropout=hparams.lora_dropout,
        layers_to_transform=hparams.layers if len(hparams.layers) > 0 else None,
        target_modules=hparams.target_modules
    )
    lora_model = get_peft_model(model, lora_config)
    lora_model.print_trainable_parameters()

    # Update target and print info
    requests = deepcopy(requests)
    texts = [r['text'].strip() for r in requests]

    # Configure optimizer / gradients
    opt = torch.optim.Adam(
        lora_model.parameters(),
        lr=hparams.lr,
        weight_decay=hparams.weight_decay,
    )

    # Update loop: intervene at layers simultaneously
    loss_meter, kl_meter = AverageMeter(), AverageMeter()
    global first_probs
    first_probs = None
    for it in range(hparams.num_steps):
        print(20 * "=")
        print(f"Epoch: {it}")
        print(20 * "=")
        loss_meter.reset()
        kl_meter.reset()

        for txt in chunks(texts, hparams.batch_size):
            inputs = tok(txt, return_tensors="pt", padding=True).to("cuda")
            target = torch.tensor(-100, device="cuda").repeat(len(txt), *inputs["input_ids"].shape[1:])
            for i in range(len(txt)):
                target[i, :inputs["attention_mask"][i].sum()-1] = inputs['input_ids'][i, 1:inputs["attention_mask"][i].sum()]

            opt.zero_grad()
            bs = inputs["input_ids"].shape[0]
            logits = lora_model(**inputs).logits
            probs = torch.log_softmax(logits, dim=2)
            loss = torch.gather(
                probs,
                2,
                torch.where(target != -100, target, 0).unsqueeze(2),
            ).squeeze(2)
            mask = (target != -100).float()
            nll_loss_each = -(loss * mask).sum(1) / mask.sum(1)
            loss = nll_loss_each.mean()

            kl_ratio = 0.5
            kl_loss = compute_kl_loss(model, inputs, logits)

            print(f"Batch loss {loss.item()}. KL loss {kl_loss.item()}")
            loss_meter.update(loss.item(), n=bs)
            kl_meter.update(kl_loss.item(), n=bs)

            final_loss = loss + kl_ratio * kl_loss
            if final_loss.item() >= 1e-2:
                final_loss.backward()
                opt.step()

        print(f"Total loss {loss_meter.avg}. KL loss {kl_meter.avg}")

        # if loss_meter.avg + kl_ratio * kl_meter.avg < 1e-2:
        if loss_meter.avg < 1e-2:
            break
    
    return lora_model



def chunks(arr, n):
    """Yield successive n-sized chunks from arr."""
    chunk = []
    for a in arr:
        chunk.append(a)
        if len(chunk) == n:
            yield chunk
            chunk = []
    if len(chunk) > 0:
        yield chunk

class AverageMeter:
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
