import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


class MLP(nn.Module):
    def __init__(self, layers, activation=torch.nn.functional.gelu):
        super(MLP, self).__init__()
        assert len(layers)>=2
        self.layers = nn.ModuleList()
        for k in range(len(layers)-2):
          self.layers.append(nn.Linear(layers[k], layers[k+1]))
        self.layers.append(nn.Linear(layers[-2], layers[-1]))
        self.activation = activation

    def forward(self, x):
        for l in self.layers[:-1]:
          x = l(x)
          x = self.activation(x)
        x = self.layers[-1](x)
        return x

class VAE2d64(nn.Module):
    def __init__(self):
        super(VAE2d64, self).__init__()
        
        # Encoder part using convolutional layers
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1)  # Output: [16, 32, 32]
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1)  # Output: [32, 16, 16]
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)  # Output: [64, 8, 8]

        # Mu and log-variance layers for z
        self.fc_mu = nn.Conv2d(64, 4, kernel_size=1)  # Output z_mu: [4, 8, 8]
        self.fc_log_var = nn.Conv2d(64, 4, kernel_size=1)  # Output z_log_var: [4, 8, 8]
        
        # Decoder part using transposed convolutions
        self.convT1 = nn.ConvTranspose2d(4, 64, kernel_size=3, stride=2, padding=1, output_padding=1)  # [64, 16, 16]
        self.convT2 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)  # [32, 32, 32]
        self.convT3 = nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1)  # [16, 64, 64]
        self.final_layer = nn.Conv2d(16, 1, kernel_size=1)  # [1, 64, 64]
    
    def encoder(self, x):
        x = torch.reshape(x, (x.shape[0], 1, 64, 64))
        h = F.relu(self.conv1(x))
        h = F.relu(self.conv2(h))
        h = F.relu(self.conv3(h))
        return self.fc_mu(h), self.fc_log_var(h)  # mu, log_var
    
    def sampling(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu)  # return z sample
        
    def decoder(self, z):
        h = F.relu(self.convT1(z))
        h = F.relu(self.convT2(h))
        h = F.relu(self.convT3(h))
        return torch.sigmoid(self.final_layer(h))
    
    def forward(self, x):
        x = torch.reshape(x, (x.shape[0], 1, 64, 64))
        mu, log_var = self.encoder(x)
        z = self.sampling(mu, log_var)
        return self.decoder(z), mu, log_var, z.reshape(x.shape[0],-1)


class VAE2d32(nn.Module):
    def __init__(self):
        super(VAE2d32, self).__init__()
        
        # Encoder part using convolutional layers
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1)  # Output: [16, 16, 16]
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1)  # Output: [32, 8, 8]
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)  # Output: [64, 4, 4]
        # Mu and log-variance layers for z
        self.fc_mu = nn.Conv2d(64, 16, kernel_size=1)  # Output z_mu: [16, 4, 4]
        self.fc_log_var = nn.Conv2d(64, 16, kernel_size=1)  # Output z_log_var: [16, 4, 4]
        
        # Decoder part using transposed convolutions
        self.convT1 = nn.ConvTranspose2d(16, 64, kernel_size=3, stride=2, padding=1, output_padding=1)  # [64, 8, 8]
        self.convT2 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)  # [32, 16, 16]
        self.convT3 = nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1)  # [16, 32, 32]
        self.final_layer = nn.Conv2d(16, 1, kernel_size=1)  # [1, 32, 32]
    
    def encoder(self, x):
        x = torch.reshape(x, (x.shape[0], 1, 32, 32))
        h = F.relu(self.conv1(x))
        h = F.relu(self.conv2(h))
        h = F.relu(self.conv3(h))
        return self.fc_mu(h), self.fc_log_var(h)  # mu, log_var
    
    def sampling(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu)  # return z sample
        
    def decoder(self, z):
        h = F.relu(self.convT1(z))
        h = F.relu(self.convT2(h))
        h = F.relu(self.convT3(h))
        return torch.sigmoid(self.final_layer(h))
    
    def forward(self, x):
        x = torch.reshape(x, (x.shape[0], 1, 32, 32))
        mu, log_var = self.encoder(x)
        z = self.sampling(mu, log_var)
        return self.decoder(z), mu, log_var, z.reshape(x.shape[0],-1)

