import torch
from torch import nn
from tqdm import tqdm

from models.clip_ft_utils.hooks import hook_forward_store_inputs
from models.clip_ft_utils.fisher_kfac import KFACComputer

from models.clip_ft_utils.utils import FisherLoader
from models.clip_ft_utils.utils import set_requires_grad_to


def get_split(dataset):
    return dataset.train_loader


@torch.no_grad()
def hook_backward_ekfac(module, _, grad_output):
    grad_out = grad_output[0]
    inputs = module.inputs

    if len(grad_out.shape) > 2:
        if 'attn.proj' in module.name or 'attn.qkv' in module.name:
            B, R, C = grad_out.shape
        else:
            R, B, C = grad_out.shape
            grad_out = grad_out.permute(1, 0, 2)
            inputs = inputs.permute(1, 0, 2)
        grad_weight = torch.einsum('blo,bli->boi', grad_out, inputs)
    else:
        grad_weight = torch.einsum('bo,bi->boi', grad_out, inputs)

    grad_bias = None

    if hasattr(module, "bias") and module.compute_bias:
        if len(grad_out.shape) > 2:
            grad_bias = grad_out.sum(1)
        else:
            assert False

    if grad_bias is not None:
        grad_weight = torch.cat((grad_weight, grad_bias.unsqueeze(2)), dim=2)

    grad_weight = torch.einsum('ij,bjk->bik', module.UG.T, grad_weight)
    grad_weight = torch.einsum('bij,jk->bik', grad_weight, module.UA)
    grad_weight = grad_weight.pow(2).sum(0)

    # --- Gram bias ---
    if not hasattr(module, "grad_weight"):
        module.grad_weight = torch.zeros_like(grad_weight)
        module.grad_weight_c = torch.zeros_like(grad_weight)

    # Kahan summation
    y_b = grad_weight - module.grad_weight_c
    t_b = module.grad_weight + y_b
    module.grad_weight_c = (t_b - module.grad_weight) - y_b
    module.grad_weight = t_b


def register_hooks(name, module, forward=True, backward=True,
                   forward_hooks_dict=None, bacward_hooks_dict=None):
    module.name = name

    if forward:
        assert forward_hooks_dict is not None
        if 'lin_proj' in name:
            module.forward_handle = module.register_forward_hook(forward_hooks_dict['hook_forward_nosequence']) # type: ignore
        elif isinstance(module, nn.Linear) or \
                isinstance(module, nn.modules.linear.NonDynamicallyQuantizableLinear):
            module.forward_handle = module.register_forward_hook(forward_hooks_dict['hook_forward']) # type: ignore
        elif isinstance(module, nn.LayerNorm):
            module.forward_handle = module.register_forward_hook(forward_hooks_dict['hook_forward_layer_norm']) # type: ignore
        elif 'cls_token' in name:
            module.forward_handle = module.register_forward_hook(forward_hooks_dict['hook_forward_layer_norm']) # type: ignore

    if backward:
        assert bacward_hooks_dict is not None
        if 'lin_proj' in name:
            module.backward_handle = module.register_full_backward_hook(
                bacward_hooks_dict['hook_backward_nosequence'])  # type: ignore
        elif isinstance(module, nn.Linear) or \
                isinstance(module, nn.modules.linear.NonDynamicallyQuantizableLinear):
            module.backward_handle = module.register_full_backward_hook(
                bacward_hooks_dict['hook_backward'])  # type: ignore
        elif isinstance(module, nn.LayerNorm):
            module.backward_handle = module.register_full_backward_hook(
                bacward_hooks_dict['hook_backward_layer_norm'])  # type: ignore
        elif 'cls_token' in name:
            module.backward_handle = module.register_full_backward_hook(bacward_hooks_dict['hook_backward_cls_token']) # type: ignore


