import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from packaging import version
from PIL import Image


def act_str_to_act_cfg_dict(act: str, negative_slope: float = 0.1):
    if act == "leaky" or act == "leakyrelu":
        return dict(type="leakyrelu", negative_slope=negative_slope)
    elif act is None:
        return dict(type="identity")
    else:
        return dict(type=act)

def get_activation(act, inplace=False):
    """

    Parameters
    ----------
    act
        Name of the activation
    inplace
        Whether to perform inplace activation

    Returns
    -------
    activation_layer
        The activation
    """
    if act is None:
        return lambda x: x
    if isinstance(act, str):
        if act == 'leaky':
            return nn.LeakyReLU(0.1, inplace=inplace)
        elif act == 'identity':
            return nn.Identity()
        elif act == 'elu':
            return nn.ELU(inplace=inplace)
        elif act == 'gelu':
            return nn.GELU()
        elif act == 'relu':
            return nn.ReLU()
        elif act == 'sigmoid':
            return nn.Sigmoid()
        elif act == 'tanh':
            return nn.Tanh()
        elif act == 'softrelu' or act == 'softplus':
            return nn.Softplus()
        elif act == 'softsign':
            return nn.Softsign()
        else:
            raise NotImplementedError('act="{}" is not supported. '
                                      'Try to include it if you can find that in '
                                      'https://pytorch.org/docs/stable/nn.html'.format(act))
    else:
        return act

class RMSNorm(nn.Module):
    def __init__(self, d, p=-1., eps=1e-8, bias=False):
        """Root Mean Square Layer Normalization proposed in "[NeurIPS2019] Root Mean Square Layer Normalization"

        Parameters
        ----------
        d
            model size
        p
            partial RMSNorm, valid value [0, 1], default -1.0 (disabled)
        eps
            epsilon value, default 1e-8
        bias
            whether use bias term for RMSNorm, disabled by
            default because RMSNorm doesn't enforce re-centering invariance.
        """
        super(RMSNorm, self).__init__()

        self.eps = eps
        self.d = d
        self.p = p
        self.bias = bias

        self.scale = nn.Parameter(torch.ones(d))
        self.register_parameter("scale", self.scale)

        if self.bias:
            self.offset = nn.Parameter(torch.zeros(d))
            self.register_parameter("offset", self.offset)

    def forward(self, x):
        if self.p < 0. or self.p > 1.:
            norm_x = x.norm(2, dim=-1, keepdim=True)
            d_x = self.d
        else:
            partial_size = int(self.d * self.p)
            partial_x, _ = torch.split(x, [partial_size, self.d - partial_size], dim=-1)

            norm_x = partial_x.norm(2, dim=-1, keepdim=True)
            d_x = partial_size

        rms_x = norm_x * d_x ** (-1. / 2)
        x_normed = x / (rms_x + self.eps)

        if self.bias:
            return self.scale * x_normed + self.offset

        return self.scale * x_normed

def get_norm_layer(normalization: str = 'layer_norm',
                   axis: int = -1,
                   epsilon: float = 1e-5,
                   in_channels: int = 0, **kwargs):
    """Get the normalization layer based on the provided type

    Parameters
    ----------
    normalization
        The type of the layer normalization from ['layer_norm']
    axis
        The axis to normalize the
    epsilon
        The epsilon of the normalization layer
    in_channels
        Input channel

    Returns
    -------
    norm_layer
        The layer normalization layer
    """
    if isinstance(normalization, str):
        if normalization == 'layer_norm':
            assert in_channels > 0
            assert axis == -1
            norm_layer = nn.LayerNorm(normalized_shape=in_channels, eps=epsilon, **kwargs)
        elif normalization == 'rms_norm':
            assert axis == -1
            norm_layer = RMSNorm(d=in_channels, eps=epsilon, **kwargs)
        else:
            raise NotImplementedError('normalization={} is not supported'.format(normalization))
        return norm_layer
    elif normalization is None:
        return nn.Identity()
    else:
        raise NotImplementedError('The type of normalization must be str')