class VAE(nn.Module):
    def __init__(self, x_dim, h_dim1, h_dim2, z_dim):
        super(VAE, self).__init__()
        
        # encoder part
        self.fc1 = nn.Linear(x_dim, h_dim1)
        self.fc2 = nn.Linear(h_dim1, h_dim2)
        self.fc31 = nn.Linear(h_dim2, z_dim)
        self.fc32 = nn.Linear(h_dim2, z_dim)
        # decoder part
        self.fc4 = nn.Linear(z_dim, h_dim2)
        self.fc5 = nn.Linear(h_dim2, h_dim1)
        self.fc6 = nn.Linear(h_dim1, x_dim)
        
    def encoder(self, x):
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        return self.fc31(h), self.fc32(h) # mu, log_var
    
    def sampling(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu) # return z sample
        
    def decoder(self, z):
        h = F.relu(self.fc4(z))
        h = F.relu(self.fc5(h))
        return F.sigmoid(self.fc6(h)) 
    
    def forward(self, x):
        mu, log_var = self.encoder(x.view(-1, 32*32))
        z = self.sampling(mu, log_var)
        return self.decoder(z), mu, log_var, z



class ScoreNetwork(nn.Module):
    def __init__(self, condition_dim=10, chs=[32, 64, 128], u=True):
        super().__init__()
        self.condition_dim = condition_dim
        self.chs = chs
        if u:
            self.layers = [self.condition_dim, 200, 200, 200, 200, 200, 1024]
            self.branch =  MLP(self.layers)
        
        self._convs = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(3, chs[0], kernel_size=3, padding=1),
                nn.LogSigmoid(),
            ),
            nn.Sequential(
                nn.MaxPool2d(kernel_size=2, stride=2),
                nn.Conv2d(chs[0], chs[1], kernel_size=3, padding=1),
                nn.LogSigmoid(),
            ),
            nn.Sequential(
                nn.MaxPool2d(kernel_size=2, stride=2),
                nn.Conv2d(chs[1], chs[2], kernel_size=3, padding=1),
                nn.LogSigmoid(),
            ),
        ])
        self._tconvs = nn.ModuleList([
            nn.Sequential(
                nn.ConvTranspose2d(chs[2], chs[1], kernel_size=3, stride=2, padding=1, output_padding=1),
                nn.LogSigmoid(),
            ),
            nn.Sequential(
                nn.ConvTranspose2d(chs[1] * 2, chs[0], kernel_size=3, stride=2, padding=1, output_padding=1),
                nn.LogSigmoid(),
            ),
            nn.Sequential(
                nn.Conv2d(chs[0] * 2, chs[0], kernel_size=3, padding=1),
                nn.LogSigmoid(),
                nn.Conv2d(chs[0], 1, kernel_size=3, padding=1),
            ),
        ])

    def forward(self, x: torch.Tensor, t: torch.Tensor, u=None) -> torch.Tensor:
        x2 = torch.reshape(x, (*x.shape[:-1], 1, 32, 32))
        tt = t[..., None, None].expand(*t.shape[:-1], 1, 32, 32)
        if u is not None:
            uu = self.branch(u).reshape(-1, 1, 32, 32)
            signal = torch.cat((x2, tt, uu), dim=-3)
        else:
            signal = torch.cat((x2, tt), dim=-3)
        signals = []
        for i, conv in enumerate(self._convs):
            signal = conv(signal)
            if i < len(self._convs) - 1:
                signals.append(signal)

        for i, tconv in enumerate(self._tconvs):
            if i == 0:
                signal = tconv(signal)
            else:
                signal = torch.cat((signal, signals[-i]), dim=-3)
                signal = tconv(signal)
        signal = torch.reshape(signal, (*signal.shape[:-3], -1))
        return signal
    
