import numpy as np
import torch
import torch.nn as nn
from demo.helper import SUPPORTED_TIMESTEP_EMBEDDING


class ResidualBlock(nn.Module):
    def __init__(self, input_dim, output_dim, cond_dim):
        super(ResidualBlock, self).__init__()
        self.fc1 = nn.Linear(input_dim, output_dim)
        self.fc_emb = nn.Sequential(
            nn.Linear(cond_dim, output_dim),
            nn.ReLU(),
            nn.Linear(output_dim, output_dim),
        )
        self.fc2 = nn.Sequential(
            nn.LayerNorm(output_dim),
            nn.Linear(output_dim, output_dim),
            nn.ReLU(),
            nn.Linear(output_dim, output_dim),
        )
        self.residual = nn.Linear(input_dim, output_dim) if input_dim != output_dim else nn.Identity()
        

    def forward(self, x, emb):
        out = self.fc1(x)
        out = out + self.fc_emb(emb)
        out = self.fc2(out)
        return out + self.residual(x)


class SimpleUNetMLP(nn.Module):
    def __init__(self, args, state_dim=None):
        super(SimpleUNetMLP, self).__init__()
        self.device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
        self.to(self.device)

        # params
        self.time_emb_dim = args.model.latent_dim
        self.input_dim = args.action_dim
        self.output_dim = args.action_dim
        self.hidden_dim = args.model.hidden_dim
        self.cond_dim = state_dim

        self.map_noise = SUPPORTED_TIMESTEP_EMBEDDING[args.model.timestep_emb_type](self.time_emb_dim)
        self.cond_layer = nn.Linear(state_dim, self.hidden_dim // 2)
        self.time_layer = nn.Linear(self.time_emb_dim, self.hidden_dim // 2)

        self.first_layer = nn.Linear(self.input_dim, self.hidden_dim // 2)
        self.mid_layer = nn.Sequential(nn.Linear(self.hidden_dim, self.hidden_dim),
                                       nn.GELU(),
                                       nn.Linear(self.hidden_dim, self.hidden_dim),
                                       nn.GELU(),
                                       nn.Linear(self.hidden_dim, self.hidden_dim),
                                       nn.GELU())

        self.final_layer = nn.Linear(self.hidden_dim, self.output_dim)

    def forward(self, x, t, cond=None):
        if cond is not None:
            cond = cond.to(self.device)
        else:
            cond = torch.zeros((x.shape[0], self.cond_dim), device=self.device)
        t = self.time_layer(self.map_noise(t)) + self.cond_layer(cond)
        x = self.first_layer(x)
        x = torch.cat([x, t], dim=1)
        x = self.mid_layer(x)
        
        return self.final_layer(x)


class UNetMLP(nn.Module):
    def __init__(self, args, state_dim=None, ch_mul=[1,2,4]): # [1,2,4,8]
        super(UNetMLP, self).__init__()
        self.device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
        self.to(self.device)
        
        # params
        self.timestep_emb_type = args.model.timestep_emb_type
        self.input_dim = args.action_dim
        self.output_dim = args.action_dim
        self.cond_dim = state_dim

        self.time_emb_dim = args.model.latent_dim
        self.hidden_dim = args.model.hidden_dim
        t_dim = self.hidden_dim * 2
        
        # model layer
        self.map_noise = SUPPORTED_TIMESTEP_EMBEDDING[self.timestep_emb_type](self.time_emb_dim)
        self.time_layer = nn.Linear(self.time_emb_dim, t_dim)
        
        self.cond_layer = nn.Sequential(
            nn.Linear(self.cond_dim, t_dim),
            nn.ReLU(),
            nn.Linear(t_dim, t_dim),
            nn.ReLU(),
            nn.Linear(t_dim, t_dim)
        )
        self.emb_layer = nn.Linear(t_dim, t_dim)

        # blocks
        dims = [self.input_dim] + [self.hidden_dim * m for m in ch_mul]
        in_out = list(zip(dims[:-1], dims[1:]))
        
        # Downsampling blocks
        self.downblocks = nn.ModuleList([])
        for ind, (in_dim, out_dim) in enumerate(in_out):
            is_last = ind >= (len(in_out) - 1)

            self.downblocks.append(nn.ModuleList([
                ResidualBlock(in_dim, out_dim, t_dim),
                ResidualBlock(out_dim, out_dim, t_dim),
                nn.Linear(out_dim, out_dim) if not is_last else nn.Identity()
            ]))

        # Middle blocks
        mid_dim = dims[-1]
        self.middleblocks1 = ResidualBlock(mid_dim, mid_dim, t_dim)
        self.middleblocks2 = ResidualBlock(mid_dim, mid_dim, t_dim)

        # Upsampling blocks
        self.upblocks = nn.ModuleList([])
        for ind, (in_dim, out_dim) in enumerate(reversed(in_out[1:])):
            is_last = ind >= (len(in_out) - 1)

            self.upblocks.append(nn.ModuleList([
                ResidualBlock(out_dim * 2, in_dim, t_dim),
                ResidualBlock(in_dim, in_dim, t_dim),
                nn.Linear(in_dim, in_dim) if not is_last else nn.Identity()
            ]))

        # Output layer
        self.out = nn.Sequential(
            nn.LayerNorm(self.hidden_dim),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, self.output_dim),
        )


    def forward(self, x, t, cond=None):
        x = x.to(self.device)
        t = t.to(self.device)
        if cond is not None:
            cond = cond.to(self.device)
        else:
            cond = torch.zeros((x.shape[0], self.cond_dim), device=self.device)

        emb = self.time_layer(self.map_noise(t)) + self.cond_layer(cond)
        emb = self.emb_layer(emb)
        
        hs = []
        # Downsampling
        for fc1, fc2, downsample in self.downblocks:
            x = fc1(x, emb)
            x = fc2(x, emb)
            hs.append(x)
            x = downsample(x)

        # Middle blocks
        x = self.middleblocks1(x, emb)
        x = self.middleblocks2(x, emb)

        # Upsampling
        for fc1, fc2, upsample in self.upblocks:
            x = torch.cat([x, hs.pop()], dim=1)
            x = fc1(x, emb)
            x = fc2(x, emb)
            x = upsample(x)


        return self.out(x)