from copy import deepcopy
from typing import Tuple, List, Union, Optional

import torch
import torch.nn as nn
from beartype import beartype
from diffusers import UNet2DModel
from jaxtyping import jaxtyped, Float

from conf.model import BackboneParams
from src.Backbones.Unet_cold_diffusion_MultiTime import Unet_Cold_Multi_Domain
from utils.utils import freeze


def adapt_time_steps(t):
    if t is not None and t.dim() == 3:
        b, ndom, dim = t.shape
        t = t.reshape(b, ndom * dim)
    return t


class TimeEmbedding(nn.Module):
    def __init__(self, model: UNet2DModel):
        super().__init__()
        self.time_proj = model.time_proj
        self.time_embedding = model.time_embedding
        self.dtype = model.dtype

    def forward(self, time):
        t = adapt_time_steps(time)
        batch, ndom = t.shape
        t_emb = self.time_proj(t.flatten())

        # timesteps does not contain any weights and will always return f32 tensors
        # but time_embedding might actually be running in fp16. so we need to cast here.
        # there might be better ways to encapsulate this.
        t_emb = t_emb.to(dtype=self.dtype)
        emb = self.time_embedding(t_emb)
        emb = emb.reshape(batch, ndom, -1)

        return emb


class Encoder(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.conv_in = model.conv_in
        self.down_blocks = model.down_blocks

    @jaxtyped
    @beartype
    def forward(
        self,
        x: Float[torch.Tensor, 'b c h w'],
        t: Float[torch.Tensor, 'b *n_dom dim'],
    ) -> Tuple[
        Float[torch.Tensor, 'b ci hi wi'],        # embedding
        List[Float[torch.Tensor, 'b cj hj wj']],  # residual + skip connections list
    ]:
        if t.dim() == 3:
            t = t.squeeze(1)
        emb = t

        # 2. pre-process
        skip_sample = x
        x = self.conv_in(x)

        # 3. down
        down_block_res_samples = [x]
        for downsample_block in self.down_blocks:
            if hasattr(downsample_block, "skip_conv"):
                x, res_samples, skip_sample = downsample_block(
                    hidden_states=x, temb=emb, skip_sample=skip_sample
                )
            else:
                x, res_samples = downsample_block(hidden_states=x, temb=emb)

            down_block_res_samples += res_samples

        return x, down_block_res_samples


class TimeEmbProjCelebA3D(nn.Module):
    def __init__(self, in_dim: int, out_dim: int):
        super().__init__()
        self.linear = nn.Linear(in_dim, out_dim)

    def forward(self, x):
        batch, ndom, dim = x.shape
        x = x.reshape([batch, ndom * dim])
        x = self.linear(x)
        return x


class Middle(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.mid_block = model.mid_block

        # in the bottleneck, we have the time embedding which is for the 3 domains
        # but the pretrained time embedding from celeba is for 1 domain
        # so we need to change the layer that is responsible for the time embedding and flatten the time embedding
        # then apply the linear with the output size unchanged and the input size x 3
        for resnetblock in self.mid_block.resnets:
            in_dim = resnetblock.time_emb_proj.in_features * 3
            out_dim = resnetblock.time_emb_proj.out_features
            resnetblock.time_emb_proj = TimeEmbProjCelebA3D(in_dim, out_dim)

    @jaxtyped
    @beartype
    def forward(
        self,
        x: Float[torch.Tensor, 'b c h w'],
        t: Float[torch.Tensor, 'b *n_dom dim'],
    ) -> Float[torch.Tensor, 'b ci hi wi']:
        emb = t

        # 4. mid
        sample = self.mid_block(x, emb)

        return sample


class Decoder(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.up_blocks = model.up_blocks
        self.conv_norm_out = model.conv_norm_out
        self.conv_act = model.conv_act
        self.conv_out = model.conv_out
        self.config = model.config

    @jaxtyped
    @beartype
    def forward(
        self,
        x: Float[torch.Tensor, 'b c h w'],
        h: List[Float[torch.Tensor, 'b cl hl wl']],
        t: Float[torch.Tensor, 'b *n_dom dim'],
    ) -> Float[torch.Tensor, 'b ci hi wi']:
        emb = t

        # 5. up
        skip_sample = None
        for upsample_block in self.up_blocks:
            res_samples = h[-len(upsample_block.resnets):]
            h = h[: -len(upsample_block.resnets)]

            if hasattr(upsample_block, "skip_conv"):
                x, skip_sample = upsample_block(x, res_samples, emb, skip_sample)
            else:
                x = upsample_block(x, res_samples, emb.squeeze(dim=1))

        # 6. post-process
        x = self.conv_norm_out(x)
        x = self.conv_act(x)
        x = self.conv_out(x)

        if skip_sample is not None:
            x += skip_sample

        if self.config.time_embedding_type == "fourier":
            t = t.reshape((x.shape[0], *([1] * len(x.shape[1:]))))
            x = x / t

        return x


def getmodel(config, params: str, dom_i: int, original: Unet_Cold_Multi_Domain) -> UNet2DModel:
    model = UNet2DModel(**config)
    if dom_i in [1, 2]:
        in_channel_encoder = original.encoders.encoders[dom_i].init_conv.in_channels
        out_chanel_decoder = original.decoders.decoders[dom_i].final_conv.out_channels

        model.conv_in = nn.Conv2d(
            in_channel_encoder, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
        )
        model.conv_out = nn.Conv2d(128, out_chanel_decoder, kernel_size=(1, 1), stride=(1, 1))

    if params == 'non':
        return model
    elif params.startswith('path_'):
        model.load_state_dict(torch.load(params[5:]))
    elif params.startswith('link_'):
        model = UNet2DModel.from_pretrained(params[5:])
        return model
    else:
        raise ValueError(f'Unknown param {params=}')


def adapt_model(
    original: Unet_Cold_Multi_Domain,
    gref_model: str,
    params: BackboneParams,
) -> nn.Module:
    gref_model = UNet2DModel.from_pretrained(gref_model)
    config = dict(gref_model.config)
    config.pop('_name_or_path')

    time_embedding = TimeEmbedding(getmodel(config, params.pretrained_time_embedding, 0, original))
    original.time_mlp = time_embedding

    image_encoder  = Encoder(getmodel(config, params.pretrained_encoder[0], 0, original))
    sketch_encoder = Encoder(getmodel(config, params.pretrained_encoder[1], 1, original))
    mask_encoder   = Encoder(getmodel(config, params.pretrained_encoder[2], 2, original))

    image_decoder  = Decoder(getmodel(config, params.pretrained_decoder[0], 0, original))
    sketch_decoder = Decoder(getmodel(config, params.pretrained_decoder[1], 1, original))
    mask_decoder   = Decoder(getmodel(config, params.pretrained_decoder[2], 2, original))

    bottleneck = Middle(getmodel(config, params.pretrained_bottleneck, 0, original))

    # region set Encoders
    original.encoders.encoders[0] = image_encoder
    original.encoders.encoders[1] = sketch_encoder
    original.encoders.encoders[2] = mask_encoder
    # endregion

    # set Middle
    original.middle = bottleneck

    # region set Decoders
    original.decoders.decoders[0] = image_decoder
    original.decoders.decoders[1] = sketch_decoder
    original.decoders.decoders[2] = mask_decoder
    # endregion

    return original


def freeze_model(model: Unet_Cold_Multi_Domain, params: BackboneParams) -> None:
    for encoder, need_freeze in zip(model.encoders.encoders, params.freeze_encoder):
        if need_freeze:
            freeze(encoder)
    for decoder, need_freeze in zip(model.decoders.decoders, params.freeze_decoder):
        if need_freeze:
            freeze(decoder)
    if params.freeze_time_embedding:
        freeze(model.time_mlp)
    if params.freeze_bottleneck:
        freeze(model.middle)


def main():
    from diffusers import DDIMPipeline, DDIMScheduler
    model_id = "google/ddpm-ema-celebahq-256"
    model_src = UNet2DModel.from_pretrained(model_id)
    encoder = Encoder(model_src)
    middle = Middle(model_src)
    decoder = Decoder(model_src)





if __name__ == '__main__':
    main()
