import torch
import torch.nn as nn
import torch.nn.functional as F


class iDSPN(nn.Module):
    def __init__(self, objective, optim_f, optim_iters, set_channels, set_size, use_starting_set=False, grad_clip=None, projection=None):
        super().__init__()
        self.objective = objective
        self.iters = optim_iters
        self.optim_f = optim_f
        self.set_channels = set_channels
        self.set_size = set_size
        self.projection = projection
        self.grad_clip = grad_clip
        if use_starting_set:
            self.starting_set = nn.Parameter(0.1 * torch.randn(1, self.set_size, self.set_channels))
        else:
            self.starting_set = None
    
    @torch.enable_grad()
    def forward(self, z, set=None):
        if set is None:
            if self.starting_set is not None:
                set = self.starting_set.expand(z.size(0), -1, -1)
            else:
                set = 0.1 * torch.randn(z.size(0), self.set_size, self.set_channels, device=z.device)
        return ImplicitSolver.apply(self.optim_f, self.iters, self.objective, self.grad_clip, self.projection, z, set, *self.objective.parameters())


class Objective(nn.Module):
    def forward(self, target_repr, set, reference_set=None):
        raise NotImplementedError


class MSEObjective(Objective):
    def __init__(self, encoder, regularized=False):
        super().__init__()
        self.encoder = encoder
        self.regularized = regularized
    
    def forward(self, target_repr, set, reference_set=None):
        # compute representation of current set
        predicted_repr = self.encoder(set)
        # how well does the representation match the target
        repr_loss = 0.5 * F.mse_loss(
            predicted_repr, target_repr, reduction='none'
        ).sum(dim=0).mean()

        if self.regularized:
            regularizer = 0.5 * F.mse_loss(set, reference_set, reduction='none').sum(dim=0).mean()
            repr_loss = repr_loss + 0.1 * regularizer

        return repr_loss


class ImplicitSolver(torch.autograd.Function):
    @staticmethod
    def forward(ctx, optim_f, iters, objective_function, grad_clip, projection, target_repr, starting_set, *params):
        # make sure that all parameters passed are used in the computation graph
        # otherwise, you have to set allow_unused=True in the autograd.grad call in backwards 

        # if regularization is used in the objective, assumes that the set to start with is the one to regularize with 
        # this doesn't hold when dspn iters are split into multiple forwards and the intention is to regularize wrt idspn.starting_set
        set = starting_set
        if projection is not None:
            set.data = projection(set.data)
        set = set.clone().detach().requires_grad_(True)

        optimizer = optim_f([set])
        with torch.enable_grad():
            # iterate n - 1 steps
            for i in range(iters - 1):
                loss = objective_function(target_repr, set, starting_set)
                set.grad, = torch.autograd.grad(loss, set)
                set.grad = clip_gradient(set.grad, max_norm=grad_clip)
                optimizer.step()
                optimizer.zero_grad()
                if projection is not None:
                    set.data = projection(set.data)
            # iterate last step
            # we don't want the optimizer to override our set with in-place modifications, so we do this one separately
            set0 = set.clone().detach().requires_grad_(True)
            loss = objective_function(target_repr, set0, starting_set)
            set_grad, = torch.autograd.grad(loss, set0, create_graph=True)
            set.grad = clip_gradient(set_grad.clone(), max_norm=grad_clip)
            optimizer.step()
            optimizer.zero_grad()
            if projection is not None:
                set.data = projection(set.data)
                set_grad = set0 - projection(set0 - set_grad)

        ctx.save_for_backward(set0, set_grad, target_repr, starting_set, *params)
        return set, set_grad.clone()
    
    @staticmethod
    def backward(ctx, output_grad, set_grad_grad):
        set, set_grad, target_repr, starting_set, *params = ctx.saved_tensors

        total_grad = set_grad_grad - output_grad
        
        inputs = [target_repr, starting_set] + params  # important to have the same order as given to forward
        # only need to differentiate wrt inputs that actually require grads
        inputs_to_differentiate = [(input, i) for i, (input, needs_grad) in enumerate(zip(inputs, ctx.needs_input_grad[5:])) if needs_grad]

        with torch.enable_grad():
            # u = conjugate_gradient(set_grad, total_grad, set)  # use this line instead for conjugate gradient approach
            u = total_grad
            # in certain cases with a starting set, the following line needs retain_graph=True added
            grads = torch.autograd.grad(set_grad, [x[0] for x in inputs_to_differentiate], u)

        # always need to return something for all inputs, so put the grads back in their corresponding position
        padded_grads = [None for _ in range(len(inputs))]
        for g, (_, i) in zip(grads, inputs_to_differentiate):
            padded_grads[i] = g

        return None, None, None, None, None, *padded_grads


