import math

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


class AddBias(nn.Module):
    def __init__(self, bias):
        super(AddBias, self).__init__()
        self._bias = nn.Parameter(bias.unsqueeze(1))   

    def forward(self, x):
        C = self._bias.size(0)
        if x.dim() == 1:
            bias = self._bias.view(C)                  # (C,)
        elif x.dim() == 2:
            # x: (B, C)
            bias = self._bias.t().view(1, C)           # (1, C)
        elif x.dim() == 3:
            # x: (B, C, L)  
            bias = self._bias.t().view(1, C, 1)        # (1, C, 1)
        else:
            shape = [1] * x.dim()
            shape[1] = C
            bias = self._bias.t().view(*shape)
        return x + bias



def _extract_patches(x, kernel_size, stride, padding):
    if padding[0] + padding[1] > 0:
        x = F.pad(x, (padding[1], padding[1], padding[0],
                      padding[0])).data  # Actually check dims
    x = x.unfold(2, kernel_size[0], stride[0])
    x = x.unfold(3, kernel_size[1], stride[1])
    x = x.transpose_(1, 2).transpose_(2, 3).contiguous()
    x = x.view(
        x.size(0), x.size(1), x.size(2),
        x.size(3) * x.size(4) * x.size(5))
    return x


def compute_cov_a(a, classname, layer_info, fast_cnn):
    batch_size = a.size(0)

    if classname == 'Conv2d':
        if fast_cnn:
            a = _extract_patches(a, *layer_info)
            a = a.view(a.size(0), -1, a.size(-1))
            a = a.mean(1)
        else:
            a = _extract_patches(a, *layer_info)
            a = a.view(-1, a.size(-1)).div_(a.size(1)).div_(a.size(2))
    elif classname == 'AddBias':
        is_cuda = a.is_cuda
        a = torch.ones(a.size(0), 1)
        if is_cuda:
            a = a.cuda()

    return a.t() @ (a / batch_size)


def compute_cov_g(g, classname, layer_info, fast_cnn):
    batch_size = g.size(0)

    if classname == 'Conv2d':
        if fast_cnn:
            g = g.view(g.size(0), g.size(1), -1)
            g = g.sum(-1)
        else:
            g = g.transpose(1, 2).transpose(2, 3).contiguous()
            g = g.view(-1, g.size(-1)).mul_(g.size(1)).mul_(g.size(2))
    elif classname == 'AddBias':
        g = g.view(g.size(0), g.size(1), -1)
        g = g.sum(-1)

    g_ = g * batch_size
    return g_.t() @ (g_ / g.size(0))


def update_running_stat(aa, m_aa, momentum):
    # Do the trick to keep aa unchanged and not create any additional tensors
    m_aa *= momentum / (1 - momentum)
    m_aa += aa
    m_aa *= (1 - momentum)


class SplitBias(nn.Module):
    # def __init__(self, module):
    #     super(SplitBias, self).__init__()
    #     self.module = module
    #     self.add_bias = AddBias(module.bias.data)
    #     self.module.bias = None
    def __init__(self, module):
        super(SplitBias, self).__init__()
        self.module = module
        self.add_bias = AddBias(module.bias.data)
        self.module.bias = None

    # def forward(self, input):
    #     x = self.module(input)
    #     x = self.add_bias(x)
    #     return x
    def forward(self, input):
        x = self.module(input)  
        x = self.add_bias(x)
        return x