class ScoreNetwork256(nn.Module):
    def __init__(self, chs=[32, 64, 128]):
        super().__init__()
        self.chs = chs
        self._convs = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(12, chs[0], kernel_size=3, padding=1),
                nn.LogSigmoid(),
            ),
            nn.Sequential(
                nn.MaxPool2d(kernel_size=2, stride=2),
                nn.Conv2d(chs[0], chs[1], kernel_size=3, padding=1),
                nn.LogSigmoid(),
            ),
            nn.Sequential(
                nn.MaxPool2d(kernel_size=2, stride=2),
                nn.Conv2d(chs[1], chs[2], kernel_size=3, padding=1),
                nn.LogSigmoid(),
            ),
        ])
        self._tconvs = nn.ModuleList([
            nn.Sequential(
                nn.ConvTranspose2d(chs[2], chs[1], kernel_size=3, stride=2, padding=1, output_padding=1),
                nn.LogSigmoid(),
            ),
            nn.Sequential(
                nn.ConvTranspose2d(chs[1] * 2, chs[0], kernel_size=3, stride=2, padding=1, output_padding=1),
                nn.LogSigmoid(),
            ),
            nn.Sequential(
                nn.Conv2d(chs[0] * 2, chs[0], kernel_size=3, padding=1),
                nn.LogSigmoid(),
                nn.Conv2d(chs[0], 1, kernel_size=3, padding=1),
            ),
        ])

    def forward(self, x: torch.Tensor, t: torch.Tensor, u=None) -> torch.Tensor:
        batch_size = x.shape[0]
        x = x.reshape(batch_size, 4, 8, 8)
        t = t.reshape(-1)
        t = torch.tile(t, (batch_size, 4, 8, 8))
        if u is not None:
            u = u.reshape(-1, 4, 8, 8)
            signal = torch.cat((x, t, u), dim=-3)
        else:
            signal = torch.cat((x, t), dim=-3)
        signals = []
        for i, conv in enumerate(self._convs):
            signal = conv(signal)
            if i < len(self._convs) - 1:
                signals.append(signal)

        for i, tconv in enumerate(self._tconvs):
            if i == 0:
                signal = tconv(signal)
            else:
                signal = torch.cat((signal, signals[-i]), dim=-3)
                signal = tconv(signal)
        signal = signal.reshape(batch_size,-1)
        return signal
    
class ScoreNetwork_z(nn.Module):
    def __init__(self, input_dim=10, u_dim=10, hidden_dims=[64, 128, 256], u=True, condition=False):
        super().__init__()
        self.z_dim = input_dim
        self.condition = condition
        if u:
            self.u_dim = u_dim
            if condition:
                u_embed_dim = 20
                self.initial_dim = input_dim + 1 + u_embed_dim  
                self.u_MLP = MLP([1,200,200,200,200,200,u_embed_dim])
            else:
        
                self.initial_dim = input_dim + 1 + u_dim 
    
        else:
            self.initial_dim = input_dim + 1
        self.fc_in = nn.Linear(self.initial_dim, hidden_dims[0])

      
        self.downsamples = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_dims[0], hidden_dims[1]),
                nn.LogSigmoid(),
            ),
            nn.Sequential(
                nn.Linear(hidden_dims[1], hidden_dims[2]),
                nn.LogSigmoid(),
            ),
        ])

 
        self.upsamples = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_dims[2], hidden_dims[1]),
                nn.LogSigmoid(),
            ),
            nn.Sequential(
                nn.Linear(hidden_dims[1], hidden_dims[0]),  
                nn.LogSigmoid(),
            ),
            nn.Sequential(
                nn.Linear(hidden_dims[0], input_dim),   
            ),
        ])

    def forward(self, x: torch.Tensor, t: torch.Tensor, u=None) -> torch.Tensor:
        if u is not None:
            u = u.view(-1, self.u_dim)
            if self.condition:
                u_embed = self.u_MLP(u)
                x = torch.cat((x, t, u_embed), dim=-1)
            else:
                x = torch.cat((x, t, u), dim=-1)
        else:
            x = torch.cat((x, t), dim=-1)
        x = self.fc_in(x)

        for downsample in self.downsamples:
            x = downsample(x)

        for upsample in self.upsamples:
            x = upsample(x)

        return x