def _generalize_padding(x, pad_t, pad_h, pad_w, padding_type, t_pad_left=False):
    """

    Parameters
    ----------
    x
        Shape (B, T, H, W, C)
    pad_t
    pad_h
    pad_w
    padding_type
    t_pad_left

    Returns
    -------
    out
        The result after padding the x. Shape will be (B, T + pad_t, H + pad_h, W + pad_w, C)
    """
    if pad_t == 0 and pad_h == 0 and pad_w == 0:
        return x

    assert padding_type in ['zeros', 'ignore', 'nearest']
    B, T, H, W, C = x.shape

    if padding_type == 'nearest':
        return F.interpolate(x.permute(0, 4, 1, 2, 3), size=(T + pad_t, H + pad_h, W + pad_w)).permute(0, 2, 3, 4, 1)
    else:
        if t_pad_left:
            return F.pad(x, (0, 0, 0, pad_w, 0, pad_h, pad_t, 0))
        else:
            return F.pad(x, (0, 0, 0, pad_w, 0, pad_h, 0, pad_t))


def _generalize_unpadding(x, pad_t, pad_h, pad_w, padding_type):
    assert padding_type in['zeros', 'ignore', 'nearest']
    B, T, H, W, C = x.shape
    if pad_t == 0 and pad_h == 0 and pad_w == 0:
        return x

    if padding_type == 'nearest':
        return F.interpolate(x.permute(0, 4, 1, 2, 3), size=(T - pad_t, H - pad_h, W - pad_w)).permute(0, 2, 3, 4, 1)
    else:
        return x[:, :(T - pad_t), :(H - pad_h), :(W - pad_w), :].contiguous()

def apply_initialization(m,
                         linear_mode="0",
                         conv_mode="0",
                         norm_mode="0"):
    if isinstance(m, nn.Linear):

        if linear_mode in ("0", ):
            nn.init.kaiming_normal_(m.weight,
                                    mode='fan_in', nonlinearity="linear")
        elif linear_mode in ("1", ):
            nn.init.kaiming_normal_(m.weight,
                                    a=0.1,
                                    mode='fan_out',
                                    nonlinearity="leaky_relu")
        else:
            raise NotImplementedError
        if hasattr(m, 'bias') and m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, (nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
        if conv_mode in ("0", ):
            nn.init.kaiming_normal_(m.weight,
                                    a=0.1,
                                    mode='fan_out',
                                    nonlinearity="leaky_relu")
        else:
            raise NotImplementedError
        if hasattr(m, 'bias') and m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.LayerNorm):
        if norm_mode in ("0", ):
            if m.elementwise_affine:
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
        else:
            raise NotImplementedError
    elif isinstance(m, nn.GroupNorm):
        if norm_mode in ("0", ):
            if m.affine:
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
        else:
            raise NotImplementedError
    # # pos_embed already initialized when created
    elif isinstance(m, nn.Embedding):
        nn.init.trunc_normal_(m.weight.data, std=0.02)

if version.parse(torch.__version__) >= version.parse('1.10.0'):
    from torch.optim.lr_scheduler import SequentialLR
