import torch
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self, n_layers, n_materials, device='cpu'):
        super().__init__()

        self.input_size = n_layers * (n_materials + 1)  # Material + thickness

        self.input_layer  = nn.Linear(self.input_size, 18 * n_layers)
        
        self.l1 = nn.Linear(18 * n_layers, 15 * n_layers)
        self.l2 = nn.Linear(15 * n_layers, 9 * n_layers)
        self.l3 = nn.Linear(9 * n_layers, 3 * n_layers)

        self.activation = nn.Tanh()

        self.to(device)
        
        self._init_weights()
    
    def _init_weights(self):
        for layer in [self.input_layer, self.l1, self.l2, self.l3]:
            nn.init.xavier_uniform_(layer.weight)
    
    def forward(self, x):
        x = self.activation(self.input_layer(x))
        x = self.activation(self.l1(x))
        x = self.activation(self.l2(x))
        x = self.l3(x)
        
        return x
    
class Decoder(nn.Module):
    def __init__(self, n_layers, n_materials,device='cpu'):
        super().__init__()

        self.l1 = nn.Linear(3 * n_layers, 9 * n_layers)
        self.l2 = nn.Linear(9 * n_layers, 15 * n_layers)
        self.l3 = nn.Linear(15 * n_layers, 18 * n_layers)

        self.material_out = nn.ModuleList([nn.Linear(18 * n_layers, n_materials) for _ in range(n_layers)])
        self.thickness_out = nn.Linear(18 * n_layers, n_layers)

        self.activation = nn.Tanh()
        self.softmax = nn.Softmax(dim=-1)
        self.relu = nn.ReLU()

        self.to(device)
        
        self._init_weights()
    
    def _init_weights(self):
        for layer in [self.l1, self.l2, self.l3, self.thickness_out] + list(self.material_out):
            nn.init.xavier_uniform_(layer.weight) # type: ignore
    
    def forward(self, x):
        x = self.activation(self.l1(x))
        x = self.activation(self.l2(x))
        x = self.activation(self.l3(x))

        materials_out = [self.softmax(self.relu(layer(x))) for layer in self.material_out]
        thickness_out = self.activation(self.thickness_out(x)) 
        
        return torch.cat(materials_out + [thickness_out], dim=-1)
    

def train_AE(encoder, decoder, x_train, n_layer, epochs, learning_rate=0.001, device='cpu', L2_WEIGHT=3, REG_WEIGHT=1, batch_size=256, log=True):
    filename = None

    if log:
        filename = f"AE_LR={learning_rate}_BS={batch_size}_L2W-RW={L2_WEIGHT}-{REG_WEIGHT}"

    optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=learning_rate)
    losses = []

    encoder.train()
    decoder.train()

    encoder.to(device)
    decoder.to(device)

    num_samples = x_train.shape[0]

    for epoch in range(epochs):
        # Shuffle indices for every epoch
        indices = torch.randperm(num_samples, device=device)

        for i in range(0, num_samples, batch_size):
            batch_indices = indices[i:i + batch_size]
            batch = x_train[batch_indices].to(device, dtype=torch.float32)

            optimizer.zero_grad()
            z = encoder(batch)
            x_hat = decoder(z)

            reg_term = torch.square(x_hat[:, :-n_layer]).sum(dim=1).mean().div(n_layer).mul(-1)
            loss_mse = torch.nn.functional.mse_loss(x_hat[:, -n_layer:], batch[:, -n_layer:])
            loss_ce = torch.nn.functional.nll_loss(torch.log(x_hat[:, :-n_layer].reshape(-1, 5)), batch[:, :-n_layer].reshape(-1, 5).argmax(dim=1))

            output = loss_mse + loss_ce + reg_term * REG_WEIGHT
            output.backward()
            optimizer.step()

            losses.append(output.item())

        if epoch % 5 == 0:
            log_message = f"Epoch {epoch} Loss: {output.item()}"

            print(log_message)
            if filename != None:
                with open(f"{filename}.txt", "a") as f:
                    f.write(log_message + "\n")

    return filename