class KFACOptimizer(optim.Optimizer):
    def __init__(self,
                 model,
                 lr=0.25,
                 momentum=0.9,
                 stat_decay=0.99,
                 kl_clip=0.001,
                 damping=1e-2,
                 weight_decay=0,
                 fast_cnn=False,
                 Ts=1,
                 Tf=10,
                 use_trust_region=False,):
        defaults = dict()

        def split_bias(module):
            for mname, child in module.named_children():
                if hasattr(child, 'bias') and child.bias is not None:
                    module._modules[mname] = SplitBias(child)
                else:
                    split_bias(child)

        split_bias(model)

        super(KFACOptimizer, self).__init__(model.parameters(), defaults)

        self.known_modules = {'Linear', 'Conv2d', 'AddBias'}

        self.modules = []
        self.grad_outputs = {}

        self.model = model
        self._prepare_model()

        self.steps = 0

        self.m_aa, self.m_gg = {}, {}
        self.Q_a, self.Q_g = {}, {}
        self.d_a, self.d_g = {}, {}

        self.momentum = momentum
        self.stat_decay = stat_decay

        self.lr = lr
        self.kl_clip = kl_clip
        self.damping = damping
        self.weight_decay = weight_decay
        self.use_trust_region = use_trust_region

        self.fast_cnn = fast_cnn

        self.Ts = Ts
        self.Tf = Tf

        self.optim = optim.SGD(
            model.parameters(),
            lr=self.lr * (1 - self.momentum),
            momentum=self.momentum)

    def _save_input(self, module, input):
        if torch.is_grad_enabled() and self.steps % self.Ts == 0:
            classname = module.__class__.__name__
            layer_info = None
            if classname == 'Conv2d':
                layer_info = (module.kernel_size, module.stride,
                              module.padding)

            aa = compute_cov_a(input[0].data, classname, layer_info,
                               self.fast_cnn)

            # Initialize buffers
            if self.steps == 0:
                self.m_aa[module] = aa.clone()

            update_running_stat(aa, self.m_aa[module], self.stat_decay)

    def _save_grad_output(self, module, grad_input, grad_output):
        # Accumulate statistics for Fisher matrices
        if self.acc_stats:
            classname = module.__class__.__name__
            layer_info = None
            if classname == 'Conv2d':
                layer_info = (module.kernel_size, module.stride,
                              module.padding)

            gg = compute_cov_g(grad_output[0].data, classname, layer_info,
                               self.fast_cnn)

            # Initialize buffers
            if self.steps == 0:
                self.m_gg[module] = gg.clone()

            update_running_stat(gg, self.m_gg[module], self.stat_decay)

    def _prepare_model(self):
        for module in self.model.modules():
            classname = module.__class__.__name__
            if classname in self.known_modules:
                assert not ((classname in ['Linear', 'Conv2d']) and module.bias is not None), \
                                    "You must have a bias as a separate layer"

                self.modules.append(module)
                # module.register_forward_pre_hook(self._save_input)
                # module.register_backward_hook(self._save_grad_output)
                module.register_forward_pre_hook(self._save_input)
                module.register_full_backward_hook(self._save_grad_output)


    def step(self):

        if self.weight_decay > 0:
            for p in self.model.parameters():
                if p.grad is None:
                    continue
                p.grad.data.add_(p.data, alpha=self.weight_decay)

        updates = {}
        handled = set()
        for i, m in enumerate(self.modules):
            assert len(list(m.parameters())) == 1, "Can handle only one parameter at the moment"
            classname = m.__class__.__name__
            p = next(m.parameters())
            if p.grad is None:
                continue

            la = self.damping + self.weight_decay

            if self.steps % self.Tf == 0:
                self.d_a[m], self.Q_a[m] = torch.linalg.eigh(self.m_aa[m], UPLO='L')
                self.d_g[m], self.Q_g[m] = torch.linalg.eigh(self.m_gg[m], UPLO='L')
                self.d_a[m].mul_((self.d_a[m] > 1e-6).float())
                self.d_g[m].mul_((self.d_g[m] > 1e-6).float())

            if classname == 'Conv2d':
                p_grad_mat = p.grad.data.view(p.grad.data.size(0), -1)
            else:
                p_grad_mat = p.grad.data

            v1 = self.Q_g[m].t() @ p_grad_mat @ self.Q_a[m]
            v2 = v1 / (self.d_g[m].unsqueeze(1) * self.d_a[m].unsqueeze(0) + la)
            v  = self.Q_g[m] @ v2 @ self.Q_a[m].t()
            v  = v.view(p.grad.data.size())

            updates[p] = v
            handled.add(p)

        if self.use_trust_region:

            vg_sum = None
            for p, v in updates.items():
                term = (v * p.grad.data * self.lr * self.lr).sum()
                vg_sum = term if vg_sum is None else vg_sum + term
            if vg_sum is not None and torch.isfinite(vg_sum) and vg_sum.item() > 0:
                nu = min(1.0, math.sqrt(self.kl_clip / vg_sum))
            else:
                nu = 1.0
        else:
            nu = 1.0

        for p, v in updates.items():
            p.grad.data.copy_(v)
            if nu != 1.0:
                p.grad.data.mul_(nu)

        for p in self.model.parameters():
            if p.grad is None:
                continue
            if p not in updates and nu != 1.0:
                p.grad.data.mul_(nu)

        self.optim.step()
        self.steps += 1
