from copy import deepcopy
from typing import Any, Dict, List, Tuple

import torch
from pprint import pprint
from collections import defaultdict
import math
from transformers import AutoModelForCausalLM, AutoTokenizer

from util import nethook
from .pte_hparams import PTEHyperParams
from .ours_plug import plug, batch_plug


def apply_pte_to_model(
    model: AutoModelForCausalLM,
    tok: AutoTokenizer,
    requests: List[Dict],
    hparams: PTEHyperParams,
    copy=False,
    return_orig_weights=False,
    **kwargs: Any,
) -> Tuple[AutoModelForCausalLM, Dict[str, Any]]:
    """
    Returns a model with the desired changes.
    :param copy: If true, will preserve the original model while creating a new one to edit.
        Note that you are responsible for deallocating the new model's memory to avoid leaks.
    :return: (1) the updated model, (2) the weights that changed
    """

    weights_copy = {}
    if copy:
        model = deepcopy(model)

    deltas = pte_ours(model, tok, requests, hparams)

    with torch.no_grad():
        for w_name, upd_matrix in deltas.items():
            w = nethook.get_parameter(model, w_name)
            if return_orig_weights and w_name not in weights_copy:
                weights_copy[w_name] = w.detach().clone()

            w[...] += upd_matrix

    print(f"New weights successfully inserted into {list(deltas.keys())}")

    return model, weights_copy

def ours_batch(
    model: AutoModelForCausalLM,
    tok: AutoTokenizer,
    requests: List[Dict],
    hparams: PTEHyperParams,
    **kwargs: Any,
) -> Dict[str, Tuple[torch.Tensor]]:
    """
    Executes the FT update algorithm for the specified update at the specified layer
    Invariant: model at beginning of function == model at end of function
    """

    # Update target and print info
    requests = deepcopy(requests)
    # texts = [r['text'].strip() for r in requests]
    texts = [t.strip() for r in requests for t in r['texts']]

    # Retrieve weights that user desires to change
    weights = {
        n: p
        for n, p in model.named_parameters()
        for layer in hparams.layers
        if hparams.rewrite_module_tmp.format(layer) in n
    }
    # Save old weights for future restoration
    weights_copy = {k: v.detach().clone() for k, v in weights.items()}
    print(f"Weights to be updated: {list(weights.keys())}")

    # Configure optimizer / gradients
    opt = torch.optim.Adam(
        [v for _, v in weights.items()],
        lr=hparams.lr,
        weight_decay=hparams.weight_decay,
    )
    for name, w in model.named_parameters():
        w.requires_grad = name in weights

    print(50 * "=")
    for it in range(hparams.num_steps):
        inputs = tok(texts, return_tensors="pt", padding=True).to("cuda")
        target = torch.tensor(-100, device="cuda").repeat(len(texts), *inputs["input_ids"].shape[1:])
        for i in range(len(texts)):
            target[i, :inputs["attention_mask"][i].sum()-1] = inputs['input_ids'][i, 1:inputs["attention_mask"][i].sum()]

        opt.zero_grad()
        pt_loss = batch_plug.compute_pt_loss(model, inputs, target, hparams.batch_size)
        kl_loss = batch_plug.compute_kl_loss(model, inputs)
        loc_loss = batch_plug.compute_loc_loss(model, tok, weights, weights_copy)
        # loc_loss = torch.tensor(0.).cuda()
        opt.step()

        print(f"Epoch {it:02d}: PT loss {pt_loss.item()}. KL loss {kl_loss.item()}. Locality loss {loc_loss.item()}")
        if pt_loss.item() < 1e-2:
            break

    print(20 * "=")
    deltas = {k: (weights[k] - weights_copy[k]).detach() for k in weights}

    # Restore state of original model
    with torch.no_grad():
        for k, v in weights.items():
            v[...] = weights_copy[k]

    print(f"Deltas successfully computed for {list(weights.keys())}")

    return deltas

