from copy import deepcopy
from typing import Any, Dict, List, Tuple

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from util import nethook

from peft import get_peft_model, LoraConfig, AdaLoraConfig, TaskType
from .lora_hparams import LORAHyperParams
from PTE.ours_plug import plug, batch_plug


def apply_lora_to_model(
    model: AutoModelForCausalLM,
    tok: AutoTokenizer,
    requests: List[Dict],
    hparams: LORAHyperParams,
    copy=False,
    return_orig_weights=False,
    **kwargs: Any,
) -> Tuple[AutoModelForCausalLM, Dict[str, Any]]:
    lora_model = lora_base(model, tok, requests, hparams)
    return lora_model, model

def ours_base(
    model: AutoModelForCausalLM,
    tok: AutoTokenizer,
    requests: List[Dict],
    hparams: LORAHyperParams,
    **kwargs: Any,
) -> Dict[str, Tuple[torch.Tensor]]:
    """
    Executes the FT update algorithm for the specified update at the specified layer
    Invariant: model at beginning of function == model at end of function
    """

    lora_config_dict = {
        'lora': LoraConfig,
        'adalora': AdaLoraConfig
    }

    lora_config = lora_config_dict[hparams.lora_type](
        task_type=TaskType.CAUSAL_LM,
        inference_mode=False,
        r=hparams.rank,
        lora_alpha=hparams.lora_alpha, lora_dropout=hparams.lora_dropout,
        layers_to_transform=hparams.layers if len(hparams.layers) > 0 else None,
        target_modules=hparams.target_modules
    )
    lora_model = get_peft_model(model, lora_config)
    lora_model.print_trainable_parameters()

    # Update target and print info
    requests = deepcopy(requests)
    texts = [r['text'].strip() for r in requests]
    # texts = [t.strip() for r in requests for t in r['texts']]

    # Configure optimizer / gradients
    opt = torch.optim.Adam(
        lora_model.parameters(),
        lr=hparams.lr,
        weight_decay=hparams.weight_decay,
    )

    # Update loop: intervene at layers simultaneously
    loss_meter, kl_meter = AverageMeter(), AverageMeter()
    for it in range(hparams.num_steps):
        print(20 * "=")
        print(f"Epoch: {it}")
        print(20 * "=")
        loss_meter.reset()
        kl_meter.reset()

        for txt in chunks(texts, hparams.batch_size):
            inputs = tok(txt, return_tensors="pt", padding=True).to("cuda")
            target = torch.tensor(-100, device="cuda").repeat(len(txt), *inputs["input_ids"].shape[1:])
            for i in range(len(txt)):
                target[i, :inputs["attention_mask"][i].sum()-1] = inputs['input_ids'][i, 1:inputs["attention_mask"][i].sum()]

            opt.zero_grad()
            bs = inputs["input_ids"].shape[0]
            if len(txt) == 1:
                loss, kl_loss = plug.forward(model, tok, inputs, target, compute_loc_loss=False)
            else:
                loss = batch_plug.compute_pt_loss(model, inputs, target, hparams.batch_size)
                kl_loss = batch_plug.compute_kl_loss(model, inputs)

            print(f"Batch loss {loss.item()}. KL loss {kl_loss.item()}.")
            loss_meter.update(loss.item(), n=bs)
            kl_meter.update(kl_loss.item(), n=bs)
            opt.step()

        print(f"Total loss {loss_meter.avg}. KL loss {kl_meter.avg}")

        if loss_meter.avg < 1e-2:
            break
    
    return lora_model

def lora_base(
    model: AutoModelForCausalLM,
    tok: AutoTokenizer,
    requests: List[Dict],
    hparams: LORAHyperParams,
    **kwargs: Any,
) -> Dict[str, Tuple[torch.Tensor]]:
    """
    Executes the FT update algorithm for the specified update at the specified layer
    Invariant: model at beginning of function == model at end of function
    """

    lora_config_dict = {
        'lora': LoraConfig,
        'adalora': AdaLoraConfig
    }

    lora_config = lora_config_dict[hparams.lora_type](
        task_type=TaskType.CAUSAL_LM,
        inference_mode=False,
        r=hparams.rank,
        lora_alpha=hparams.lora_alpha, lora_dropout=hparams.lora_dropout,
        layers_to_transform=hparams.layers if len(hparams.layers) > 0 else None,
        target_modules=hparams.target_modules
    )
    lora_model = get_peft_model(model, lora_config)
    lora_model.print_trainable_parameters()

    # Update target and print info
    requests = deepcopy(requests)
    texts = [r['text'].strip() for r in requests]
    # texts = [t.strip() for r in requests for t in r['texts']]

    # Configure optimizer / gradients
    opt = torch.optim.Adam(
        lora_model.parameters(),
        lr=hparams.lr,
        weight_decay=hparams.weight_decay,
    )

    # Update loop: intervene at layers simultaneously
    loss_meter = AverageMeter()
    for it in range(hparams.num_steps):
        print(20 * "=")
        print(f"Epoch: {it}")
        print(20 * "=")
        loss_meter.reset()

        for txt in chunks(texts, hparams.batch_size):
            inputs = tok(txt, return_tensors="pt", padding=True).to("cuda")
            target = torch.tensor(-100, device="cuda").repeat(len(txt), *inputs["input_ids"].shape[1:])
            for i in range(len(txt)):
                target[i, :inputs["attention_mask"][i].sum()-1] = inputs['input_ids'][i, 1:inputs["attention_mask"][i].sum()]

            opt.zero_grad()
            bs = inputs["input_ids"].shape[0]
            logits = lora_model(**inputs).logits
            probs = torch.log_softmax(logits, dim=2)
            loss = torch.gather(
                probs,
                2,
                torch.where(target != -100, target, 0).unsqueeze(2),
            ).squeeze(2)
            mask = (target != -100).float()
            nll_loss_each = -(loss * mask).sum(1) / mask.sum(1)
            loss = nll_loss_each.mean()

            print(f"Batch loss {loss.item()}")
            loss_meter.update(loss.item(), n=bs)

            if loss.item() >= 1e-3:
                loss.backward()
                opt.step()

        print(f"Total loss {loss_meter.avg}")

        if loss_meter.avg < 1e-3:
            break
    
    return lora_model

def chunks(arr, n):
    """Yield successive n-sized chunks from arr."""
    chunk = []
    for a in arr:
        chunk.append(a)
        if len(chunk) == n:
            yield chunk
            chunk = []
    if len(chunk) > 0:
        yield chunk

class AverageMeter:
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
