# adopted from
# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
# and
# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
# and
# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
#
# thanks!
from typing import Dict, List, Optional

import math
import torch
import torch.nn as nn
import numpy as np
from einops import repeat

from celldiff.util import instantiate_from_config

def create_activation(name):
    if name == "relu":
        return nn.ReLU()
    elif name == "gelu":
        return nn.GELU()
    elif name == "prelu":
        return nn.PReLU()
    elif name is None:
        return nn.Identity()
    elif name == "elu":
        return nn.ELU()
    else:
        raise NotImplementedError(f"{name} is not implemented.")

def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
    if schedule == "linear":
        betas = (
                torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
        )

    elif schedule == "cosine":
        timesteps = (
                torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
        )
        alphas = timesteps / (1 + cosine_s) * np.pi / 2
        alphas = torch.cos(alphas).pow(2)
        alphas = alphas / alphas[0]
        betas = 1 - alphas[1:] / alphas[:-1]
        betas = np.clip(betas, a_min=0, a_max=0.999)

    elif schedule == "sqrt_linear":
        betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
    elif schedule == "sqrt":
        betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
    else:
        raise ValueError(f"schedule '{schedule}' unknown.")
    return betas.numpy()


def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
    if ddim_discr_method == 'uniform':
        c = num_ddpm_timesteps // num_ddim_timesteps
        ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
    elif ddim_discr_method == 'quad':
        ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
    else:
        raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')

    # assert ddim_timesteps.shape[0] == num_ddim_timesteps
    # add one to get the final alpha values right (the ones from first scale to data during sampling)
    steps_out = ddim_timesteps + 1
    if verbose:
        print(f'Selected timesteps for ddim sampler: {steps_out}')
    return steps_out


def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
    # select alphas for computing the variance schedule
    alphas = alphacums[ddim_timesteps]
    alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())

    # according the the formula provided in https://arxiv.org/abs/2010.02502
    sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
    if verbose:
        print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
        print(f'For the chosen value of eta, which is {eta}, '
              f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
    return sigmas, alphas, alphas_prev


def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
    """
    Create a beta schedule that discretizes the given alpha_t_bar function,
    which defines the cumulative product of (1-beta) over time from t = [0,1].
    :param num_diffusion_timesteps: the number of betas to produce.
    :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
                      produces the cumulative product of (1-beta) up to that
                      part of the diffusion process.
    :param max_beta: the maximum beta to use; use values lower than 1 to
                     prevent singularities.
    """
    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
    return np.array(betas)


def extract_into_tensor(a, t, x_shape):
    b, *_ = t.shape
    out = a.gather(-1, t)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))


def checkpoint(func, inputs, params, flag):
    """
    Evaluate a function without caching intermediate activations, allowing for
    reduced memory at the expense of extra compute in the backward pass.
    :param func: the function to evaluate.
    :param inputs: the argument sequence to pass to `func`.
    :param params: a sequence of parameters `func` depends on but does not
                   explicitly take as arguments.
    :param flag: if False, disable gradient checkpointing.
    """
    if flag:
        args = tuple(inputs) + tuple(params)
        return CheckpointFunction.apply(func, len(inputs), *args)
    else:
        return func(*inputs)


class CheckpointFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, run_function, length, *args):
        ctx.run_function = run_function
        ctx.input_tensors = list(args[:length])
        ctx.input_params = list(args[length:])

        with torch.no_grad():
            output_tensors = ctx.run_function(*ctx.input_tensors)
        return output_tensors

    @staticmethod
    def backward(ctx, *output_grads):
        ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors if x is not None]
        with torch.enable_grad():
            # Fixes a bug where the first op in run_function modifies the
            # Tensor storage in place, which is not allowed for detach()'d
            # Tensors.
            shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
            output_tensors = ctx.run_function(*shallow_copies)
        input_grads = torch.autograd.grad(
            output_tensors,
            ctx.input_tensors + ctx.input_params,
            output_grads,
            allow_unused=True,
        )
        del ctx.input_tensors
        del ctx.input_params
        del output_tensors
        return (None, None) + input_grads


def sinusoidal_embedding(pos: torch.Tensor, dim: int, max_period: int) -> torch.Tensor:
    half = dim // 2
    freqs = torch.exp(
        -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
    ).to(device=pos.device)
    args = pos[:, None].float() * freqs[None]
    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
    if dim % 2:
        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
    return embedding


def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
    """
    Create sinusoidal timestep embeddings.
    :param timesteps: a 1-D Tensor of N indices, one per batch element.
                      These may be fractional.
    :param dim: the dimension of the output.
    :param max_period: controls the minimum frequency of the embeddings.
    :return: an [N x dim] Tensor of positional embeddings.
    """
    if not repeat_only:
        embedding = sinusoidal_embedding(timesteps, dim, max_period)
    else:
        embedding = repeat(timesteps, 'b -> b d', d=dim)
    return embedding


