import torch
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau


class ForwardSimulator(nn.Module):
    def __init__(
            self, 
            num_layer_material, 
            num_materials,
            latent_dim = None, # Needed when the simulator is used from the vae latent space
        ):
        super().__init__()

        self.n_layer = num_layer_material
        self.n = num_layer_material + (num_layer_material * num_materials)

        self.act_function = nn.LeakyReLU(negative_slope=0.2)
        

        if latent_dim == None:
            self.l1 = nn.Linear(self.n, 420)
        else:
            self.l1 = nn.Linear(latent_dim, 420)

        self.l2 = nn.Linear(420, 640)
        self.l3 = nn.Linear(640, 2001)
        self.l4 = nn.Linear(2001, 2001)



    def forward(self, x):
        x = self.l1(x)
        x = self.act_function(x)

        x = self.l2(x)
        x = self.act_function(x)

        x = self.l3(x)
        x = self.act_function(x)

        x = self.l4(x)
        x = self.act_function(x)

        return x
        
    

    def train_model(self, x_train, y_train, x_val=None, y_val=None, num_epochs=50, batch_size=128, learning_rate=0.01, scheduler_threshold=0.001, log_file=False):
        self.train()
        txt_file = None
        txt_file_name = None

        if log_file:
            test_shape = 0 if x_val is None else x_val.shape[0]
            txt_file_name = f"Nf_{x_train.shape[0]}train_{test_shape}val_{self.n_layer}matlay_{num_epochs}e_{batch_size}b_{learning_rate}lr"
            txt_file = open(f"{txt_file_name}.txt", "w")

        optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate, weight_decay=1e-3)
        scheduler = ReduceLROnPlateau(optimizer, "min", patience=10, threshold=scheduler_threshold)
        
        train_loss = []
        val_loss = []
        
        train_dl = torch.utils.data.DataLoader(
            list(zip(x_train, y_train)), batch_size=batch_size, shuffle=True, drop_last=True # type: ignore
        )

        test_dl = None
        if x_val is not None and y_val is not None:
            test_dl = torch.utils.data.DataLoader(
                list(zip(x_val, y_val)), batch_size=batch_size, drop_last=True # type: ignore
            )
        
        last_mse = 0.0
        last_val_mse = 0.0
        for epoch in range(num_epochs):
            epoch_loss = 0.0
            epoch_test_loss = 0.0

            for data, target in train_dl:
                optimizer.zero_grad()
             
                outputs = self(data)

                loss_value = (outputs - target).square().mean(dim=1).sum()
                loss_value.backward()
                optimizer.step()

                epoch_loss += loss_value.item()

            if test_dl != None:
                with torch.no_grad():
                    for data2, target2 in test_dl:
                        output_cat = self(data2)
                        epoch_test_loss += (output_cat - target2).square().mean(dim=1).sum()

            epoch_loss /= len(train_dl)
            epoch_loss /= batch_size

            if test_dl != None:
                epoch_test_loss /= len(test_dl)
                epoch_test_loss /= batch_size
            
            
            scheduler.step(epoch_test_loss)

            train_loss.append(epoch_loss)  
            val_loss.append(epoch_test_loss)

            #print(f"Epoch [{epoch + 1}/{num_epochs}], MSE: {epoch_loss:.6f}, Val MSE: {epoch_test_loss:.6f}")

            if txt_file != None:
                print(f"Epoch [{epoch + 1}/{num_epochs}], MSE: {epoch_loss:.6f}, Val MSE: {epoch_test_loss:.6f}", file=txt_file)
                txt_file.flush()

            last_mse = epoch_loss
            last_val_mse = epoch_test_loss


        return (train_loss, val_loss), txt_file_name if txt_file != None else None, (last_mse, last_val_mse)