import torch
from torch import optim
import numpy as np
import random
from configs.config import EPOCHS, LEARNING_RATE
from src.models.model import ENFORCE

def train_model(
        model: ENFORCE,
        train_inputs_tensor: torch.tensor,
        train_outputs_tensor: torch.tensor,
        batch_size: int,
        random_seed: int
        ) -> torch.nn.Module:
    """ Trains the model using the provided data.
        
        Parameters:
        model (torch.nn.Module): Model to train.
        train_inputs_tensor (torch.Tensor): Inputs for training.
        train_outputs_tensor (torch.Tensor): Outputs for training.
        random_seed (int): Random seed for reproducibility.
        
        Returns:
        torch.nn.Module: Trained model.    
    """
    torch.manual_seed(random_seed)
    np.random.seed(random_seed)
    random.seed(random_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.cuda.manual_seed_all(random_seed)

    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    epochs = EPOCHS
    model.epoch = 1
    opt_problem = False

    for model.epoch in range(epochs):
        # Shuffle the training data at the beginning of each epoch
        torch.manual_seed(random_seed + model.epoch)  # Change seed each epoch for determinism
        permutation = torch.randperm(train_inputs_tensor.size()[0])
        train_inputs_shuffled = train_inputs_tensor[permutation]
        train_outputs_shuffled = train_outputs_tensor[permutation]

        model.train()
        # Batch training
        # Implement batch training with mini-batches
        num_samples = train_inputs_shuffled.size(0)
        num_batches = (num_samples + batch_size - 1) // batch_size  # Ceiling division

        total_loss = 0.0
        total_loss_displacement = 0.0
        total_loss_data_after_projection = 0.0
        total_loss_data_before_projection = 0.0
        constraint_residuals_avg = 0.0
        constraint_residuals_max = 0.0
        objective_value_optimization = 0.0
        objective_value_prediction = 0.0
        for i in range(num_batches):
            start_idx = i * batch_size
            end_idx = min((i + 1) * batch_size, num_samples)
            
            batch_inputs = train_inputs_shuffled[start_idx:end_idx]
            batch_outputs = train_outputs_shuffled[start_idx:end_idx]
            
            optimizer.zero_grad()
            loss, loss_data_after_projection, loss_displacement, loss_data_before_projection, ytilde, yhat, proj_iter = model.loss(batch_inputs, batch_outputs)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            total_loss_displacement += loss_displacement.item()

            x_unscaled, y_unscaled = model.unscale(batch_inputs, batch_outputs)
            _, ytilde_unscaled = model.unscale(batch_inputs, ytilde)
            _, yhat_unscaled = model.unscale(batch_inputs, yhat)

            constraint_residuals = model.c(x_unscaled, ytilde_unscaled)
            constraint_residuals_avg += torch.mean(torch.abs(constraint_residuals)).item()
            constraint_residuals_max += torch.max(torch.abs(constraint_residuals)).item()
            try: # Learn optimization problem with SSL
                obj_value_opt = model.ssl_loss(x_unscaled, y_unscaled)
                obj_value_pred = model.ssl_loss(x_unscaled, ytilde_unscaled)
                objective_value_optimization += obj_value_opt.item()
                objective_value_prediction += obj_value_pred.item()
                total_loss_data_after_projection += obj_value_pred.item()
                total_loss_data_before_projection += model.ssl_loss(x_unscaled, yhat_unscaled).item()
                opt_problem = True
            except: # Regression problem
                total_loss_data_after_projection += loss_data_after_projection.item()
                total_loss_data_before_projection += loss_data_before_projection.item()


        # Average loss for this epoch
        avg_loss = total_loss / num_batches
        avg_loss_displacement = total_loss_displacement / num_batches
        avg_loss_data_after_projection = total_loss_data_after_projection / num_batches
        avg_loss_data_before_projection = total_loss_data_before_projection / num_batches
        avg_constraint_residuals_avg = constraint_residuals_avg / num_batches
        avg_constraint_residuals_max = constraint_residuals_max / num_batches
        try: # Learn optimization problem with SSL
            avg_objective_value_optimization = objective_value_optimization / num_batches
            avg_objective_value_prediction = objective_value_prediction / num_batches
        except: # Regression problem
            pass


        # Optionally, print progress
        if (model.epoch + 1) % 1 == 0 and opt_problem:
            print(f'Epoch {model.epoch + 1}/{epochs}, Loss: {avg_loss:.2e}, LossDispl:{avg_loss_displacement:.2e}, EqMeanResidual: {avg_constraint_residuals_avg:.2e}, EqMaxResidual: {avg_constraint_residuals_max:.2e}, ObjValueOpt: {avg_objective_value_optimization:.2e}, ObjValuePred: {avg_objective_value_prediction:.2e}')
        elif (model.epoch + 1) % 100 == 0:
            print(f'Epoch {model.epoch + 1}/{epochs}, Loss: {avg_loss:.2e}, LossDispl:{avg_loss_displacement:.2e}, EqMeanResidual: {avg_constraint_residuals_avg:.2e}, EqMaxResidual: {avg_constraint_residuals_max:.2e}')

        if not model.constrained:
            # Unconstrained case, only one loss
            model.losses.append({'loss_unconstrained': avg_loss_data_after_projection})
        else:
            model.losses.append({
                'loss_data_after_projection': avg_loss_data_after_projection,
                'loss_displacement': avg_loss_displacement,
                'loss_data_before_projection': avg_loss_data_before_projection,
                'projection_iterations': proj_iter
            })
            if opt_problem:
                model.losses[-1].update({
                    'loss_data_after_projection': avg_loss_data_after_projection,
                    'loss_displacement': avg_loss_displacement,
                    'loss_data_before_projection': avg_loss_data_before_projection,
                    'projection_iterations': proj_iter,
                    'objective_value_optimization': avg_objective_value_optimization,
                    'objective_value_prediction': avg_objective_value_prediction
                })

        model.epoch += 1

    return model