def pte_batch(
    model: AutoModelForCausalLM,
    tok: AutoTokenizer,
    requests: List[Dict],
    hparams: PTEHyperParams,
    **kwargs: Any,
) -> Dict[str, Tuple[torch.Tensor]]:
    """
    Executes the FT update algorithm for the specified update at the specified layer
    Invariant: model at beginning of function == model at end of function
    """

    # Update target and print info
    requests = deepcopy(requests)
    # texts = [r['text'].strip() for r in requests]
    texts = [t.strip() for r in requests for t in r['texts']]

    # Retrieve weights that user desires to change
    weights = {
        n: p
        for n, p in model.named_parameters()
        for layer in hparams.layers
        if hparams.rewrite_module_tmp.format(layer) in n
    }
    # Save old weights for future restoration
    weights_copy = {k: v.detach().clone() for k, v in weights.items()}
    print(f"Weights to be updated: {list(weights.keys())}")
    print(f'len(texts) = {len(texts)}')

    # Configure optimizer / gradients
    opt = torch.optim.Adam(
        [v for _, v in weights.items()],
        lr=hparams.lr,
        weight_decay=hparams.weight_decay,
    )
    # scheduler =  torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=opt, T_max=hparams.num_steps)
    for name, w in model.named_parameters():
        w.requires_grad = name in weights

    # Update loop: intervene at layers simultaneously
    loss_meter = AverageMeter()
    for it in range(hparams.num_steps):
        print(20 * "=")
        print(f"Epoch: {it}")
        print(20 * "=")
        loss_meter.reset()
        opt.zero_grad()

        for txt in chunks(texts, hparams.batch_size):
            inputs = tok(txt, return_tensors="pt", padding=True).to("cuda")
            target = torch.tensor(-100, device="cuda").repeat(len(txt), *inputs["input_ids"].shape[1:])
            for i in range(len(txt)):
                target[i, :inputs["attention_mask"][i].sum()-1] = inputs['input_ids'][i, 1:inputs["attention_mask"][i].sum()]

            bs = inputs["input_ids"].shape[0]
            logits = model(**inputs).logits
            probs = torch.log_softmax(logits, dim=2)
            loss = torch.gather(
                probs,
                2,
                torch.where(target != -100, target, 0).unsqueeze(2),
            ).squeeze(2)
            mask = (target != -100).float()
            nll_loss_each = -(loss * mask).sum(1) / mask.sum(1)
            loss = nll_loss_each.mean()

            print(f"Batch loss {loss.item()}")
            loss_meter.update(loss.item(), n=bs)

            if loss.item() >= 1e-2:
                (nll_loss_each.sum()/len(texts)).backward()

        opt.step()
        # scheduler.step()
        print(f"Total loss {loss_meter.avg}")
        if loss_meter.avg < 1e-2:
            break

    deltas = {k: (weights[k] - weights_copy[k]).detach() for k in weights}

    # Restore state of original model
    with torch.no_grad():
        for k, v in weights.items():
            v[...] = weights_copy[k]

    print(f"Deltas successfully computed for {list(weights.keys())}")

    return deltas

def pte_base(
    model: AutoModelForCausalLM,
    tok: AutoTokenizer,
    requests: List[Dict],
    hparams: PTEHyperParams,
    **kwargs: Any,
) -> Dict[str, Tuple[torch.Tensor]]:
    """
    Executes the FT update algorithm for the specified update at the specified layer
    Invariant: model at beginning of function == model at end of function
    """

    # Update target and print info
    requests = deepcopy(requests)
    texts = [r['text'].strip() for r in requests]
    # for r in requests:
    #     texts.extend(r['sentences'])
    # for r in requests:
    #     texts.extend(r['paraphrase'])

    # Retrieve weights that user desires to change
    weights = {
        n: p
        for n, p in model.named_parameters()
        for layer in hparams.layers
        if hparams.rewrite_module_tmp.format(layer) in n
    }
    # Save old weights for future restoration
    weights_copy = {k: v.detach().clone() for k, v in weights.items()}
    print(f"Weights to be updated: {list(weights.keys())}")

    # Configure optimizer / gradients
    opt = torch.optim.Adam(
        [v for _, v in weights.items()],
        lr=hparams.lr,
        weight_decay=hparams.weight_decay,
    )
    for name, w in model.named_parameters():
        w.requires_grad = name in weights

    # Update loop: intervene at layers simultaneously
    loss_meter = AverageMeter()
    for it in range(hparams.num_steps):
        print(20 * "=")
        print(f"Epoch: {it}")
        print(20 * "=")
        loss_meter.reset()

        for txt in chunks(texts, hparams.batch_size):
            inputs = tok(txt, return_tensors="pt", padding=True).to("cuda")
            target = torch.tensor(-100, device="cuda").repeat(len(txt), *inputs["input_ids"].shape[1:])
            for i in range(len(txt)):
                target[i, :inputs["attention_mask"][i].sum()-1] = inputs['input_ids'][i, 1:inputs["attention_mask"][i].sum()]

            opt.zero_grad()
            bs = inputs["input_ids"].shape[0]
            logits = model(**inputs).logits
            probs = torch.log_softmax(logits, dim=2)
            loss = torch.gather(
                probs,
                2,
                torch.where(target != -100, target, 0).unsqueeze(2),
            ).squeeze(2)
            mask = (target != -100).float()
            nll_loss_each = -(loss * mask).sum(1) / mask.sum(1)
            loss = nll_loss_each.mean()

            print(f"Batch loss {loss.item()}")
            loss_meter.update(loss.item(), n=bs)

            if loss.item() >= 1e-2:
                loss.backward()
                opt.step()

            if type(hparams.norm_constraint) is float:
                eps = hparams.norm_constraint
                with torch.no_grad():
                    for k, v in weights.items():
                        v[...] = torch.clamp(
                            v, min=weights_copy[k] - eps, max=weights_copy[k] + eps
                        )

        print(f"Total loss {loss_meter.avg}")

        if loss_meter.avg < 1e-2:
            break

    deltas = {k: (weights[k] - weights_copy[k]).detach() for k in weights}

    # Restore state of original model
    with torch.no_grad():
        for k, v in weights.items():
            v[...] = weights_copy[k]

    print(f"Deltas successfully computed for {list(weights.keys())}")

    return deltas

