from functools import reduce
import math

import torch
from torch.optim.optimizer import Optimizer


class CustomAdam(Optimizer):
    """
    General version of Adam-SRT that works for different normalization layer
    if specific channel options (channel_dims, channel_wise, channel_gloabal)
    are given.
    It should be used on parameters that are subject to scale invariance
    because they are followed by a normalization layer.
    Because not all params are concern, group_parameters of pytorch
    should be used.
    The effect is to adapt moments of Adam to the geometry implied by
    normalization layer. RT transform the order one moment ; make the
    order 2 moment rescaled and by norm.

    Example:
        >>> par_groups = [{'params': model.conv_params(), 'channel_wise'=True},
        >>>               {'params': model.other_params()}]
        >>> optimizer = AdamSRT(par_groups, lr=0.01, betas=(0.9, 0.9999))
        >>> optimizer.zero_grad()
        >>> loss_fn(model(input), target).backward()
        >>> optimizer.step()

    Arguments:
        params (list of dict or iterator): either a list of group_parameters
            which is a dict with key 'params' with a param iterator and other
            key for optimizer parameters overiding for this group of param,
        lr (float): learning rate,
        betas (tuple(float)): momentum factor for Adam moments,
        eps (float): float to avoid numerical instality in normalization,
        weight_decay (float): value for L2 regularization,
        channel_dims (list of int): the index of shape that represent the
            distinct dim that are independently normalized. Default value is
            channel_dims=shape which correspond to classic Adam.
            It can be used to adapt Adam to any normalization layers that
            follow conv layers,
        channel_wise (bool): if True and channel_dims is None set it to [0]
            which correspond to classic channel shape in 2D conv Network.
            Normalization will be done over other dims which are subject to
            scale invariance thanks to following normalization layer,
    """
    def __init__(
        self,
        params,
        lr=1e-3,
        betas=(0.9, 0.999),
        eps=1e-8,
        weight_decay=0,
        channel_dims=None,  # For customize the dimension for group of param
        channel_wise=False,  # For default conv followed by BN invariance
        standardize=False,
        transport=False,
        rescale=False,
        decouple=False,
        constrain=False
    ):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= weight_decay:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))

        defaults = dict(
            lr=lr,
            betas=betas,
            eps=eps,
            weight_decay=weight_decay,
            channel_dims=channel_dims,
            channel_wise=channel_wise,
            standardize=standardize,
            transport=transport,
            rescale=rescale,
            decouple=decouple,
            constrain=constrain,
        )
        super(CustomAdam, self).__init__(params, defaults)

    def step(self):
        """
        Performs a single optimizatino step
        """
        for group in self.param_groups:
            for p in group['params']:

                # Get grad of params to update
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError(
                        'Adam does not support sparse gradients'
                    )

                # Get state
                state = self.state[p]

                # State initialization if needed
                if len(state) == 0:
                    state['step'] = 0
                    # Exponential moving average of gradient values
                    state['exp_avg'] = torch.zeros_like(p.data)
                    # Exponential moving average of squared gradient values
                    state['exp_avg_sq'] = torch.zeros_like(p.data)
                    # Create the scalar product in respect of channel dims
                    shape = p.data.shape
                    channel_dims = group['channel_dims']
                    if channel_dims is None:
                        if group['channel_wise']:
                            # Classic meaning of channels
                            channel_dims = [0]
                        else:
                            # element wise : every element is a channel
                            # It correpond to classic adam update
                            channel_dims = list(range(len(shape)))
                    state['channel_dims'] = channel_dims
                    state['shape'] = shape

                # Start by increment step
                state['step'] += 1

                # Create the appropriate dot operator for the invar groups
                dot_ope = self.get_dot_operator(
                    state['channel_dims'], state['shape']
                )

                # If the group dim is more than one (no element-wise)
                # We have a spheric case
                spheric_case = (dot_ope.dim > 1)
                # To transport or rescale we need to do a manipulation at step
                step_manip = (
                    spheric_case and
                    (group['rescale'] or group['transport'] or group['constrain'])
                )
                # Get numerical helper
                eps = group['eps']

                if group['weight_decay'] != 0.:
                    if spheric_case and group['decouple']:
                        # Apply WD decouple from the rest of the step
                        p.data.mul_(1 - group['lr'] * group['weight_decay'])
                    else:
                        grad.add_(p.data, alpha=group['weight_decay'])

                # Retrive moments and constant
                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                b1, b2 = group['betas']

                # Update momentums
                if spheric_case and group['standardize']:
                    grad_sq = dot_ope(grad, grad)
                else:
                    grad_sq = grad * grad

                exp_avg.mul_(b1).add_(grad, alpha=1 - b1)
                exp_avg_sq.mul_(b2).add_(grad_sq, alpha=1 - b2)

                # Compute bias correction
                bias_correction1 = 1 - b1 ** state['step']
                bias_correction2 = 1 - b2 ** state['step']

                # Compute actual nom and denom for the step
                nom = exp_avg / bias_correction1
                denom = (
                    exp_avg_sq.sqrt() / math.sqrt(bias_correction2)
                ).add_(eps)

                # Prepare data temporary copy for manip
                if step_manip:
                    prev_data = p.data.clone().detach()

                # Take the step on the datas
                step_size = group['lr']
                if spheric_case and group['standardize'] and not group['constrain']:
                    step_size *= math.sqrt(dot_ope.dim)
                p.data.addcdiv_(nom, denom, value=-step_size)

                # We are on a sphere, we manipulate momentums and data
                if step_manip:
                    new_data = p.data.clone().detach()

                    prev_norm_sq = dot_ope(prev_data, prev_data)
                    prev_norm = torch.sqrt(prev_norm_sq)

                    new_norm_sq = dot_ope(new_data, new_data)
                    new_norm = torch.sqrt(new_norm_sq)

                    if group['transport']:
                        # mk = T(mk-1) + grad(uk)
                        # si <mk-1,uk-1>=0 et uk in span(mk-1,uk-1) :
                        # T(mk-1) = T(mk-1 - <mk-1,u1>u1)
                        #         = (u1 ^ mk-1) ^ u2
                        #         = -<mk-1,u2>u1 + <u1,u2>mk-1
                        #         = (-<mk-1,x2>x1 + <x1,x2>mk-1)/(|x1||x2|)
                        prev_unit = prev_data / (prev_norm + eps)
                        new_unit = new_data / (new_norm + eps)
                        scal_u1_u2 = dot_ope(prev_unit, new_unit)
                        scal_m_u2 = dot_ope(exp_avg.clone().detach(), new_unit)
                        (
                            exp_avg
                            .mul_(scal_u1_u2)
                            .add_(scal_m_u2 * prev_data)
                        )

                    if group['rescale']:
                        # mk = (rk-1/rk)mk-1 + grad(uk)
                        # vk = (rk-1/rk)^2vk-1 + ||grad||**2
                        (
                            exp_avg
                            .mul_(prev_norm)
                            .div_(new_norm + eps)
                        )
                        (
                            exp_avg_sq
                            .mul_(prev_norm_sq)
                            .div_(new_norm_sq + eps)
                        )

                    if group['constrain']:
                        # Normalize new weights
                        p.data.div_(new_norm + eps)

    @staticmethod
    def get_dot_operator(channel_dims, shape):
        """
        Generate a function that do scalar product for each channel dims
        Over the remaining dims
        """
        # Other dims are the ones of groups of elem for each channel
        grp_dims = list(set(range(len(shape))) - set(channel_dims))

        # Compute shape and size
        channel_shape = [shape[i] for i in channel_dims]
        grp_shape = [shape[i] for i in grp_dims]
        channel_size = reduce(lambda x, y: x * y, [1] + channel_shape)
        grp_size = reduce(lambda x, y: x * y, [1] + grp_shape)

        # Prepare the permutation to ordonate dims and its reciproc
        perm = channel_dims + grp_dims
        antiperm = [
            e[1]
            for e in sorted([(j, i) for i, j in enumerate(perm)])
        ]

        # Prepare index query that retrieve all dimensions
        slice_len = max(len(channel_shape), 1)
        idx = [slice(None)] * slice_len + [None] * (len(shape) - slice_len)

        # Define the scalar product channel wise over grp dims
        # Output have is extend to fit initial shape
        def scalar_product(tensor1, tensor2):
            return (
                (tensor1 * tensor2)
                .permute(perm)  # permute as chan_dims, grp_dims
                .contiguous()
                .view(channel_size, grp_size)  # view as 2 dims tensor
                .sum(dim=1)  # norm over group dims to have scalar
                .view(*(channel_shape if channel_shape else [-1]))
                [idx]  # restore channel shape and extend on grp dims
                .permute(antiperm)  # Reverse permute to retrieve shape
                .contiguous()
            )
        scalar_product.dim = grp_size
        return scalar_product