class EKFAComputer(nn.Module):

    def __init__(self, device: torch.device, debug_mode,
                 fisher_loader: FisherLoader = None, train_percent: float = 1.0,
                 num_samples_expectation: int = 0):

        super().__init__()

        assert 0 < train_percent <= 1.0

        self.device = device
        self.debug_mode = debug_mode
        self.train_percent = train_percent

        self.fisher_kfac_loader = fisher_loader
        self.current_task = -1
        self.num_samples_expectation = num_samples_expectation

        self.kfac_computer = KFACComputer(device, debug_mode, train_percent, num_samples_expectation)

    def to_be_fishered(self, name, module, all_param_finetuned):
        if not isinstance(module, nn.Linear) \
                and not isinstance(module, nn.modules.linear.NonDynamicallyQuantizableLinear) \
                and not isinstance(module, nn.MultiheadAttention):
            return False
        if f"{name}.weight" in all_param_finetuned \
                or f"{name}.bias" in all_param_finetuned:
            return True
        else:
            return False

    def to_be_fishered_layer_norm(self, name, module, all_param_finetuned):
        if not isinstance(module, nn.LayerNorm):
            return False
        if f"{name}.weight" in all_param_finetuned \
                or f"{name}.bias" in all_param_finetuned:
            return True
        else:
            return False

    def compute(self, net, head, delta_w_names, dataset, use_head=False):

        self.current_task += 1

        if self.fisher_kfac_loader is None:
            ggT, aaT, ffT, num_of_examples_ggT, num_of_examples_aaT = \
                self.kfac_computer.compute(net, head, delta_w_names, dataset, use_head)
        else:
            ggT, aaT, ffT, num_of_examples_ggT, num_of_examples_aaT = \
                self.fisher_kfac_loader.load_kfac(self.current_task)

        all_param_finetuned = list(delta_w_names)
        num_of_batches = int(self.train_percent * len(dataset.train_loader))

        orig_mode = net.visual_encoder.training
        net.visual_encoder.eval()

        assert num_of_examples_ggT == num_of_examples_aaT
        num_of_examples = num_of_examples_ggT

        set_requires_grad_to(net.visual_encoder, delta_w_names, True)

        UA, UG = {}, {}

        assert ggT.keys() == aaT.keys()

        aaT_keys = list(aaT.keys())

        for i, k in tqdm(enumerate(aaT_keys), total=len(aaT_keys), desc='SVD computation'):
            aaT_matrix, ggT_matrix = aaT[k].double(), ggT[k].double()
            UA[k] = torch.linalg.svd(aaT_matrix / num_of_examples)[0].float()
            UG[k] = torch.linalg.svd(ggT_matrix / num_of_examples)[0].float()
            del aaT[k]
            del ggT[k]

        aaT.clear()
        ggT.clear()

        fake_optim = torch.optim.SGD(
            params=[p for (n, p) in net.visual_encoder.named_parameters() if n in delta_w_names],
            lr=0.0
        )

        forward_hooks_dict = {
            'hook_forward': hook_forward_store_inputs,
            'hook_forward_nosequence': hook_forward_store_inputs,
        }

        backward_hooks_dict = {
            'hook_backward': hook_backward_ekfac,
            'hook_backward_nosequence': hook_backward_ekfac,
        }

        for name, module in net.visual_encoder.named_modules():
            if self.to_be_fishered(name, module, all_param_finetuned):
                module.compute_bias = True if f"{name}.bias" in all_param_finetuned else False
                module.UA = UA[f"{name}.weight"]
                module.UG = UG[f"{name}.weight"]
                register_hooks(name, module, forward=True, backward=True,
                               forward_hooks_dict=forward_hooks_dict,
                               bacward_hooks_dict=backward_hooks_dict)

        fake_param = torch.tensor([1.], requires_grad=True).to(self.device)

        for i, data in tqdm(enumerate(get_split(dataset)),
                            total=len(get_split(dataset)),
                            desc='D computation'):

            if self.debug_mode and i > 1:
                break

            if i > num_of_batches:
                break

            x = data[0].to(self.device)

            features = net.visual_encoder(x * fake_param)
            features = features / features.norm(dim=-1, keepdim=True)

            if use_head:
                features = head(features)

            if self.num_samples_expectation > 0:
                for s in range(self.num_samples_expectation):
                    (features * torch.randn_like(features)).sum().backward(
                        retain_graph=s < self.num_samples_expectation - 1)
            else:
                features = features.sum(0)
                for cnt_class, feat in enumerate(features):
                    fake_optim.zero_grad()
                    feat.backward(retain_graph=cnt_class < features.shape[0] - 1)

        fake_optim.zero_grad()

        D = {}

        def collect_D(name, module):
            if f"{name}.weight" in all_param_finetuned:
                D[f"{name}.weight"] = getattr(module, "grad_weight") / num_of_examples

        for (name, module) in net.visual_encoder.named_modules():
            if self.to_be_fishered(name, module, all_param_finetuned):
                collect_D(name, module)

        # remove hooks
        for name, module in net.visual_encoder.named_modules():
            if self.to_be_fishered(name, module, all_param_finetuned):
                del module.compute_bias
                module.backward_handle.remove()
                module.grad_weight = None
                module.grad_weight_c = None
                del module.grad_weight
                del module.grad_weight_c
                del module.UA
                del module.UG

        set_requires_grad_to(net.visual_encoder, delta_w_names, False)
        net.visual_encoder.train(orig_mode)

        del fake_optim

        return UA, UG, D, ffT, num_of_examples