def zero_module(module):
    """
    Zero out the parameters of a module and return it.
    """
    for p in module.parameters():
        p.detach().zero_()
    return module


def scale_module(module, scale):
    """
    Scale the parameters of a module and return it.
    """
    for p in module.parameters():
        p.detach().mul_(scale)
    return module


def mean_flat(tensor):
    """
    Take the mean over all non-batch dimensions.
    """
    return tensor.mean(dim=list(range(1, len(tensor.shape))))


def normalization(channels):
    """
    Make a standard normalization layer.
    :param channels: number of input channels.
    :return: an nn.Module for normalization.
    """
    return GroupNorm32(32, channels)


# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
class SiLU(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)


class GroupNorm32(nn.GroupNorm):
    def forward(self, x):
        return super().forward(x.float()).type(x.dtype)

def conv_nd(dims, *args, **kwargs):
    """
    Create a 1D, 2D, or 3D convolution module.
    """
    if dims == 1:
        return nn.Conv1d(*args, **kwargs)
    elif dims == 2:
        return nn.Conv2d(*args, **kwargs)
    elif dims == 3:
        return nn.Conv3d(*args, **kwargs)
    raise ValueError(f"unsupported dimensions: {dims}")


def linear(*args, **kwargs):
    """
    Create a linear module.
    """
    return nn.Linear(*args, **kwargs)


def avg_pool_nd(dims, *args, **kwargs):
    """
    Create a 1D, 2D, or 3D average pooling module.
    """
    if dims == 1:
        return nn.AvgPool1d(*args, **kwargs)
    elif dims == 2:
        return nn.AvgPool2d(*args, **kwargs)
    elif dims == 3:
        return nn.AvgPool3d(*args, **kwargs)
    raise ValueError(f"unsupported dimensions: {dims}")


class HybridConditioner(nn.Module):

    def __init__(self, c_concat_config, c_crossattn_config):
        super().__init__()
        self.concat_conditioner = instantiate_from_config(c_concat_config)
        self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)

    def forward(self, c_concat, c_crossattn):
        c_concat = self.concat_conditioner(c_concat)
        c_crossattn = self.crossattn_conditioner(c_crossattn)
        return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}


def noise_like(shape, device, repeat=False):
    repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
    noise = lambda: torch.randn(shape, device=device)
    return repeat_noise() if repeat else noise()

def select_pe_encoder(pe):
    if pe in ['sin', 'sinu', 'sinusoidal']:
        return Sinusoidal2dPE
    elif pe in ['learnable', 'bin']:
        return Learnable2dPE
    elif pe in ['naive', 'mlp']:
        return NaivePE
    elif pe in ['lap', 'graphlap', 'lappe']:
        return GraphLapPE
    else:
        raise NotImplementedError(f'Unsupported positional encoding type: {pe}')

