import torch
import torch.nn as nn
import numpy as np
from .attention import SelfAttention
import torch.nn.functional as F
from .utils import *

class AttentionBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.groupnorm = nn.GroupNorm(32, channels)
        self.attention = SelfAttention(1, channels)
    
    def forward(self, x):
        residue = x
        x = self.groupnorm(x)

        n, c, h, w = x.shape
        x = x.view((n, c, h * w))
        x = x.transpose(-1, -2)
        x = self.attention(x)
        x = x.transpose(-1, -2)
        x = x.view((n, c, h, w))

        x += residue
        return x

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.groupnorm_1 = nn.GroupNorm(32, in_channels)
        self.conv_1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)

        self.groupnorm_2 = nn.GroupNorm(32, out_channels)
        self.conv_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)

        if in_channels == out_channels:
            self.residual_layer = nn.Identity()
        else:
            self.residual_layer = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
    
    def forward(self, x):
        residue = x

        x = self.groupnorm_1(x)
        x = F.silu(x)
        x = self.conv_1(x)

        x = self.groupnorm_2(x)
        x = F.silu(x)
        x = self.conv_2(x)

        return x + self.residual_layer(residue)




class Conv1dEncoder(nn.Module):
    def __init__(self, bias_shapes, hidden_dim=128, latent_dim=64, obs_dim=64, action_dim=7):
        """
        overcooked: obs_dim = 96, action_dim = 6
        assistive_gym: obs_dim = 64, action_dim = 17
        """
        super().__init__()
        self.bias_len = np.sum(bias_shapes)
        
        def Conv1dBlock(in_channels, out_channels, kernel_size, stride, normalize=True):
            layers = [nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=1)]
            if normalize:
                layers.append(nn.GroupNorm(32, out_channels))
            layers.append(nn.SiLU())
            return layers
        
        self.fc_encoder_1 = nn.Sequential( 
            *Conv1dBlock(64, hidden_dim, kernel_size=3, stride=2),
            *Conv1dBlock(hidden_dim, hidden_dim, kernel_size=3, stride=2), 
            nn.Conv1d(hidden_dim, int(latent_dim / 16), kernel_size=3, stride=1, padding=1),
            nn.Flatten(),
            nn.Linear(obs_dim, latent_dim)
        )

        self.fc_encoder_2 = nn.Sequential( 
            *Conv1dBlock(64, hidden_dim, kernel_size=3, stride=2), 
            *Conv1dBlock(hidden_dim, hidden_dim, kernel_size=3, stride=2), 
            nn.Conv1d(hidden_dim, int(latent_dim / 16), kernel_size=3, stride=1, padding=1), 
            nn.Flatten(),
            nn.Linear(4 * 16, latent_dim),
        )

        self.fc_encoder_3 = nn.Sequential( 
            *Conv1dBlock(action_dim, hidden_dim, kernel_size=3, stride=2), 
            *Conv1dBlock(hidden_dim, hidden_dim, kernel_size=3, stride=2), 
            nn.Conv1d(hidden_dim, int(latent_dim / 16), kernel_size=3, stride=1, padding=1), 
            nn.Flatten(),
            nn.Linear(4 * 16, latent_dim),
        )

        self.bias_encoder = nn.Sequential(
            nn.Linear(self.bias_len, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, latent_dim)
        )
        self.fc = nn.Linear(latent_dim * 4, latent_dim * 2)
        
    def forward(self, params):
        fc_weights_1 = params['fc_1_weight'].squeeze()
        fc_weights_2 = params['fc_2_weight'].squeeze()
        fc_weights_3 = params['fc_3_weight'].squeeze()
        bias = params["bias"]
        
        fc_latent_1 = self.fc_encoder_1(fc_weights_1)
        fc_latent_2 = self.fc_encoder_2(fc_weights_2)
        fc_latent_3 = self.fc_encoder_3(fc_weights_3)
        bias_latent = self.bias_encoder(bias)
        
        concated_latents = torch.concat([fc_latent_1, fc_latent_2, fc_latent_3, bias_latent], dim=-1)
        latent = self.fc(concated_latents) # latent_dim * 2
        
        mean, log_variance = torch.chunk(latent, 2, dim=-1)
        log_variance = torch.clamp(log_variance, -30, 20) # TODO: set the proper clamp range
        variance = log_variance.exp()
        stdev = variance.sqrt()
        
        noise = torch.randn_like(mean)
        latent = mean + stdev * noise
        
        latent *= 0.18215
        return latent, mean, log_variance


