import torch
# from timm.optim import create_optimizer
from timm.utils import NativeScaler
import torch.optim as optim
# from timm.optim.optim_factory import add_weight_decay

class EMSA_optimizer():
    def __init__(self, all_blocks=None, args=None) -> None:
        self.rho = args.emsa_rho
        self.delta = args.delta
        self.args = args
        self.scaler_list = []
        self.all_blocks = all_blocks
        for i, block in enumerate(all_blocks):
            self.scaler_list.append(torch.cuda.amp.GradScaler(init_scale=1.0))

    # @staticmethod
    def compute_x(self, x):
        x_list = [x]
        blocks = self.all_blocks
        for i, block in enumerate(blocks):
            x = block(x)
            if self.args.model == 'resnet26' and i == len(blocks) - 2:
                x = torch.nn.functional.avg_pool2d(x, 8)
                x = x.view(-1, 640)
            if self.args.model == 'resnet50' and i == len(blocks) - 2:
                x = torch.flatten(x, 1)
            x_list.append(x)
        return x, x_list

    @staticmethod
    def compute_p(loss, x_array):
        # TODO: Check what form of L(v) can be used
        p_list = []
        p_N = - torch.autograd.grad(loss, x_array[-1], retain_graph=True)[0].detach().clone()
        p_previous = p_N
        p_list.append(p_previous.detach().clone())

        for i in reversed(list(range(0, len(x_array)-1))):
            if len(p_previous.shape) == 3:
                H = torch.einsum("bnd, bnd -> b", p_previous.detach().clone(), x_array[i+1])
            elif len(p_previous.shape) == 2:
                H = torch.einsum("bd, bd -> b", p_previous.detach().clone(), x_array[i+1])
            elif len(p_previous.shape) == 4:
                H = torch.einsum("bchw, bchw -> b", p_previous.detach().clone(), x_array[i+1])
            else:
                raise Exception
            p = torch.autograd.grad(H.sum(), x_array[i], retain_graph=True)[0]
            p_list.append(p.detach().clone())
            p_previous = p
        p_list.reverse()
        # H_list.()
        return p_list
        # for i in range(len(x_array)):
        #     p = - torch.autograd.grad(loss, x_array[i], retain_graph=True)[0].detach().clone()
        #     p_list.append(p)
        # return p_list

    @staticmethod
    def compute_H(x, p, block):
        g = block(x)
        if len(p.shape) == 3:
            H = torch.einsum("bnd, bnd -> b", p.detach().clone(), g)
        elif len(p.shape) == 2:
            H = torch.einsum("bd, bd -> b", p.detach().clone(), g)
        elif len(p.shape) == 4:
            H = torch.einsum("bchw, bchw -> b", p.detach().clone(), g)
        else:
            raise Exception
        return H, g
    

    @staticmethod
    def multiply(a, b):
        r = torch.einsum("m..., n... -> mn", a, b)
        return r
    
    

    def update_model(self, x_array, p_list, model, layerwise, wandb = None):
        optimizer_list = self.create_optimizers(model, layerwise)

        for i, optimizer in enumerate(optimizer_list):
            if (layerwise != -1 and i != layerwise) or optimizer is None:
                continue
            # Zero grad all parameters
            for optimizer_temp in optimizer_list:
                if optimizer_temp is not None:
                    optimizer_temp.zero_grad()

            # Enable grad of the current layer
            for param_group in optimizer.param_groups:
                for param in param_group["params"]:
                    param.requires_grad = True

            # Disable grad of other layers
            for optimizer_temp in optimizer_list:
                if optimizer_temp is not optimizer and optimizer_temp is not None:
                    for param_group in optimizer_temp.param_groups:
                        for param in param_group["params"]:
                            param.requires_grad = False

            for j in range(self.args.emsa_iter):
                optimizer.zero_grad()
                x = x_array[i].detach().clone()
                x.requires_grad = True
                # with torch.cuda.amp.autocast():
                H, g = self.compute_H(x, p_list[i+1].detach().clone(), self.all_blocks[i])
                dH_dx = torch.autograd.grad(H.sum(), x, create_graph=True, retain_graph=True, allow_unused=True)[0]
                loss_H = - H
                loss_L2 = torch.tensor(0.)
                loss_L2 = loss_L2.to(self.args.device, non_blocking=True)
                for param_group in optimizer.param_groups:
                    for param in param_group["params"]:
                        loss_L2 += torch.norm(param, p=2)
                loss_rho = 1/2 * self.rho * self.norm_square(x_array[i+1].detach().clone() - g)
                loss_rho += 1/2 * self.rho * self.norm_square(p_list[i].detach().clone() - dH_dx)
                loss = (loss_H + loss_rho).sum() + self.delta * loss_L2
                if self.args.h_plot:
                    wandb.log({"Hamiltonian": loss_H.sum().item(), 
                               "H_aug": loss.item()})
                    # Log loss_H and loss_L2
                    # E-MSA is useful for first few iterations
                    # Cifar10 finetuning -> Multitask finetuning
                    # 
                # self.scaler_list[i].scale(loss).backward()
                # self.scaler_list[i].step(optimizer)
                # self.scaler_list[i].update()
                loss.backward()
                optimizer.step()


        # Enable grad of all layers for next batch
        for param in model.parameters():
            param.requires_grad = True

    def create_optimizers(self, model, layerwise):
        optimizer_list = []
        for i, block in enumerate(self.all_blocks):
            if (layerwise != -1 and i != layerwise) or (self.args.model == 'resnet26' and i == 3): 
                optimizer_list.append(None)
            elif len(list(block.parameters()))!=0:
                optimizer_list.append(self.create_optimizer(self.args, block))
            # optimizer_list.append(optim.LBFGS(block.parameters(), history_size=10,
            #                                   max_iter=4, line_search_fn="strong_wolfe"))
            else:
                optimizer_list.append(None)
        return optimizer_list
    
    @staticmethod
    def create_optimizer(args, model):
        # if args.weight_decay:
        #     skip = {}
        #     if hasattr(model, 'no_weight_decay'):
        #         skip = model.no_weight_decay()
        #     # parameters = add_weight_decay(model, args.weight_decay, skip)
        #     parameters = model.parameters()
        #     weight_decay = 0.
        # else:
        parameters = model.parameters()
        optimizer = optim.Adam(parameters, lr=args.emsa_subopt_lr,
                            #    eps=args.emsa_subopt_eps
                               )
        return optimizer

    @staticmethod
    def norm_square(input, p="fro"):
        if len(input.shape) == 2:
            return (input ** 2).mean(dim=[1])
        elif len(input.shape) == 3:
            return (input ** 2).mean(dim=[1, 2])
        elif len(input.shape) == 4:
            if p=="fro":
                return (input ** 2).mean(dim=[1, 2, 3])
            else:
                raise Exception
        else:
            raise Exception