else:
    from torch.optim.lr_scheduler import _LRScheduler
    from bisect import bisect_right

    class SequentialLR(_LRScheduler):
        """Receives the list of schedulers that is expected to be called sequentially during
        optimization process and milestone points that provides exact intervals to reflect
        which scheduler is supposed to be called at a given epoch.

        Args:
            schedulers (list): List of chained schedulers.
            milestones (list): List of integers that reflects milestone points.

        Example:
            >>> # Assuming optimizer uses lr = 1. for all groups
            >>> # lr = 0.1     if epoch == 0
            >>> # lr = 0.1     if epoch == 1
            >>> # lr = 0.9     if epoch == 2
            >>> # lr = 0.81    if epoch == 3
            >>> # lr = 0.729   if epoch == 4
            >>> scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2)
            >>> scheduler2 = ExponentialLR(self.opt, gamma=0.9)
            >>> scheduler = SequentialLR(self.opt, schedulers=[scheduler1, scheduler2], milestones=[2])
            >>> for epoch in range(100):
            >>>     train(...)
            >>>     validate(...)
            >>>     scheduler.step()
        """

        def __init__(self, optimizer, schedulers, milestones, last_epoch=-1, verbose=False):
            for scheduler_idx in range(1, len(schedulers)):
                if (schedulers[scheduler_idx].optimizer != schedulers[0].optimizer):
                    raise ValueError(
                        "Sequential Schedulers expects all schedulers to belong to the same optimizer, but "
                        "got schedulers at index {} and {} to be different".format(0, scheduler_idx)
                    )
            if (len(milestones) != len(schedulers) - 1):
                raise ValueError(
                    "Sequential Schedulers expects number of schedulers provided to be one more "
                    "than the number of milestone points, but got number of schedulers {} and the "
                    "number of milestones to be equal to {}".format(len(schedulers), len(milestones))
                )
            self.optimizer = optimizer
            self._schedulers = schedulers
            self._milestones = milestones
            self.last_epoch = last_epoch + 1

        def step(self):
            self.last_epoch += 1
            idx = bisect_right(self._milestones, self.last_epoch)
            if idx > 0 and self._milestones[idx - 1] == self.last_epoch:
                self._schedulers[idx].step(0)
            else:
                self._schedulers[idx].step()

        def state_dict(self):
            """Returns the state of the scheduler as a :class:`dict`.

            It contains an entry for every variable in self.__dict__ which
            is not the optimizer.
            The wrapped scheduler states will also be saved.
            """
            state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', '_schedulers')}
            state_dict['_schedulers'] = [None] * len(self._schedulers)

            for idx, s in enumerate(self._schedulers):
                state_dict['_schedulers'][idx] = s.state_dict()

            return state_dict

        def load_state_dict(self, state_dict):
            """Loads the schedulers state.

            Args:
                state_dict (dict): scheduler state. Should be an object returned
                    from a call to :meth:`state_dict`.
            """
            _schedulers = state_dict.pop('_schedulers')
            self.__dict__.update(state_dict)
            # Restore state_dict keys in order to prevent side effects
            # https://github.com/pytorch/pytorch/issues/32756
            state_dict['_schedulers'] = _schedulers

            for idx, s in enumerate(_schedulers):
                self._schedulers[idx].load_state_dict(s)

def warmup_lambda(warmup_steps, min_lr_ratio=0.1):
    def ret_lambda(epoch):
        if epoch <= warmup_steps:
            return min_lr_ratio + (1.0 - min_lr_ratio) * epoch / warmup_steps
        else:
            return 1.0
    return ret_lambda

def get_parameter_names(model, forbidden_layer_types):
    """
    Returns the names of the model parameters that are not inside a forbidden layer.

    Borrowed from https://github.com/huggingface/transformers/blob/623b4f7c63f60cce917677ee704d6c93ee960b4b/src/transformers/trainer_pt_utils.py#L996
    """
    result = []
    for name, child in model.named_children():
        result += [
            f"{name}.{n}"
            for n in get_parameter_names(child, forbidden_layer_types)
            if not isinstance(child, tuple(forbidden_layer_types))
        ]
    # Add model specific parameters (defined with nn.Parameter) since they are not in any child.
    result += list(model._parameters.keys())
    return result

def layout_to_in_out_slice(layout, in_len, out_len=None):
    from copy import deepcopy
    t_axis = layout.find("T")
    num_axes = len(layout)
    in_slice = [slice(None, None), ] * num_axes
    out_slice = deepcopy(in_slice)
    in_slice[t_axis] = slice(None, in_len)
    if out_len is None:
        out_slice[t_axis] = slice(in_len, None)
    else:
        out_slice[t_axis] = slice(-out_len, None)
    return in_slice, out_slice

def save_gif(single_seq, fname):
    """Save a single gif consisting of image sequence in single_seq to fname."""
    img_seq = [Image.fromarray(img.astype(np.float32) * 255, 'F').convert("L") for img in single_seq]
    img = img_seq[0]
    img.save(fname, save_all=True, append_images=img_seq[1:])

