import torch
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self, n_layers, device='cpu'):
        super().__init__()

        num_mat = 7
        self.input_size = n_layers * (num_mat + 1)  # Material + thickness
        self.latent_dim = 3 * n_layers
        self.input_layer  = nn.Linear(self.input_size, 18 * n_layers)
        
        self.l1 = nn.Linear(18 * n_layers, 15 * n_layers)
        self.l2 = nn.Linear(15 * n_layers, 9 * n_layers)
        self.l3 = nn.Linear(9 * n_layers, 3 * n_layers)

        self.activation = nn.Tanh()

        self.to(device)
        
        self._init_weights()
    
    def _init_weights(self):
        for layer in [self.input_layer, self.l1, self.l2, self.l3]:
            nn.init.xavier_uniform_(layer.weight)
    
    def forward(self, x):
        x = self.activation(self.input_layer(x))
        x = self.activation(self.l1(x))
        x = self.activation(self.l2(x))
        x = self.l3(x)
        
        return x
    
class Decoder(nn.Module):
    def __init__(self, n_layers, device='cpu'):
        super().__init__()

        num_mat = 7
        self.l1 = nn.Linear(3 * n_layers, 9 * n_layers)
        self.l2 = nn.Linear(9 * n_layers, 15 * n_layers)
        self.l3 = nn.Linear(15 * n_layers, 18 * n_layers)

        self.material_out = nn.ModuleList([nn.Linear(18 * n_layers, num_mat) for _ in range(n_layers)])
        self.thickness_out = nn.Linear(18 * n_layers, n_layers)

        self.activation = nn.Tanh()
        self.softmax = nn.Softmax(dim=-1)
        self.relu = nn.ReLU()
        self.sig = nn.Sigmoid()

        self.to(device)
        
        self._init_weights()
    
    def _init_weights(self):
        for layer in [self.l1, self.l2, self.l3, self.thickness_out] + list(self.material_out):
            nn.init.xavier_uniform_(layer.weight) # type: ignore
    
    def forward(self, x):
        x = self.activation(self.l1(x))
        x = self.activation(self.l2(x))
        x = self.activation(self.l3(x))

        materials_out = [self.softmax(self.relu(layer(x))) for layer in self.material_out]
        thickness_out = self.sig(self.thickness_out(x)) 
        
        return torch.cat(materials_out + [thickness_out], dim=-1)
    


    
def train_AE(encoder, decoder, x_train, n_layer, epochs, learning_rate=0.001, device='cpu', L2_WEIGHT=3.0, REG_WEIGHT=1.0, batch_size=256, log=True):
    filename = None

    if log:
        filename = f"AE_LR={learning_rate}_BS={batch_size}_L2W-RW={L2_WEIGHT}-{REG_WEIGHT}"

    optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=learning_rate)
    losses = []

    #encoder.train()
    #decoder.train()

    num_samples = x_train.shape[0]

    for epoch in range(epochs):
        # Shuffle indices for every epoch
        indices = torch.randperm(num_samples, device=device)

        for i in range(0, num_samples, batch_size):
            batch_indices = indices[i:i + batch_size]
            batch = x_train[batch_indices]#.to(device, dtype=torch.float32)

            optimizer.zero_grad()
            z = encoder(batch)
            x_hat = decoder(z)


            reg_term = torch.square(x_hat[:, :-n_layer]).sum(dim=1).mean().div(n_layer).mul(-1)            
            loss_mse = torch.nn.functional.mse_loss(x_hat[:, -n_layer:], batch[:, -n_layer:])
            loss_ce = torch.nn.functional.nll_loss(torch.log(x_hat[:, :-n_layer].reshape(-1, 7)), batch[:, :-n_layer].reshape(-1, 7).argmax(dim=1))

            loss = loss_mse + loss_ce  + reg_term * REG_WEIGHT

            loss.backward()
            optimizer.step()


        #if epoch % 5 == 0:
        log_message = f"Epoch {epoch} Rec loss: {loss.item()}"
        print(log_message)

        #    if filename != None:
        #        with open(f"{filename}.txt", "a") as f:
        #            f.write(log_message + "\n")

    return filename