class ScoreNetwork_z256(nn.Module):
    def __init__(self, input_dim=256, u_dim=256, hidden_dims=[512, 1024, 1024, 1024, 1024]):
        super().__init__()
        self.z_dim = input_dim
        self.u_dim = u_dim
 
        self.initial_dim = input_dim + 1 + u_dim  

        self.fc_in = nn.Linear(self.initial_dim, hidden_dims[0])

        
        self.downsamples = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_dims[0], hidden_dims[1]),
                nn.ReLU(),
            ),
            nn.Sequential(
                nn.Linear(hidden_dims[1], hidden_dims[2]),
                nn.ReLU(),
            ),
             nn.Sequential(
                nn.Linear(hidden_dims[2], hidden_dims[3]),
                nn.ReLU(),
            ),
             nn.Sequential(
                nn.Linear(hidden_dims[3], hidden_dims[4]),
                nn.ReLU(),
            ),
        ])

   
        self.upsamples = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_dims[4], hidden_dims[3]),
                nn.ReLU(),
            ),
            nn.Sequential(
                nn.Linear(hidden_dims[3], hidden_dims[2]),
                nn.ReLU(),
            ),
            nn.Sequential(
                nn.Linear(hidden_dims[2], hidden_dims[1]),
                nn.ReLU(),
            ),
            nn.Sequential(
                nn.Linear(hidden_dims[1], hidden_dims[0]),  
                nn.ReLU(),
            ),
            nn.Sequential(
                nn.Linear(hidden_dims[0], input_dim),  
            ),
        ])

    def forward(self, x: torch.Tensor, t: torch.Tensor, u=None) -> torch.Tensor:
        if u is not None:
            u = u.view(-1, self.u_dim)
            x = torch.cat((x, t, u), dim=-1)
        else:
            x = torch.cat((x, t), dim=-1)
        x = self.fc_in(x)

        for downsample in self.downsamples:
            x = downsample(x)

        for upsample in self.upsamples:
            x = upsample(x)

        return x
    

        
class ScoreNetwork_z_only(nn.Module):
    def __init__(self, input_dim=10, hidden_dims=[64, 128, 256]):
        super().__init__()
        self.input_dim = input_dim
       
        self.hidden_dims = hidden_dims
    
        self.z_dim = input_dim
        self.initial_dim = input_dim + 1 

        self.fc_in = nn.Linear(self.initial_dim, hidden_dims[0])
    


     
        self.downsamples = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_dims[0], hidden_dims[1]),
                nn.LogSigmoid(),
            ),
            nn.Sequential(
                nn.Linear(hidden_dims[1], hidden_dims[2]),
                nn.LogSigmoid(),
            ),
        ])

  
        self.upsamples = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_dims[2], hidden_dims[1]),
                nn.LogSigmoid(),
            ),
            nn.Sequential(
                nn.Linear(hidden_dims[1], hidden_dims[0]), 
                nn.LogSigmoid(),
            ),
            nn.Sequential(
                nn.Linear(hidden_dims[0], input_dim), 
         
            ),
        ])

    def forward(self, x: torch.Tensor, t: torch.Tensor, u=None) -> torch.Tensor:

        x = torch.cat((x, t), dim=-1)
        x = self.fc_in(x)

        for downsample in self.downsamples:
            x = downsample(x)

        for upsample in self.upsamples:
            x = upsample(x)

        return x

class UNet1D(nn.Module):
    def __init__(self):
        super(UNet1D, self).__init__()
    
        self.encoder = nn.Sequential(
            nn.Linear(11, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU()
        )
   
        self.decoder = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 10)
        )
        
    def forward(self, x, t, u=None):
        x = torch.cat([x,t],dim=-1)
        encoding = self.encoder(x)
        decoding = self.decoder(encoding)
        return decoding
    
class LatentScoreOperator(nn.Module):
    def __init__(self, x_dim, h_dim1, h_dim2, z_dim, num_examples, num_samples, u_dim=10, u=True, condition=False, sigma_train=False):
        super().__init__()
        self.vae = VAE(x_dim, h_dim1, h_dim2, z_dim)
        self.scorenet = ScoreNetwork_z(u_dim = u_dim, u=u, condition=condition)
        self.num_examples = num_examples
        self.num_samples = num_samples
        self.z_dim = z_dim
        self.sigma_train = sigma_train
        if sigma_train:
            self.sigma = torch.nn.parameter.Parameter(torch.tensor(15.0))

    def forward(self):
        return

class LatentScoreOperator2D(nn.Module):
    def __init__(self, num_examples, num_samples, u_dim=256):
        super().__init__()
        self.vae = VAE2d64()
        self.scorenet = ScoreNetwork_z256(u_dim = u_dim)
        self.num_examples = num_examples
        self.num_samples = num_samples
        self.z_dim = u_dim

    def forward(self):
        return