class Sinusoidal2dPE(nn.Module):
    def __init__(self, d_model, height=100, width=100):
        """
        :param d_model: dimension of the model
        :param height: height of the positions
        :param width: width of the positions
        """
        super().__init__()
        if d_model % 4 != 0:
            raise ValueError("Cannot use sin/cos positional encoding with "
                             "odd dimension (got dim={:d})".format(d_model))
        self.d_model = d_model
        self.height = height
        self.width = width
        self.pe_key = 'coord'
        self.missing_pe = nn.Parameter(torch.randn(d_model) * 1e-2)

        pe = torch.zeros(d_model, height, width)
        # Each dimension use half of d_model
        d_model = int(d_model / 2)
        div_term = torch.exp(torch.arange(0., d_model, 2) *
                             -(math.log(10000.0) / d_model))
        pos_w = torch.arange(0., width).unsqueeze(1)
        pos_h = torch.arange(0., height).unsqueeze(1)
        pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
        pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
        pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
        pe[d_model + 1::2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
        self.pe_enc = nn.Embedding.from_pretrained(pe.flatten(1).T)

    def forward(self, coordinates):
        if coordinates[0][0] == -1:
            return self.missing_pe.unsqueeze(0).expand(coordinates.shape[0], -1)
        x = coordinates[:, 0]
        y = coordinates[:, 1]
        x = ((x*1.02-0.01) * self.width).long()
        y = ((y*1.02-0.01) * self.height).long()
        x[x >= self.width] = self.width - 1
        y[y >= self.height] = self.height - 1
        x[x < 0] = 0
        y[y < 0] = 0
        pe_input = x * self.width + y
        return self.pe_enc(pe_input)

class Learnable2dPE(nn.Module):
    def __init__(self, d_model, height=100, width=100):
        """
        :param d_model: dimension of the model
        :param height: height of the positions
        :param width: width of the positions
        """
        super().__init__()
        self.pe_enc = nn.Embedding(height * width, d_model)
        self.missing_pe = nn.Parameter(torch.randn(d_model) * 1e-2)
        self.pe_key = 'coord'

    def forward(self, coordinates):
        if coordinates[0][0] == -1:
            return self.missing_pe.unsqueeze(0).expand(coordinates.shape[0], -1)
        x = coordinates[:, 0]
        y = coordinates[:, 1]
        x = ((x*1.02-0.01) * self.width).long()
        y = ((y*1.02-0.01) * self.height).long()
        x[x >= self.width] = self.width
        y[y >= self.height] = self.height
        x[x < 0] = 0
        y[y < 0] = 0
        pe_input = x * self.width + y
        return self.pe_enc(pe_input)

class NaivePE(nn.Module):
    def __init__(self, d_model, coord_dim = 2, height=None, width=None):
        """
        :param d_model: dimension of the model
        :param coord_dim: dimension of coordinates
        :param height: placeholder
        :param width: placeholder
        """
        super().__init__()
        self.pe_enc = nn.Sequential(
                            nn.Linear(coord_dim, d_model),
                            nn.PReLU(),
        )
        self.missing_pe = nn.Parameter(torch.randn(d_model) * 1e-2)
        self.pe_key = 'coord'

    def forward(self, coordinates):
        if coordinates[0][0] == -1:
            return self.missing_pe.unsqueeze(0).expand(coordinates.shape[0], -1)
        return self.pe_enc(coordinates)

class GraphLapPE(nn.Module):
    def __init__(self, d_model, k = 10, height=None, width=None):
        """
        :param d_model: dimension of the model
        :param k: top k
        :param height: placeholder
        :param width: placeholder
        """
        super().__init__()
        self.k = k
        self.pe_enc = nn.Sequential(
                            nn.Linear(k, d_model),
                            nn.PReLU(),
        )
        self.missing_pe = nn.Parameter(torch.randn(d_model) * 1e-2)
        self.pe_key = 'eigvec'

    def forward(self, eigvec):
        if eigvec[0][0] == -1:
            return self.missing_pe.unsqueeze(0).expand(eigvec.shape[0], -1)
        eigvec = eigvec * (torch.randint(0, 2, (self.k, ), dtype=torch.float, device=eigvec.device)[None, :]*2-1)
        return self.pe_enc(eigvec)


class MaskedEncoderConditioner(nn.Module):
    """Use 2-layer MLP to encoder available feature number.

    The encoded feature number condition is added to the cell embddings. If
    disabled, then directly return the original cell embeddings.

    """

    def __init__(
        self,
        dim: int,
        mult: int = 4,
        use_ratio: bool = False,
        use_se: bool = False,
        use_semlp: bool = False,
        concat: bool = False,
        disable: bool = False,
    ):
        super().__init__()
        assert not (use_ratio and use_se), "Cannot set use_se and use_ratio together"
        assert not (use_se and use_semlp), "Cannot set use_se and use_semlp together"
        assert not (use_se and concat), "Cannot set use_se and concat together"
        self.dim = dim
        self.use_ratio = use_ratio
        self.use_se = use_se or use_semlp
        self.concat = concat
        self.disable = disable
        if not disable:
            dim_in = dim if self.use_se else 1
            dim_in = dim_in + dim if concat else dim_in
            dim_hid = dim * mult

            self.proj = nn.Sequential(
                nn.Linear(dim_in, dim_hid),
                nn.SiLU(),
                nn.Linear(dim_hid, dim),
            ) if not use_se else nn.Identity()

    def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        if not self.disable and mask is not None:
            # Count the number of denoising input featues
            size = (mask.bool()).sum(1, keepdim=True).float()

            if self.use_ratio:
                h = size / x.shape[1]
            elif self.use_se:
                h = sinusoidal_embedding(size.ravel(), dim=self.dim, max_period=x.shape[1] + 1)
            else:
                h = size

            if self.concat:
                h = torch.cat((x, h), dim=-1)
                x = self.proj(h)
            else:
                h = self.proj(h)
                x = x + h

        return x


class ConditionEncoderWrapper(nn.Module):
    def __init__(self, module):
        super().__init__()
        self.module = module

    def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None):
        x = torch.cat((x, context), dim=1).sum(1) if context is not None else x.squeeze(1)
        return self.module(x).unsqueeze(1)