def change_layout_np(data, in_layout='NHWT', out_layout='NHWT', ret_contiguous=False):
    # first convert to 'NHWT'
    if in_layout == 'NHWT':
        pass
    elif in_layout == 'NTHWC':
        data = data[:, :, :, :, 0]
        data = np.transpose(data,
                            axes=(0, 2, 3, 1))
    else:
        raise NotImplementedError

    if out_layout == 'NHWT':
        pass
    else:
        raise NotImplementedError
    if ret_contiguous:
        data = data.ascontiguousarray()
    return data

def save_example_vis_results(
        save_dir, save_prefix,
        in_seq, target_seq,
        pred_seq, label,
        layout='NHWT', idx=0,
        plot_stride=1, fs=10, norm="none"):
    import os
    from matplotlib import pyplot as plt
    """
    Parameters
    ----------
    in_seq: np.array
        float value 0-1
    target_seq: np.array
        float value 0-1
    pred_seq:   np.array
        float value 0-1
    """
    in_seq = change_layout_np(in_seq, in_layout=layout).astype(np.float32)
    target_seq = change_layout_np(target_seq, in_layout=layout).astype(np.float32)
    if isinstance(pred_seq, list):
        pred_seq_list = [change_layout_np(ele, in_layout=layout).astype(np.float32)
                         for ele in pred_seq]
        assert isinstance(label, list) and len(label) == len(pred_seq)
    else:
        pred_seq_list = [change_layout_np(pred_seq, in_layout=layout).astype(np.float32), ]
        label_list = [label, ]
    fig_path = os.path.join(save_dir, f'{save_prefix}.png')
    if norm == "none":
        norm = {'scale': 1.0,
                'shift': 0.0}
    elif norm == "to255":
        norm = {'scale': 255,
                'shift': 0}
    else:
        raise NotImplementedError
    in_len = in_seq.shape[-1]
    out_len = target_seq.shape[-1]
    max_len = max(in_len, out_len)
    ncols = (max_len - 1) // plot_stride + 1
    fig, ax = plt.subplots(nrows=2 + len(pred_seq_list),
                           ncols=ncols,
                           figsize=(24, 8))

    ax[0][0].set_ylabel('Inputs\n', fontsize=fs)
    for i in range(0, max_len, plot_stride):
        if i < in_len:
            xt = in_seq[idx, :, :, i] * norm['scale'] + norm['shift']
            ax[0][i // plot_stride].imshow(xt, cmap='gray')
        else:
            ax[0][i // plot_stride].axis('off')

    ax[1][0].set_ylabel('Target\n', fontsize=fs)
    for i in range(0, max_len, plot_stride):
        if i < out_len:
            xt = target_seq[idx, :, :, i] * norm['scale'] + norm['shift']
            ax[1][i // plot_stride].imshow(xt, cmap='gray')
        else:
            ax[1][i // plot_stride].axis('off')

    y_preds = [pred_seq[idx:idx + 1] * norm['scale'] + norm['shift']
               for pred_seq in pred_seq_list]

    # Plot model predictions
    for k in range(len(pred_seq_list)):
        for i in range(0, max_len, plot_stride):
            if i < out_len:
                ax[2 + k][i // plot_stride].imshow(y_preds[k][0, :, :, i], cmap='gray')
            else:
                ax[2 + k][i // plot_stride].axis('off')

        # ax[2 + k][0].set_ylabel(label_list[k] + '\nPrediction', fontsize=fs)
        ax[2 + k][0].set_ylabel(label_list[k], fontsize=fs)

    for i in range(0, max_len, plot_stride):
        if i < out_len:
            ax[-1][i // plot_stride].set_title(f"step {int(i + plot_stride)}", y=-0.25, fontsize=fs)

    for j in range(len(ax)):
        for i in range(len(ax[j])):
            ax[j][i].xaxis.set_ticks([])
            ax[j][i].yaxis.set_ticks([])

    plt.subplots_adjust(hspace=0.05, wspace=0.05)
    plt.savefig(fig_path)
    plt.close(fig)