class HyperDecoder(nn.Module):
    """Decoder incorparating GHN to reconstruct policy parameters from the latent representation"""
    def __init__(self, hyper_actor):
        super().__init__()
        self.policy = hyper_actor
        self.policy_optim = self.policy.optimizer
        
    def forward(self, z):
        self.policy.re_query_uniform_weights(z)
        params = []
        for model in self.policy.current_model:
            modules = list(model.classifier.children())
            params.append(params_from_modules(modules, idx=[0, 2, 4]))
        return params


class Decoder(nn.Module): 
    """Directly reconstruct policy parameters from the latent representation without GHN"""
    def __init__(self, bias_shapes, fc_shapes, latent_dim=64, batch_size=32):
        super().__init__()
        self.batch_size = batch_size
        self.bias_shapes = bias_shapes
        self.fc_shapes = fc_shapes
        self.bias_len = np.sum(bias_shapes)
        self.fc_weights_len_1 = fc_shapes[0][0] * fc_shapes[0][1]
        self.fc_weights_len_2 = fc_shapes[1][0] * fc_shapes[1][1]
        self.fc_weights_len_3 = fc_shapes[2][0] * fc_shapes[2][1]
        
        self.fc = nn.Sequential(
            nn.Linear(latent_dim, 256 * 4),
            nn.BatchNorm1d(256 * 4),
            nn.SiLU()
        )
        
        self.fc_decoder_1 = nn.Sequential(
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.SiLU(),
            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.SiLU(),
            nn.Linear(1024, self.fc_weights_len_1),
            nn.BatchNorm1d(self.fc_weights_len_1),
            nn.SiLU()
        )
        
        self.fc_decoder_2 = nn.Sequential(
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.SiLU(),
            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.SiLU(),
            nn.Linear(1024, self.fc_weights_len_2),
            nn.BatchNorm1d(self.fc_weights_len_2),
            nn.SiLU()
        )
        
        self.fc_decoder_3 = nn.Sequential(
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.SiLU(),
            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.SiLU(),
            nn.Linear(1024, self.fc_weights_len_3),
            nn.BatchNorm1d(self.fc_weights_len_3),
            nn.SiLU()
        )
        
        self.bias_decoder = nn.Sequential(
            nn.Linear(256, 1024),
            nn.BatchNorm1d(1024),
            nn.SiLU(),
            nn.Linear(1024, 512), 
            nn.BatchNorm1d(512),
            nn.SiLU(),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.SiLU(),
            nn.Linear(256, self.bias_len)
        )
    
    def forward(self, latent):
        concated_latents = self.fc(latent)
        fc_latent_1, fc_latent_2, fc_latent_3, bias_latent = torch.chunk(concated_latents, 4, dim=-1)
        fc_weights_1 = self.fc_decoder_1(fc_latent_1).reshape(self.batch_size, *self.fc_shapes[0]) # (32, 64, 62)
        fc_weights_2 = self.fc_decoder_2(fc_latent_2).reshape(self.batch_size, *self.fc_shapes[1]) # (32, 64, 64)
        fc_weights_3 = self.fc_decoder_3(fc_latent_3).reshape(self.batch_size, *self.fc_shapes[2]) # (32, 6, 64)
        bias = self.bias_decoder(bias_latent)
        
        recon_params = {'fc_1_weight': fc_weights_1, 'fc_2_weight': fc_weights_2, 'fc_3_weight': fc_weights_3, 'bias': bias}
        return recon_params
    

class VAE(nn.Module):
    def __init__(self, encoder=None, decoder=None):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        
    def forward(self, x):
        latent, mean, logvar = self.encoder(x)
        params = self.decoder(latent)
        return params, mean, logvar