def clip_gradient(grads, norm_type=2., max_norm=2.):
    if max_norm is None:
        return grads
    grad_norm = grads.detach().norm(norm_type, dim=list(range(1, grads.ndim)), keepdim=True)
    clip_coef = max_norm / (grad_norm + 1e-6)
    clip_coef = clip_coef.clamp(0., 1.)
    grads = grads * clip_coef
    return grads


def hvp(set_grad, x, set, preconditioner=None):
    result = torch.autograd.grad(set_grad, set, retain_graph=True, grad_outputs=x)[0]
    # precondition with identity matrix
    if preconditioner is None:
        return result
    else:
        return result / preconditioner + x


def conjugate_gradient(in_grad, outer_grad, set, cg_iters=3, preconditioner=100):
    x = outer_grad.clone().detach()
    r = outer_grad.clone().detach() - hvp(in_grad, x, set, preconditioner=preconditioner)
    p = r.clone().detach()
    bdot = lambda a, b: torch.einsum('nsc, nsc -> n', a, b).clamp(min=1e-37)
    for i in range(cg_iters):
        Ap = hvp(in_grad, p, set, preconditioner=preconditioner)
        alpha = bdot(r, r) / bdot(p, Ap)
        alpha = alpha.unsqueeze(1).unsqueeze(2)
        x = x + alpha * p
        r_new = r - alpha * Ap
        beta = bdot(r_new, r_new) / bdot(r, r)
        beta = beta.unsqueeze(1).unsqueeze(2)
        p = r_new + beta * p
        r = r_new.clone().detach()
    return x


class ProjectSimplex(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        dim=2
        x_proj = projection_unit_simplex(x, dim=dim)
        ctx.save_for_backward(x, x_proj)
        ctx.dim = dim
        return x_proj

    @staticmethod
    def backward(ctx, x_proj_grad):
        x, x_proj = ctx.saved_tensors
        out_grad = projection_unit_simplex_jvp(x, x_proj, x_proj_grad, ctx.dim)
        return out_grad


def unsqueeze_like(x, target, match_dim):
    shape = [1]*target.ndim
    shape[match_dim] = -1
    return x.reshape(*shape)


def batched_idx(idx, dim):
    set_dim = 1
    bid = torch.arange(idx.size(0), device=idx.device).repeat_interleave(idx.size(set_dim))
    sid = torch.arange(idx.size(set_dim), device=idx.device).repeat(idx.size(0))
    ret = [bid, None, None]
    ret[dim] = idx.flatten()
    ret[2 if dim==1 else 1] = sid.flatten()
    return ret


def projection_unit_simplex(x, dim):
    s = 1.0
    n_features = x.shape[dim]
    u, _ = torch.sort(x, dim=dim, descending=True)
    cssv = torch.cumsum(u, dim=dim) - s
    ind = torch.arange(n_features, device=x.device) + 1
    cond = u - cssv / unsqueeze_like(ind, cssv, dim) > 0
    idx = torch.count_nonzero(cond, dim=dim)
    threshold = cssv[batched_idx(idx - 1, dim=dim)].reshape(idx.shape) / idx.to(x.dtype)
    return torch.relu(x - threshold.unsqueeze(dim))


def projection_unit_simplex_jvp(x, x_proj, x_proj_grad, dim):
    supp = x_proj > 0
    card = torch.count_nonzero(supp, dim=dim).unsqueeze(dim)
    supp = supp.to(x_proj_grad.dtype)
    prod = supp * x_proj_grad
    tangent_out = prod - (prod.sum(dim=dim, keepdim=True) / card) * supp
    return tangent_out


def clevr_project(x, dim=2):
    sizes, proj_fun = zip(
        *[(3, lambda x: x),
          (2, ProjectSimplex.apply),
          (8, ProjectSimplex.apply),
          (3, ProjectSimplex.apply),
          (2, ProjectSimplex.apply),
          (1, lambda x: x)])
    splits = x.split(sizes, dim=dim)
    p = torch.cat([f(s) for s, f in zip(splits, proj_fun)], dim=dim)
    return p