
from functools import partial
import typing as tp

import numpy as np
import torch
from torch import nn
from torch.nn import functional as F

from .blocks import (ResConvBlock, FourierFeatures, Upsample1d, Upsample1d_2, Downsample1d, Downsample1d_2,
                     SelfAttention1d, SkipBlock, expand_to_planes)
from .conditioners import MultiConditioner, create_multi_conditioner_from_conditioning_config
from .dit import DiffusionTransformer
from .factory import create_pretransform_from_config
from .pretransforms import Pretransform
from ..inference.generation import generate_diffusion_cond
from .adp import UNetCFG1d, UNet1d


class DiffusionModel(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def forward(self, x, t, **kwargs):
        raise NotImplementedError()


class DiffusionModelWrapper(nn.Module):
    def __init__(
        self,
        model: DiffusionModel,
        io_channels,
        sample_size,
        sample_rate,
        min_input_length,
        pretransform: tp.Optional[Pretransform] = None,
    ):
        super().__init__()
        self.io_channels = io_channels
        self.sample_size = sample_size
        self.sample_rate = sample_rate
        self.min_input_length = min_input_length

        self.model = model

        if pretransform:
            self.pretransform = pretransform
        else:
            self.pretransform = None

    def forward(self, x, t, **kwargs):
        return self.model(x, t, **kwargs)


class ConditionedDiffusionModel(nn.Module):
    def __init__(
        self,
        *args,
        supports_cross_attention: bool = False,
        supports_input_concat: bool = False,
        supports_global_cond: bool = False,
        supports_prepend_cond: bool = False,
        **kwargs
    ):
        super().__init__(*args, **kwargs)
        self.supports_cross_attention = supports_cross_attention
        self.supports_input_concat = supports_input_concat
        self.supports_global_cond = supports_global_cond
        self.supports_prepend_cond = supports_prepend_cond

    def forward(
        self,
        x: torch.Tensor,
        t: torch.Tensor,
        cross_attn_cond: torch.Tensor = None,
        cross_attn_mask: torch.Tensor = None,
        input_concat_cond: torch.Tensor = None,
        global_embed: torch.Tensor = None,
        prepend_cond: torch.Tensor = None,
        prepend_cond_mask: torch.Tensor = None,
        cfg_scale: float = 1.0,
        cfg_dropout_prob: float = 0.0,
        batch_cfg: bool = False,
        rescale_cfg: bool = False,
        **kwargs
    ):
        raise NotImplementedError()


class ConditionedDiffusionModelWrapper(nn.Module):
    """
    A diffusion model that takes in conditioning
    """

    def __init__(
            self,
            model: ConditionedDiffusionModel,
            conditioner: MultiConditioner,
            io_channels,
            sample_rate,
            min_input_length: int,
            diffusion_objective: tp.Literal["v", "rectified_flow"] = "v",
            pretransform: tp.Optional[Pretransform] = None,
            cross_attn_cond_ids: tp.List[str] = [],
            global_cond_ids: tp.List[str] = [],
            input_concat_ids: tp.List[str] = [],
            prepend_cond_ids: tp.List[str] = [],
    ):
        super().__init__()

        self.model = model
        self.conditioner = conditioner
        self.io_channels = io_channels
        self.sample_rate = sample_rate
        self.diffusion_objective = diffusion_objective
        self.pretransform = pretransform
        self.cross_attn_cond_ids = cross_attn_cond_ids
        self.global_cond_ids = global_cond_ids
        self.input_concat_ids = input_concat_ids
        self.prepend_cond_ids = prepend_cond_ids
        self.min_input_length = min_input_length

    def get_conditioning_inputs(
        self,
        conditioning_tensors: tp.Dict[str, tp.Any],
        negative=False
    ):
        cross_attention_input = None
        cross_attention_masks = None
        global_cond = None
        input_concat_cond = None
        prepend_cond = None
        prepend_cond_mask = None

        if len(self.cross_attn_cond_ids) > 0:
            # Concatenate all cross-attention inputs over the sequence dimension
            # Assumes that the cross-attention inputs are of shape (batch, seq, channels)
            cross_attention_input = []
            cross_attention_masks = []

            for key in self.cross_attn_cond_ids:
                cross_attn_in, cross_attn_mask = conditioning_tensors[key]

                # Add sequence dimension if it's not there
                if len(cross_attn_in.shape) == 2:
                    cross_attn_in = cross_attn_in.unsqueeze(1)
                    cross_attn_mask = cross_attn_mask.unsqueeze(1)

                cross_attention_input.append(cross_attn_in)
                cross_attention_masks.append(cross_attn_mask)

            cross_attention_input = torch.cat(cross_attention_input, dim=1)
            cross_attention_masks = torch.cat(cross_attention_masks, dim=1)

        if len(self.global_cond_ids) > 0:
            # Concatenate all global conditioning inputs over the channel dimension
            # Assumes that the global conditioning inputs are of shape (batch, channels)
            global_conds = []
            for key in self.global_cond_ids:
                global_cond_input = conditioning_tensors[key][0]
                global_conds.append(global_cond_input)

            # Concatenate over the channel dimension
            global_cond = torch.cat(global_conds, dim=-1)

            if len(global_cond.shape) == 3:
                global_cond = global_cond.squeeze(1)

        if len(self.input_concat_ids) > 0:
            # Concatenate all input concat conditioning inputs over the channel dimension
            # Assumes that the input concat conditioning inputs are of shape (batch, channels, seq)
            input_concat_cond = torch.cat([conditioning_tensors[key][0] for key in self.input_concat_ids], dim=1)

        if len(self.prepend_cond_ids) > 0:
            # Concatenate all prepend conditioning inputs over the sequence dimension
            # Assumes that the prepend conditioning inputs are of shape (batch, seq, channels)
            prepend_conds = []
            prepend_cond_masks = []

            for key in self.prepend_cond_ids:
                prepend_cond_input, prepend_cond_mask = conditioning_tensors[key]
                prepend_conds.append(prepend_cond_input)
                prepend_cond_masks.append(prepend_cond_mask)

            prepend_cond = torch.cat(prepend_conds, dim=1)
            prepend_cond_mask = torch.cat(prepend_cond_masks, dim=1)

        if negative:
            return {
                "negative_cross_attn_cond": cross_attention_input,
                "negative_cross_attn_mask": cross_attention_masks,
                "negative_global_cond": global_cond,
                "negative_input_concat_cond": input_concat_cond
            }
        else:
            return {
                "cross_attn_cond": cross_attention_input,
                "cross_attn_mask": cross_attention_masks,
                "global_cond": global_cond,
                "input_concat_cond": input_concat_cond,
                "prepend_cond": prepend_cond,
                "prepend_cond_mask": prepend_cond_mask
            }

    def forward(self, x: torch.Tensor, t: torch.Tensor, cond: tp.Dict[str, tp.Any], **kwargs):
        return self.model(x, t, **self.get_conditioning_inputs(cond), **kwargs)

    def generate(self, *args, **kwargs):
        return generate_diffusion_cond(self, *args, **kwargs)


class UNetCFG1DWrapper(ConditionedDiffusionModel):
    def __init__(
        self,
        *args,
        **kwargs
    ):
        super().__init__(supports_cross_attention=True, supports_global_cond=True, supports_input_concat=True)

        self.model = UNetCFG1d(*args, **kwargs)

        with torch.no_grad():
            for param in self.model.parameters():
                param *= 0.5

    def forward(
        self,
        x,
        t,
        cross_attn_cond=None,
        cross_attn_mask=None,
        input_concat_cond=None,
        global_cond=None,
        cfg_scale=1.0,
        cfg_dropout_prob: float = 0.0,
        batch_cfg: bool = False,
        rescale_cfg: bool = False,
        negative_cross_attn_cond=None,
        negative_cross_attn_mask=None,
        negative_global_cond=None,
        negative_input_concat_cond=None,
        prepend_cond=None,
        prepend_cond_mask=None,
        **kwargs
    ):
        channels_list = None
        if input_concat_cond:
            channels_list = [input_concat_cond]

        outputs = self.model(
            x,
            t,
            embedding=cross_attn_cond,
            embedding_mask=cross_attn_mask,
            features=global_cond,
            channels_list=channels_list,
            embedding_scale=cfg_scale,
            embedding_mask_proba=cfg_dropout_prob,
            batch_cfg=batch_cfg,
            rescale_cfg=rescale_cfg,
            negative_embedding=negative_cross_attn_cond,
            negative_embedding_mask=negative_cross_attn_mask,
            **kwargs)

        return outputs


class UNet1DCondWrapper(ConditionedDiffusionModel):
    def __init__(
        self,
        *args,
        **kwargs
    ):
        super().__init__(supports_cross_attention=False, supports_global_cond=True, supports_input_concat=True)

        self.model = UNet1d(*args, **kwargs)

        with torch.no_grad():
            for param in self.model.parameters():
                param *= 0.5

    def forward(
        self,
        x,
        t,
        input_concat_cond=None,
        global_cond=None,
        cross_attn_cond=None,
        cross_attn_mask=None,
        prepend_cond=None,
        prepend_cond_mask=None,
        cfg_scale=1.0,
        cfg_dropout_prob: float = 0.0,
        batch_cfg: bool = False,
        rescale_cfg: bool = False,
        negative_cross_attn_cond=None,
        negative_cross_attn_mask=None,
        negative_global_cond=None,
        negative_input_concat_cond=None,
        **kwargs
    ):
        channels_list = None
        if input_concat_cond:
            # Interpolate input_concat_cond to the same length as x
            if input_concat_cond.shape[2] != x.shape[2]:
                input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest')

            channels_list = [input_concat_cond]

        outputs = self.model(
            x,
            t,
            features=global_cond,
            channels_list=channels_list,
            **kwargs)

        return outputs


class UNet1DUncondWrapper(DiffusionModel):
    def __init__(
        self,
        in_channels,
        *args,
        **kwargs
    ):
        super().__init__()

        self.model = UNet1d(in_channels=in_channels, *args, **kwargs)
        self.io_channels = in_channels

        with torch.no_grad():
            for param in self.model.parameters():
                param *= 0.5

    def forward(self, x, t, **kwargs):
        return self.model(x, t, **kwargs)


class DAU1DCondWrapper(ConditionedDiffusionModel):
    def __init__(
        self,
        *args,
        **kwargs
    ):
        super().__init__(supports_cross_attention=False, supports_global_cond=False, supports_input_concat=True)

        self.model = DiffusionAttnUnet1D(*args, **kwargs)

        with torch.no_grad():
            for param in self.model.parameters():
                param *= 0.5

    def forward(
        self,
        x,
        t,
        input_concat_cond=None,
        cross_attn_cond=None,
        cross_attn_mask=None,
        global_cond=None,
        cfg_scale=1.0,
        cfg_dropout_prob: float = 0.0,
        batch_cfg: bool = False,
        rescale_cfg: bool = False,
        negative_cross_attn_cond=None,
        negative_cross_attn_mask=None,
        negative_global_cond=None,
        negative_input_concat_cond=None,
        prepend_cond=None,
        **kwargs
    ):
        return self.model(x, t, cond=input_concat_cond)


class DiffusionAttnUnet1D(nn.Module):
    def __init__(
        self,
        io_channels=2,
        depth=14,
        n_attn_layers=6,
        channels=[128, 128, 256, 256] + [512] * 10,
        cond_dim=0,
        cond_noise_aug=False,
        kernel_size=5,
        learned_resample=False,
        strides=[2] * 13,
        conv_bias=True,
        use_snake=False
    ):
        super().__init__()

        self.cond_noise_aug = cond_noise_aug
        self.io_channels = io_channels

        if self.cond_noise_aug:
            self.rng = torch.quasirandom.SobolEngine(1, scramble=True)

        self.timestep_embed = FourierFeatures(1, 16)

        attn_layer = depth - n_attn_layers
        strides = [1] + strides
        conv_block = partial(ResConvBlock, kernel_size=kernel_size, conv_bias=conv_bias, use_snake=use_snake)
        block = nn.Identity()

        for i in range(depth, 0, -1):
            c = channels[i - 1]
            stride = strides[i - 1]

            if stride > 2 and not learned_resample:
                raise ValueError("Must have stride 2 without learned resampling")

            if i > 1:
                c_prev = channels[i - 2]
                add_attn = i >= attn_layer and n_attn_layers > 0
                block = SkipBlock(
                    Downsample1d_2(c_prev, c_prev, stride) if (learned_resample or stride == 1) else Downsample1d("cubic"),
                    conv_block(c_prev, c, c),
                    SelfAttention1d(c, c // 32) if add_attn else nn.Identity(),
                    conv_block(c, c, c),
                    SelfAttention1d(c, c // 32) if add_attn else nn.Identity(),
                    conv_block(c, c, c),
                    SelfAttention1d(c, c // 32) if add_attn else nn.Identity(),
                    block,
                    conv_block(c * 2 if i != depth else c, c, c),
                    SelfAttention1d(c, c // 32) if add_attn else nn.Identity(),
                    conv_block(c, c, c),
                    SelfAttention1d(c, c // 32) if add_attn else nn.Identity(),
                    conv_block(c, c, c_prev),
                    SelfAttention1d(c_prev, c_prev // 32) if add_attn else nn.Identity(),
                    Upsample1d_2(c_prev, c_prev, stride) if learned_resample else Upsample1d(kernel="cubic")
                )
            else:
                cond_embed_dim = 16 if not self.cond_noise_aug else 32
                block = nn.Sequential(
                    conv_block((io_channels + cond_dim) + cond_embed_dim, c, c),
                    conv_block(c, c, c),
                    conv_block(c, c, c),
                    block,
                    conv_block(c * 2, c, c),
                    conv_block(c, c, c),
                    conv_block(c, c, io_channels, is_last=True),
                )

        self.net = block

        with torch.no_grad():
            for param in self.net.parameters():
                param *= 0.5

    def forward(self, x, t, cond=None, cond_aug_scale=None):

        timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), x.shape)
        inputs = [x, timestep_embed]

        if cond:
            if cond.shape[2] != x.shape[2]:
                cond = F.interpolate(cond, (x.shape[2], ), mode='linear', align_corners=False)

            if self.cond_noise_aug:
                # Get a random number between 0 and 1, uniformly sampled
                if cond_aug_scale is None:
                    aug_level = self.rng.draw(cond.shape[0])[:, 0].to(cond)
                else:
                    aug_level = torch.tensor([cond_aug_scale]).repeat([cond.shape[0]]).to(cond)

                # Add noise to the conditioning signal
                cond = cond + torch.randn_like(cond) * aug_level[:, None, None]

                # Get embedding for noise cond level, reusing timestamp_embed
                aug_level_embed = expand_to_planes(self.timestep_embed(aug_level[:, None]), x.shape)

                inputs.append(aug_level_embed)

            inputs.append(cond)

        outputs = self.net(torch.cat(inputs, dim=1))

        return outputs


class DiTWrapper(ConditionedDiffusionModel):
    def __init__(self, *args, **kwargs):
        super().__init__(supports_cross_attention=True, supports_global_cond=False, supports_input_concat=False)
        self.model = DiffusionTransformer(*args, **kwargs)

        with torch.no_grad():
            for param in self.model.parameters():
                param *= 0.5

    def forward(
        self,
        x,
        t,
        cross_attn_cond=None,
        cross_attn_mask=None,
        negative_cross_attn_cond=None,
        negative_cross_attn_mask=None,
        input_concat_cond=None,
        negative_input_concat_cond=None,
        global_cond=None,
        negative_global_cond=None,
        prepend_cond=None,
        prepend_cond_mask=None,
        cfg_scale=1.0,
        cfg_dropout_prob: float = 0.0,
        batch_cfg: bool = True,
        rescale_cfg: bool = False,
        scale_phi: float = 0.0,
        **kwargs
    ):
        assert batch_cfg, "batch_cfg must be True for DiTWrapper"

        return self.model(
            x,
            t,
            cross_attn_cond=cross_attn_cond,
            cross_attn_cond_mask=cross_attn_mask,
            negative_cross_attn_cond=negative_cross_attn_cond,
            negative_cross_attn_mask=negative_cross_attn_mask,
            input_concat_cond=input_concat_cond,
            prepend_cond=prepend_cond,
            prepend_cond_mask=prepend_cond_mask,
            cfg_scale=cfg_scale,
            cfg_dropout_prob=cfg_dropout_prob,
            scale_phi=scale_phi,
            global_embed=global_cond,
            **kwargs
        )


class DiTUncondWrapper(DiffusionModel):
    def __init__(
        self,
        in_channels,
        *args,
        **kwargs
    ):
        super().__init__()

        self.model = DiffusionTransformer(io_channels=in_channels, *args, **kwargs)
        self.io_channels = in_channels

        with torch.no_grad():
            for param in self.model.parameters():
                param *= 0.5

    def forward(self, x, t, **kwargs):
        return self.model(x, t, **kwargs)


def create_diffusion_uncond_from_config(config: tp.Dict[str, tp.Any]):
    diffusion_uncond_config = config["model"]
    model_type = diffusion_uncond_config['type']
    diffusion_config = diffusion_uncond_config.get('config', {})
    pretransform = diffusion_uncond_config.get("pretransform", None)

    sample_size = config.get["sample_size"]
    sample_rate = config.get["sample_rate"]
    min_input_length = 1

    if pretransform:
        pretransform = create_pretransform_from_config(pretransform, sample_rate)
        min_input_length = pretransform.downsampling_ratio

    if model_type == 'DAU1d':
        model = DiffusionAttnUnet1D(**diffusion_config)
    elif model_type == "adp_uncond_1d":
        model = UNet1DUncondWrapper(**diffusion_config)
    elif model_type == "dit":
        model = DiTUncondWrapper(**diffusion_config)
    else:
        raise NotImplementedError(f'Unknown model type: {model_type}')

    return DiffusionModelWrapper(
        model,
        io_channels=model.io_channels,
        sample_size=sample_size,
        sample_rate=sample_rate,
        pretransform=pretransform,
        min_input_length=min_input_length
    )


def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]):

    model_config = config["model"]
    model_type = config["model_type"]
    diffusion_config = model_config['diffusion']
    diffusion_model_type = diffusion_config['type']
    diffusion_model_config = diffusion_config['config']

    if diffusion_model_type == 'adp_cfg_1d':
        diffusion_model = UNetCFG1DWrapper(**diffusion_model_config)
    elif diffusion_model_type == 'adp_1d':
        diffusion_model = UNet1DCondWrapper(**diffusion_model_config)
    elif diffusion_model_type == 'dit':
        diffusion_model = DiTWrapper(**diffusion_model_config)
    else:
        raise NotImplementedError(f'Unknown model type: {diffusion_model_type}')

    io_channels = model_config['io_channels']
    sample_rate = config['sample_rate']
    diffusion_objective = diffusion_config.get('diffusion_objective', 'v')
    conditioning_config = model_config.get('conditioning', None)

    conditioner = None
    if conditioning_config:
        conditioner = create_multi_conditioner_from_conditioning_config(conditioning_config)

    cross_attn_cond_ids = diffusion_config.get('cross_attention_cond_ids', [])
    global_cond_ids = diffusion_config.get('global_cond_ids', [])
    input_concat_ids = diffusion_config.get('input_concat_ids', [])
    prepend_cond_ids = diffusion_config.get('prepend_cond_ids', [])

    pretransform = model_config.get("pretransform", None)

    if pretransform:
        pretransform = create_pretransform_from_config(pretransform, sample_rate)
        min_input_length = pretransform.downsampling_ratio
    else:
        min_input_length = 1

    if diffusion_model_type == "adp_cfg_1d" or diffusion_model_type == "adp_1d":
        min_input_length *= np.prod(diffusion_model_config["factors"])
    elif diffusion_model_type == "dit":
        min_input_length *= diffusion_model.model.patch_size

    # Get the proper wrapper class

    extra_kwargs = {}

    if model_type == "diffusion_cond" or model_type == "diffusion_cond_inpaint":
        wrapper_fn = ConditionedDiffusionModelWrapper
        extra_kwargs["diffusion_objective"] = diffusion_objective
    elif model_type == "diffusion_prior":
        prior_type = model_config["prior_type"]

        if prior_type == "mono_stereo":
            from .diffusion_prior import MonoToStereoDiffusionPrior
            wrapper_fn = MonoToStereoDiffusionPrior

    return wrapper_fn(
        diffusion_model,
        conditioner,
        min_input_length=min_input_length,
        sample_rate=sample_rate,
        cross_attn_cond_ids=cross_attn_cond_ids,
        global_cond_ids=global_cond_ids,
        input_concat_ids=input_concat_ids,
        prepend_cond_ids=prepend_cond_ids,
        pretransform=pretransform,
        io_channels=io_channels,
        **extra_kwargs
    )
