import torch
import torch.nn as nn
from typing import Literal
from torch.optim.lr_scheduler import ReduceLROnPlateau

def Nf_layer_config(conf_type, n):
    if conf_type == 1:
        return [n*2, n*4, n*5]
    
    if conf_type == 2:
        return [int(n*1.4), 50, 60]
    
    raise Exception("Invalid layer config provided")


class ForwardSimulator(nn.Module):
    def __init__(
            self, 
            num_layer_material, 
            layer_config = 1, 
            latent_dim = None,
            activation: Literal['relu', 'tanh'] = 'relu', 
            initialization: Literal['glorot_normal', 'random_normal'] = 'glorot_normal'
        ):
        super().__init__()

        self.n_layer = num_layer_material
        self.n = num_layer_material + (num_layer_material * 5)
        self.layer_config = layer_config
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        if activation == 'relu':
            self.act_function = nn.LeakyReLU()
        if activation == 'tanh':
            self.act_function = nn.Tanh()

        if initialization == 'glorot_normal':
            self.init_function = nn.init.xavier_normal_
        if initialization == 'random_normal':
            self.init_function = nn.init.normal_

        self.__construct_model__(latent_dim)

    # Building block for 1 out of the 4 convolutional heads of the simulator
    def __conv__block__(self, dense_dim):
        # Keep track of the tensor dimension
        filter_size = 4

        dense_dim = dense_dim# - (filter_size - 1)
        c1 = nn.Conv1d(1, 64, filter_size, padding='same')
        #self.init_function(c1.weight)

        dense_dim = int(dense_dim / 2)
        p1 = nn.MaxPool1d(2)

        dense_dim = dense_dim# - (filter_size - 1)
        c2 = nn.Conv1d(64, 128, filter_size, padding='same')
        #self.init_function(c2.weight)

        dense_dim = int(dense_dim / 2)
        p2 = nn.MaxPool1d(2)

        f = nn.Flatten()
        d = nn.Linear(128 * dense_dim, 600)
        #self.init_function(d.weight)

        return [c1, p1, c2, p2, f, d]

 
    def __construct_model__(self, latent_dim = None):
        layer_config = Nf_layer_config(self.layer_config, self.n)

        # Construct the first 3 dense layers of the model based on the chosen configuration
        self.dense_block = nn.ModuleList([
            nn.Linear(self.n if latent_dim == None else latent_dim, layer_config[0]),
            nn.Linear(layer_config[0], layer_config[1]),
            nn.Linear(layer_config[1], layer_config[2]),
        ])

        # Apply initialization function to layers
        for l in self.dense_block:
            self.init_function(l.weight) # type: ignore

        # 4 parallel heads for reflectance p,s and transmittance p,s
        # Each head has 600 output neuros, which contains 3x200 points, one for each angle
        self.rs = nn.ModuleList(self.__conv__block__(layer_config[2]))
        self.rp = nn.ModuleList(self.__conv__block__(layer_config[2]))
        self.ts = nn.ModuleList(self.__conv__block__(layer_config[2]))
        self.tp = nn.ModuleList(self.__conv__block__(layer_config[2]))


    def forward(self, x):
        # From input to output of third dense layer
        # dense_block is a list of Linear layers
        for layer in self.dense_block:
            x = self.act_function(layer(x))

        # From BATCH_SIZE x NUM_NEURON to BATCH_SIZE x 1 x NUM_NEURON
        # Needed for 1d Convolution 
        x = torch.unsqueeze(x, 1)

        # From output of dense block, to first convolutional layer
        rs = self.act_function(self.rs[0](x))
        rp = self.act_function(self.rp[0](x))
        ts = self.act_function(self.ts[0](x))
        tp = self.act_function(self.tp[0](x))
    
        for l in range(1, len(self.rs) - 1):
            # Apply the l-th layer to rs, rp... until the Flatten layer
            rs = self.act_function(self.rs[l](rs))
            rp = self.act_function(self.rp[l](rp))
            ts = self.act_function(self.ts[l](ts))
            tp = self.act_function(self.tp[l](tp))

        # Apply the last dense layer, with the activation function
        rs = self.rs[-1](rs)
        rp = self.rp[-1](rp)
        ts = self.ts[-1](ts)
        tp = self.tp[-1](tp)

        return torch.cat([rs, rp, ts, tp], -1)
    

    
    def get_loss(self, y_output, y_expected):
        # Get the loss for all 12 heads in batch i
        losses = []
        loss_fn = torch.nn.MSELoss()
        start = 0
        end = 200
        while end <= 2400:
            slice_y = y_output[:, start:end]
            slice_y_exp = y_expected[:, start:end]
            
            loss = loss_fn(slice_y, slice_y_exp).item()
            losses.append(loss)
            
            start += 200
            end += 200

            
        return losses

    def train_model(self, x_train, y_train, x_val=None, y_val=None, num_epochs=50, batch_size=128, learning_rate=0.01, scheduler_threshold=0.001, log_file=False):
        txt_file = None
        txt_file_name = None
        if log_file:
            test_shape = 0 if x_val is None else x_val.shape[0]
            txt_file_name = f"Nf_{x_train.shape[0]}train_{test_shape}val_{self.n_layer}matlay_{num_epochs}e_{batch_size}b_{learning_rate}lr"
            txt_file = open(f"{txt_file_name}.txt", "w")

        optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
        scheduler = ReduceLROnPlateau(optimizer, "min", patience=10, threshold=scheduler_threshold)
        
        train_loss = []
        test_loss = []
        
        train_heads_losses = []
        test_heads_losses = []

        train_dl = torch.utils.data.DataLoader(
            list(zip(x_train, y_train)), batch_size=batch_size, shuffle=True, drop_last=False # type: ignore
        )

        test_dl = None
        if x_val is not None and y_val is not None:
            test_dl = torch.utils.data.DataLoader(
                list(zip(x_val, y_val)), batch_size=batch_size, drop_last=False # type: ignore
            )
        
        for epoch in range(num_epochs):
            epoch_loss = 0.0
            epoch_test_loss = 0.0
            epoch_mse = 0.0
            epoch_test_mse = 0.0
            epoch_rmse = 0.0
            epoch_test_rmse = 0.0
            
            train_heads_losses.append([0.0 for _ in range(12)])
            test_heads_losses.append([0.0 for _ in range(12)])

            for data, target in train_dl:
                optimizer.zero_grad()
                # This returns (rs, rp, ts, tp)
                # Each head has length 3x200
                outputs = self(data)

                loss_value = torch.norm(outputs - target, p=1)
                mse = nn.functional.mse_loss(outputs, target)
                rmse = (outputs - target).square().mean().sqrt()
                
                # Compute the individual loss for every head of the model
                head_losses = self.get_loss(outputs, target)
                for i in range(0, 12):
                    train_heads_losses[-1][i] += head_losses[i]
                
                loss_value.backward()
                optimizer.step()

                epoch_loss += loss_value.item()
                epoch_mse += mse.item()
                epoch_rmse += rmse.item()

            if test_dl != None:
                with torch.no_grad():
                    for data2, target2 in test_dl:
                        output_cat = self(data2)
                        epoch_test_loss += torch.norm(output_cat - target2, p=1).item()
                        epoch_test_mse += nn.functional.mse_loss(output_cat, target2).item()
                        epoch_test_rmse += (output_cat - target2).square().mean().sqrt()
                        
                        head_losses = self.get_loss(output_cat, target2)
                        for i in range(0, 12):
                            test_heads_losses[-1][i] += head_losses[i]


            epoch_mse /= len(train_dl)
            epoch_loss /= len(train_dl)
            epoch_rmse /= len(train_dl)

            if test_dl != None:
                epoch_test_mse /= len(test_dl)
                epoch_test_loss /= len(test_dl)
                epoch_test_rmse /= len(test_dl)
            
            
            scheduler.step(epoch_test_mse)

            train_loss.append(epoch_loss)  
            test_loss.append(epoch_test_loss)
        
            print(f"Epoch [{epoch + 1}/{num_epochs}],      Loss: {epoch_loss:0f},       MSE: {epoch_mse:6f},       RMSE: {epoch_rmse:6f},       lr: {scheduler.get_last_lr()}")

            if test_dl != None:
                print(f"{' ' * len(str(epoch))}               Val loss: {epoch_test_loss:0f},   Val MSE: {epoch_test_mse:6f},   Val RMSE: {epoch_test_rmse:6f}")

            print()


            if txt_file != None:
                print(f"Epoch [{epoch + 1}/{num_epochs}],      Loss: {epoch_loss:0f},       MSE: {epoch_mse:6f},       RMSE: {epoch_rmse:6f},       lr: {scheduler.get_last_lr()}", file=txt_file)

                if test_dl != None:
                    print(f"{' ' * len(str(epoch))}                Val loss: {epoch_test_loss:0f},  Val MSE: {epoch_test_mse:6f},  Val RMSE: {epoch_test_rmse:6f}", file=txt_file)

                print(file=txt_file)

                txt_file.flush()


            
        # Compute the mean (on the number of batches) for every loss
        for i in range(0, num_epochs):
            for j in range(12):
                train_heads_losses[i][j] /= len(train_dl)
                
                if test_dl != None:
                    test_heads_losses[i][j] /= len(test_dl)
                    
        train_heads_losses = torch.tensor(train_heads_losses, dtype=torch.float32)
        test_heads_losses = torch.tensor(test_heads_losses, dtype=torch.float32)

        return (train_loss, test_loss, train_heads_losses, test_heads_losses), txt_file_name if txt_file != None else None
    