def pte_ours(
    model: AutoModelForCausalLM,
    tok: AutoTokenizer,
    requests: List[Dict],
    hparams: PTEHyperParams,
    **kwargs: Any,
) -> Dict[str, Tuple[torch.Tensor]]:
    """
    Executes the FT update algorithm for the specified update at the specified layer
    Invariant: model at beginning of function == model at end of function
    """

    # Update target and print info
    requests = deepcopy(requests)
    texts = [r['text'].strip() for r in requests]

    # Retrieve weights that user desires to change
    weights = {
        n: p
        for n, p in model.named_parameters()
        for layer in hparams.layers
        if hparams.rewrite_module_tmp.format(layer) in n
    }
    # Save old weights for future restoration
    weights_copy = {k: v.detach().clone() for k, v in weights.items()}
    print(f"Weights to be updated: {list(weights.keys())}")

    # Configure optimizer / gradients
    opt = torch.optim.Adam(
        [v for _, v in weights.items()],
        lr=hparams.lr,
        weight_decay=hparams.weight_decay,
    )
    for name, w in model.named_parameters():
        w.requires_grad = name in weights

    # Update loop: intervene at layers simultaneously
    loss_meter, kl_meter, loc_meter = AverageMeter(), AverageMeter(), AverageMeter()
    for it in range(hparams.num_steps):
        print(20 * "=")
        print(f"Epoch: {it}")
        print(20 * "=")
        loss_meter.reset()
        kl_meter.reset()
        loc_meter.reset()

        for txt in chunks(texts, hparams.batch_size):
            inputs = tok(txt, return_tensors="pt", padding=True).to("cuda")
            target = torch.tensor(-100, device="cuda").repeat(len(txt), *inputs["input_ids"].shape[1:])
            for i in range(len(txt)):
                target[i, :inputs["attention_mask"][i].sum()-1] = inputs['input_ids'][i, 1:inputs["attention_mask"][i].sum()]

            opt.zero_grad()
            bs = inputs["input_ids"].shape[0]
            loss, kl_loss, loc_loss = plug.forward(model, tok, inputs, target, weights, weights_copy)

            print(f"Batch loss {loss.item()}. KL loss {kl_loss.item()}. Locality loss {loc_loss.item()}")
            loss_meter.update(loss.item(), n=bs)
            kl_meter.update(kl_loss.item(), n=bs)
            loc_meter.update(loc_loss.item(), n=bs)
            opt.step()

            if type(hparams.norm_constraint) is float:
                eps = hparams.norm_constraint
                with torch.no_grad():
                    for k, v in weights.items():
                        v[...] = torch.clamp(
                            v, min=weights_copy[k] - eps, max=weights_copy[k] + eps
                        )

        print(f"Total loss {loss_meter.avg}. KL loss {kl_meter.avg}. Locality loss {loc_meter.avg}")

        if loss_meter.avg < 1e-2:
            break

    deltas = {k: (weights[k] - weights_copy[k]).detach() for k in weights}

    # Restore state of original model
    with torch.no_grad():
        for k, v in weights.items():
            v[...] = weights_copy[k]

    print(f"Deltas successfully computed for {list(weights.keys())}")

    return deltas

def chunks(arr, n):
    """Yield successive n-sized chunks from arr."""
    chunk = []
    for a in arr:
        chunk.append(a)
        if len(chunk) == n:
            yield chunk
            chunk = []
    if len(chunk) > 0:
        yield chunk

class AverageMeter:
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
