import torch
import torch.nn.functional as F
from torch.optim.optimizer import Optimizer

#from my_modules import Linear, Conv2d

import torch.distributed as dist


class KFAC(Optimizer):

    def __init__(self, net, eps, sua=False, pi=False, update_freq=1,
                 alpha=1.0, constraint_norm=False, distributed=True, world_size=-1):
        """ K-FAC Preconditionner for Linear and Conv2d layers.
        Computes the K-FAC of the second moment of the gradients.
        It works for Linear and Conv2d layers and silently skip other layers.

        Ruos: Modify for synchronization across multiple GPUs

        Args:
            net (torch.nn.Module): Network to precondition.
            eps (float): Tikhonov regularization parameter for the inverses.
            sua (bool): Applies SUA approximation.
            pi (bool): Computes pi correction for Tikhonov regularization.
            update_freq (int): Perform inverses every update_freq updates.
            alpha (float): Running average parameter (if == 1, no r. ave.).
            constraint_norm (bool): Scale the gradients by the squared
                fisher norm.
        """
        self.eps = eps
        self.sua = sua
        self.pi = pi
        self.update_freq = update_freq
        self.alpha = alpha
        self.constraint_norm = constraint_norm
        self.params = []
        self._fwd_handles = []
        self._bwd_handles = []
        self._iteration_counter = 0
        self.distributed = distributed
        self.world_size = world_size
        for mod in net.modules():
            mod_class = mod.__class__.__name__
            # print(mod_class)
            # print("attrs:", mod.__getattr__)
            if mod_class in ['AdapterLayer', "BiasLayer", "LowRankLinear"]:
                for sub_mod in mod.modules():
                    sub_mod_class = sub_mod.__class__.__name__
                    # print(sub_mod_class)
                    # print("attrs:", sub_mod.__dict__)
                    if sub_mod_class in ['Linear']:
                        # print("register")
                        handle = sub_mod.register_forward_pre_hook(self._save_input)
                        self._fwd_handles.append(handle)
                        handle = sub_mod.register_backward_hook(self._save_grad_output)
                        self._bwd_handles.append(handle)
                        params = [sub_mod.weight]
                        if sub_mod.bias is not None:
                            params.append(sub_mod.bias)
                        d = {'params': params, 'mod': sub_mod, 'layer_type': sub_mod_class}
                        self.params.append(d)
            
            # elif mod_class in ["LowRankLinear"]:
            #     # print("found lora layer")
            #     # print("attrs:", mod.__dict__)
            #     handle = mod.register_forward_pre_hook(self._save_input)
            #     self._fwd_handles.append(handle)
            #     handle = mod.register_backward_hook(self._save_grad_output)
            #     self._bwd_handles.append(handle)
            #     for p_name in ["lora_A", "lora_B"]:
            #         d = {'params': [getattr(mod, p_name)], 'mod': mod, 'layer_type': mod_class}
            #         self.params.append(d)
                

                # for sub_mod in mod.modules():
                #     sub_mod_class = sub_mod.__class__.__name__
                #     print(sub_mod_class)
                #     print("attrs:", sub_mod.__dict__)
                #     if sub_mod_class in ['Linear']:
                #         # print("register")
                #         handle = sub_mod.register_forward_pre_hook(self._save_input)
                #         self._fwd_handles.append(handle)
                #         handle = sub_mod.register_backward_hook(self._save_grad_output)
                #         self._bwd_handles.append(handle)
                #         params = [sub_mod.weight]
                #         if sub_mod.bias is not None:
                #             params.append(sub_mod.bias)
                #         d = {'params': params, 'mod': sub_mod, 'layer_type': sub_mod_class}
                #         self.params.append(d)

            # if mod_class in ['AdapterLayer']:
            #     print("found")
            #     mod = mod.modulelist
            #     print(mod)
            #     handle = mod.register_forward_pre_hook(self._save_input)
            #     self._fwd_handles.append(handle)
            #     handle = mod.register_backward_hook(self._save_grad_output)
            #     self._bwd_handles.append(handle)
            #     for sub_mod in mod.modules():
            #         sub_mod_class = sub_mod.__class__.__name__
            #         if sub_mod_class in ['Linear']:
            #             params = [sub_mod.weight]
            #             if sub_mod.bias is not None:
            #                 params.append(sub_mod.bias)
            #             d = {'params': params, 'mod': mod, 'layer_type': sub_mod.__class__.__name__}
            #             self.params.append(d)
        super(KFAC, self).__init__(self.params, {})

    def step(self, update_stats=True, update_params=True):
        """Performs one step of preconditioning."""
        # print(self.param_groups)
        # print("in step state length:", len(self.state))
        fisher_norm = 0.
        for group in self.param_groups:
            # Getting parameters
            if len(group['params']) == 2:
                weight, bias = group['params']
            else:
                weight = group['params'][0]
                bias = None
            state = self.state[weight]
            # Update convariances and inverses
            if update_stats:
                if self._iteration_counter % self.update_freq == 0:
                    self._compute_covs(group, state)
                    ixxt, iggt = self._inv_covs(state['xxt'], state['ggt'],
                                                state['num_locations'])
                    state['ixxt'] = ixxt
                    state['iggt'] = iggt
                else:
                    if self.alpha != 1:
                        self._compute_covs(group, state)
            if update_params:
                # Preconditionning
                gw, gb = self._precond(weight, bias, group, state)
                # Updating gradients
                if self.constraint_norm:
                    fisher_norm += (weight.grad * gw).sum()
                weight.grad.data = gw
                if bias is not None:
                    if self.constraint_norm:
                        fisher_norm += (bias.grad * gb).sum()
                    bias.grad.data = gb
            # Cleaning
            if 'x' in self.state[group['mod']]:
                del self.state[group['mod']]['x']
            if 'gy' in self.state[group['mod']]:
                del self.state[group['mod']]['gy']
        # Eventually scale the norm of the gradients
        if update_params and self.constraint_norm:
            scale = (1. / fisher_norm) ** 0.5
            for group in self.param_groups:
                for param in group['params']:
                    param.grad.data *= scale
        if update_stats:
            self._iteration_counter += 1

    def _save_input(self, mod, i):
        """Saves input of layer to compute covariance."""
        # if torch.cuda.current_device() == 0: # default 0
        # print("in hook")
        # mod_name = mod.__class__.__name__
        if mod.training:
            # if mod_name in ["LowRankLinear"]:
            #     for p_name in ["lora_A", "lora_B"]:
            #         d = {'params': [getattr(mod, p_name)], 'mod': mod, 'layer_type': mod_class}
            #         self.params.append(d)
            #     self.state[getattr(mod, p_name)]['x'] = i[0]
            # # print("save in state")
            self.state[mod]['x'] = i[0]


    def _save_grad_output(self, mod, grad_input, grad_output):
        """Saves grad on output of layer to compute covariance."""
        if mod.training:
            self.state[mod]['gy'] = grad_output[0] * grad_output[0].size(0)

    def _precond(self, weight, bias, group, state):
        """Applies preconditioning."""
        if group['layer_type'] == 'Conv2d' and self.sua:
            return self._precond_sua(weight, bias, group, state)
        ixxt = state['ixxt']
        iggt = state['iggt']
        g = weight.grad.data
        s = g.shape
        if group['layer_type'] == 'Conv2d':
            g = g.contiguous().view(s[0], s[1]*s[2]*s[3])
        if bias is not None:
            gb = bias.grad.data
            g = torch.cat([g, gb.view(gb.shape[0], 1)], dim=1)
        g = torch.mm(torch.mm(iggt, g), ixxt)
        if group['layer_type'] == 'Conv2d':
            g /= state['num_locations']
        if bias is not None:
            gb = g[:, -1].contiguous().view(*bias.shape)
            g = g[:, :-1]
        else:
            gb = None
        g = g.contiguous().view(*s)
        return g, gb

    def _precond_sua(self, weight, bias, group, state):
        """Preconditioning for KFAC SUA."""
        ixxt = state['ixxt']
        iggt = state['iggt']
        g = weight.grad.data
        s = g.shape
        g = g.permute(1, 0, 2, 3).contiguous()
        if bias is not None:
            gb = bias.grad.view(1, -1, 1, 1).expand(1, -1, s[2], s[3])
            g = torch.cat([g, gb], dim=0)
        g = torch.mm(ixxt, g.contiguous().view(-1, s[0]*s[2]*s[3]))
        g = g.view(-1, s[0], s[2], s[3]).permute(1, 0, 2, 3).contiguous()
        g = torch.mm(iggt, g.view(s[0], -1)).view(s[0], -1, s[2], s[3])
        g /= state['num_locations']
        if bias is not None:
            gb = g[:, -1, s[2]//2, s[3]//2]
            g = g[:, :-1]
        else:
            gb = None
        return g, gb

    def _compute_covs(self, group, state):

        """Computes the covariances."""
        mod = group['mod']
        if mod not in self.state:
            print(mod)
            print("mod not in self.state!")
        # print(self.state[group['mod']])
        x = self.state[group['mod']]['x']
        # print(self.state[group["mod"]].keys())
        gy = self.state[group['mod']]['gy']
        # Computation of xxt
        if group['layer_type'] == 'Conv2d':
            if not self.sua:
                x = F.unfold(x, mod.kernel_size, padding=mod.padding,
                             stride=mod.stride)
            else:
                x = x.view(x.shape[0], x.shape[1], -1)
            x = x.data.permute(1, 0, 2).contiguous().view(x.shape[1], -1)
        else:
            x = x.view(-1, x.size(-1))
            x = x.data.t()  
            "handle >2 dimensions"
        if hasattr(mod, "bias") and mod.bias is not None and group['layer_type'] == 'Linear':
            ones = torch.ones_like(x[:1])
            x = torch.cat([x, ones], dim=0)
        if self._iteration_counter == 0:
            state['xxt'] = torch.mm(x, x.t()) / float(x.shape[1])
        else:
            state['xxt'].addmm_(mat1=x, mat2=x.t(),
                                beta=(1. - self.alpha),
                                alpha=self.alpha / float(x.shape[1]))

        # Computation of ggt
        if group['layer_type'] == 'Conv2d':
            gy = gy.data.permute(1, 0, 2, 3)
            state['num_locations'] = gy.shape[2] * gy.shape[3]
            gy = gy.contiguous().view(gy.shape[0], -1)
        else:
            gy = gy.view(-1, gy.size(-1))
            "handle >2 dimensions"
            gy = gy.data.t()
            state['num_locations'] = 1
        if self._iteration_counter == 0:
            state['ggt'] = torch.mm(gy, gy.t()) / float(gy.shape[1])
        else:
            state['ggt'].addmm_(mat1=gy, mat2=gy.t(),
                                beta=(1. - self.alpha),
                                alpha=self.alpha / float(gy.shape[1]))

        # sycbn xxt, ggt
        if self.distributed:
            assert self.world_size > 1
            dist.all_reduce(tensor=state['xxt'], op=dist.ReduceOp.SUM)
            state['xxt'].div_(self.world_size)
            dist.all_reduce(tensor=state['ggt'], op=dist.ReduceOp.SUM)
            state['ggt'].div_(self.world_size)

    def _inv_covs(self, xxt, ggt, num_locations):
        """Inverses the covariances."""
        # Computes pi
        pi = 1.0
        if self.pi:
            tx = torch.trace(xxt) * ggt.shape[0]
            tg = torch.trace(ggt) * xxt.shape[0]
            pi = (tx / tg)
        # Regularizes and inverse
        eps = self.eps / num_locations
        diag_xxt = xxt.new(xxt.shape[0]).fill_((eps * pi) ** 0.5)
        diag_ggt = ggt.new(ggt.shape[0]).fill_((eps / pi) ** 0.5)
        ixxt = (xxt + torch.diag(diag_xxt)).inverse()
        iggt = (ggt + torch.diag(diag_ggt)).inverse()
        return ixxt, iggt

    def __del__(self):
        for handle in self._fwd_handles + self._bwd_handles:
            handle.remove()