class AdamA(CustomAdam):
    def __init__(
        self,
        params,
        lr=1e-3,
        betas=(0.9, 0.999),
        eps=1e-8,
        weight_decay=0,
        channel_dims=None,  # For customize the dimension for group of param
        channel_wise=False,  # For default conv followed by BN invariance
    ):
        super(AdamA, self).__init__(
            params,
            lr=lr,
            betas=betas,
            eps=eps,
            weight_decay=weight_decay,
            channel_dims=channel_dims,
            channel_wise=channel_wise,
            standardize=True,
            transport=False,
            rescale=False,
            decouple=False,
            constrain=False,
        )


class AdamAB(CustomAdam):
    def __init__(
        self,
        params,
        lr=1e-3,
        betas=(0.9, 0.999),
        eps=1e-8,
        weight_decay=0,
        channel_dims=None,  # For customize the dimension for group of param
        channel_wise=False,  # For default conv followed by BN invariance
    ):
        super(AdamAB, self).__init__(
            params,
            lr=lr,
            betas=betas,
            eps=eps,
            weight_decay=weight_decay,
            channel_dims=channel_dims,
            channel_wise=channel_wise,
            standardize=True,
            transport=True,
            rescale=False,
            decouple=False,
            constrain=False,
        )


class AdamABC(CustomAdam):
    def __init__(
        self,
        params,
        lr=1e-3,
        betas=(0.9, 0.999),
        eps=1e-8,
        weight_decay=0,
        channel_dims=None,  # For customize the dimension for group of param
        channel_wise=False,  # For default conv followed by BN invariance
    ):
        super(AdamABC, self).__init__(
            params,
            lr=lr,
            betas=betas,
            eps=eps,
            weight_decay=weight_decay,
            channel_dims=channel_dims,
            channel_wise=channel_wise,
            standardize=True,
            transport=True,
            rescale=True,
            decouple=False,
            constrain=False,
        )


class AdamW(CustomAdam):
    def __init__(
        self,
        params,
        lr=1e-3,
        betas=(0.9, 0.999),
        eps=1e-8,
        weight_decay=0,
        channel_dims=None,  # For customize the dimension for group of param
        channel_wise=False,  # For default conv followed by BN invariance
    ):
        super(AdamW, self).__init__(
            params,
            lr=lr,
            betas=betas,
            eps=eps,
            weight_decay=weight_decay,
            channel_dims=channel_dims,
            channel_wise=channel_wise,
            standardize=False,
            transport=False,
            rescale=False,
            decouple=True,
            constrain=False,
        )


class AdamG(CustomAdam):
    def __init__(
        self,
        params,
        lr=1e-3,
        betas=(0.9, 0.999),
        eps=1e-8,
        weight_decay=0,
        channel_dims=None,  # For customize the dimension for group of param
        channel_wise=False,  # For default conv followed by BN invariance
    ):
        super(AdamG, self).__init__(
            params,
            lr=lr,
            betas=betas,
            eps=eps,
            weight_decay=weight_decay,
            channel_dims=channel_dims,
            channel_wise=channel_wise,
            standardize=True,
            transport=True,
            rescale=False,
            decouple=False,
            constrain=True,
        )
