import torch
import torch.nn as nn

def kl_divergence(mu, log_var):
    return  -0.5 * torch.mean(1 + log_var - torch.pow(mu, 2) - torch.exp(log_var))


class VAE(torch.nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()

        self.device = device
        self.encoder = encoder.to(device)
        self.decoder = decoder.to(device)
        self.latent_dim = encoder.latent_dim

        self.mu_layer = nn.Linear(self.latent_dim, self.latent_dim).to(device)
        self.var_layer = nn.Linear(self.latent_dim, self.latent_dim).to(device)

    def get_parameters(self):
        return [
            {"params": self.encoder.parameters()},
            {"params": self.decoder.parameters()},
            {"params": self.mu_layer.parameters()}, 
            {"params": self.var_layer.parameters()}
        ]

    def encode(self, inputs):
        '''
        Encode the inputs and return [mu, log_var, epsilon]

        Where epsilon is the result of the sampling layer
        '''
        encoded = self.encoder(inputs)

        mu = self.mu_layer(encoded)
        log_var = self.var_layer(encoded)
        var = torch.exp(0.5 * log_var)

        epsilon = torch.randn(mu.shape).to(self.device) * var + mu

        return mu, log_var, epsilon
    
    def decode(self, inputs):
        '''
        Takes as input epsilon (resulting from the sampling layer) and decodes it
        '''
        return self.decoder(inputs)

    def forward(self, inputs):
        # Takes only epsilon from encoder and decodes it
        return self.decode(self.encode(inputs)[2])
    


    '''def update_weights(W1, W2, L1, L2):
        if L1 > L2:
            return 1.0, 1 - (L2 * W2 / L1 * W1)
        else:
            return 1 - (L1 * W1 / L2 * W2), 1.0'''


    def train_model(self, simulator, train_data: tuple[torch.Tensor, torch.Tensor], val_data = None, epochs=100, learning_rate=0.001, batch_size=256, log=True,
        L2_WEIGHT=1.0, SIM_WEIGHT=1.0, KL_WEIGHT=1.0
    ):
        
        filename = None
        file = None
        if log:
            filename = f"VAE_LR={learning_rate}_BS={batch_size}_E={epochs}"
            file = open(filename, "w")


        params = self.get_parameters()
        params.append({"params": simulator.parameters()})
        optimizer = torch.optim.Adam(params, lr=learning_rate) # type: ignore

        x_train, y_train = train_data

        num_samples = x_train.shape[0]
        for epoch in range(epochs):
            # Shuffle indices for every epoch
            indices = torch.randperm(num_samples, device=self.device)

            l2_loss_epoch = 0.0
            sim_loss_epoch = 0.0
            kl_loss_epoch = 0.0
            num_batches = 0

            for i in range(0, num_samples, batch_size):
                batch_indices = indices[i:i + batch_size]
                batch_x = x_train[batch_indices]
                batch_y = y_train[batch_indices]

                optimizer.zero_grad()

                mu, log_var, epsilon = self.encode(batch_x)
                x_hat = self.decode(epsilon)

                loss_l2 = (x_hat - batch_x).square().mean()
                loss_sim = (simulator(epsilon) - batch_y).square().mean() 
                loss_kl = kl_divergence(mu, log_var)

                output = loss_l2 * L2_WEIGHT + loss_sim * SIM_WEIGHT + loss_kl * KL_WEIGHT

                output.backward()
                optimizer.step()

                l2_loss_epoch += loss_l2.item()
                sim_loss_epoch += loss_sim.item()
                kl_loss_epoch += loss_kl.item()
                num_batches += 1


            # The model is trained on the sum of the loss for the batch, but the loss shown is averaged for num batches and batch size
            l2_loss_epoch /= num_batches
            sim_loss_epoch /= num_batches
            kl_loss_epoch /= num_batches

            print(f"Epoch [{epoch+1}/{epochs}], Rec loss: {l2_loss_epoch:.5f}, Sim loss: {sim_loss_epoch:.5f}, KL loss: {kl_loss_epoch:.5f}", end='')



            if val_data is not None:
                x_val, y_val = val_data
                n_samples_val = x_val.shape[0]
                test_indices = torch.randperm(n_samples_val, device=self.device)

                val_loss_l2_epoch = 0.0
                val_loss_sim_epoch = 0.0
                val_loss_kl_epoch = 0.0
                num_batches = 0

                with torch.no_grad():
                    for i in range(0, n_samples_val, batch_size):
                        batch_indices = test_indices[i:i + batch_size]
                        batch_x_val = x_val[batch_indices]
                        batch_y_val = y_val[batch_indices]

                        mu_, logvar_, epsilon = self.encode(batch_x_val)
                        x_hat = self.decode(epsilon)

                        loss_l2 = (x_hat - batch_x_val).square().mean()
                        loss_sim = (simulator(epsilon) - batch_y_val).square().mean() 
                        loss_kl = kl_divergence(mu_, logvar_)

                        val_loss_l2_epoch += loss_l2.item()
                        val_loss_sim_epoch += loss_sim.item()
                        val_loss_kl_epoch += loss_kl.item()
                        num_batches += 1

                val_loss_l2_epoch /= num_batches
                val_loss_sim_epoch /= num_batches
                val_loss_kl_epoch /= num_batches
            
                print(f"   ---   Val Rec loss: {val_loss_l2_epoch:.5f}, Val sim loss: {val_loss_sim_epoch:.5f}, Val KL loss: {val_loss_kl_epoch:.5f}", end='')

            print()

            if file != None:
                print(f"Epoch [{epoch+1}/{epochs}], Rec loss: {l2_loss_epoch:.4f}, Sim loss: {sim_loss_epoch:.4f}, KL loss: {kl_loss_epoch:.4f}", end='', file=file)

                if x_val is not None and y_val is not None:
                    print(f"   ---   Val Rec loss: {val_loss_l2_epoch:.4f}, Val sim loss: {val_loss_sim_epoch:.4f}, Val KL loss: {val_loss_kl_epoch:.4f}", end='', file=file)

                print(file=file)
                file.flush()

        return filename
