import torch
import torch.nn as nn
import torch.nn.functional as thF
from torch.autograd import grad
from torch.autograd.functional import jacobian
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts


class ArgminLayer(nn.Module):
    """
    c* = argmin_c F(c;x,y) = ||f_θ(x,c)-y||
    """

    def __init__(self,
                 model: nn.Module,
                 c_dim: int,
                 lr: float = 1e-3,
                 T_0: int = 32,
                 n_iter: int = 100,
                 ctx_lambda: float = 0.0001,
                 **kwargs):
        super().__init__()

        self.model = model  # parametrized func
        self.c_dim = c_dim

        self.step_size = lr
        self.T_0 = T_0
        self.steps = n_iter
        self.ctx_lambda = ctx_lambda

    def f(self, x, c):
        """
        :param x: [batch, n_ctx, x_dim]
        :param c: [batch, c_dim]
        :return:
        """
        return self.model(ctx=c, us=x)

    # def F(self, x, c, y):
    #     f = F.mse_loss(self.f(x, c), y)
    #     return f

    def argminF(self, x, y):
        with torch.set_grad_enabled(True):
            c = torch.zeros((x.shape[0], self.c_dim), requires_grad=True, device=x.device)
            opt = torch.optim.Adam([c], lr=self.step_size)
            scheduler = CosineAnnealingWarmRestarts(opt, T_0=self.T_0)
            for i in range(self.steps):
                penalty_weight = self.ctx_lambda
                penalty = penalty_weight * c.norm()

                opt.zero_grad()
                loss = thF.mse_loss(self.f(x, c), y)
                total_loss = loss + penalty
                total_loss.backward()
                opt.step()
                scheduler.step()

            info = {
                'inner_tot_loss': total_loss,
                'inner_ctx_norm': c.norm(),
                'inner_penalty_loss': penalty,
                'inner_ctx_loss': loss,
                'inner_penalty_weight': penalty_weight
            }

            return c.detach(), info

    def forward(self, x, y):
        F = lambda c: thF.mse_loss(self.f(x, c), y)

        # compute z*
        with torch.no_grad():
            z, loss = self.argminF(x, y)

        DzF = jacobian(F, z, vectorize=True, create_graph=True)
        z = z - DzF

        def get_grad_sum(z):
            grad_sum = grad(F(z), z, create_graph=True)[0].sum(0)
            return grad_sum

        z0 = z.detach().clone().requires_grad_(True)
        HzF = jacobian(get_grad_sum, z0, vectorize=True).swapaxes(0, 1)

        def modify_grad(y):
            g = torch.linalg.solve(HzF, y.unsqueeze(dim=-1))
            return g.squeeze(dim=-1)

        if x.requires_grad:
            z.register_hook(modify_grad)

        return z, loss
