# main.py
import os
import numpy as np
import torch
import random
import pandas as pd
import matplotlib.pyplot as plt
from datetime import datetime
import json
import time

from configs.config import (
    PLOT, SAVE, N, DATA_TRAINING, DATA_TEST, LEFT_LIMIT, RIGHT_LIMIT, OUTPUT_NEURONS,
    EPOCHS, SOFT_CONSTRAINED, WEIGHT_LOSS_SOFT, TRAINING_DIR, FIX_SEED, PROJ_WEIGHTING_OPTION,
    INPUT_DATA_PATH, OUTPUT_DATA_PATH, PARAMS_PATH, SUPERVISED, BATCH_SIZE, PROBLEM, MODEL
)
from src.constraints.functions import get_functions
from src.constraints.constraints import get_constraints
from src.data.data_utils import generate_data, scale_data
from src.models.model import ENFORCE
from src.engines.train import train_model
from src.engines.evaluate import evaluate_model
from src.visualization.plotting import plot_all_results
from src.constraints.opt_problem import Dataloader, NonlinearProgram, NonconvexProgram
from src.models.ssl_loss import SSLLoss, SSLConfig

import torch.autograd.profiler as profiler
PROFILER = False  # Set to True to enable profiling

data_seed = 41  # Fixed seed for reproducibility
np.random.seed(data_seed)
torch.manual_seed(data_seed)
random.seed(data_seed)

# Generate seeds
if not FIX_SEED:
    seeds = [random.randint(0, 100000) for _ in range(N)]
else: seeds = data_seed

def main():
    opt_checks = ['nonconvex_linear', 'nonconvex_nonlinear']

    if MODEL == "BOTH":
        mode_arg = 'both'
    elif MODEL == "ENFORCE":
        mode_arg = 'constrained'
    elif MODEL == "MLP":
        mode_arg = 'unconstrained'
    elif mode_arg not in ['constrained', 'unconstrained', 'both']:
        raise ValueError("Invalid model. Choose 'ENFORCE', 'MLP', or 'BOTH'.")

    modes_to_run = ['constrained', 'unconstrained'] if mode_arg == 'both' else [mode_arg]

    # Update output directory
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using {device} device")
    training_dir = os.path.join(TRAINING_DIR, PROBLEM)
    if not SUPERVISED:
        if PROBLEM == 'nonconvex_linear':
            training_dir = os.path.join(training_dir, 'nonconvex') 
        elif PROBLEM == 'nonconvex_nonlinear':
            training_dir = os.path.join(training_dir, 'nonlinear')

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_dir = os.path.join(training_dir, f"run_{timestamp}")
    if SAVE:
        os.makedirs(output_dir, exist_ok=True)

        print(f"Output directory: {output_dir}")

        # Create subdirectories
        losses_dir = os.path.join(output_dir, "losses")
        test_set_dir = os.path.join(output_dir, "test_dataset")
        os.makedirs(losses_dir, exist_ok=True)
        os.makedirs(test_set_dir, exist_ok=True)
        os.makedirs(os.path.join(output_dir, "plots"), exist_ok=True)
        os.makedirs(os.path.join(output_dir, "test_predictions"), exist_ok=True)
        scaling_param_path = os.path.join(output_dir, "scaling.json")    

        # Create a readme.txt file with metadata
        readme_content = f"""
        Training Output Data

        This directory contains CSV files and plots generated from the training runs.

        Files:

        - constrained_losses_run_X.csv: Per-epoch losses for run X in the constrained mode.
        - unconstrained_losses_run_X.csv: Per-epoch losses for run X in the unconstrained mode.
        - constrained_mean_losses.csv: Mean and standard deviation of losses over all runs in constrained mode.
        - unconstrained_mean_losses.csv: Mean and standard deviation of losses over all runs in unconstrained mode.
        - constrained_metrics_runs.csv: Metrics for each run in constrained mode.
        - unconstrained_metrics_runs.csv: Metrics for each run in unconstrained mode.
        - constrained_mean_metrics.csv: Mean and standard deviation of metrics over all runs in constrained mode.
        - unconstrained_mean_metrics.csv: Mean and standard deviation of metrics over all runs in unconstrained mode.
        - constrained_test_data.csv: Test inputs and true outputs (same across runs).
        - constrained_test_predictions_run_X.csv: Predictions for run X in constrained mode.
        - unconstrained_test_predictions_run_X.csv: Predictions for run X in unconstrained mode.
        - losses_comparison.png: Plot comparing losses between constrained and unconstrained modes.
        - dataset_distribution.png: Plot of training and test data distributions.
        - constrained_prediction_projection.png: Plots of predictions and projections in constrained mode.
        - unconstrained_prediction_projection.png: Plots of predictions in unconstrained mode.
        - constrained_constraint.png: Constraint plot in constrained mode.
        - unconstrained_constraint.png: Constraint plot in unconstrained mode.
        - constrained_parity_plots.png: Parity plots in constrained mode.
        - unconstrained_parity_plots.png: Parity plots in unconstrained mode.
        - parity_plots_comparison.png: Comparison of parity plots between constrained and unconstrained modes.

        All data required to reproduce the plots are provided in the CSV files.

        Note: The same set of seeds were used for corresponding runs in both modes, ensuring consistent initialization and data shuffling for fair comparison. The seeds used are included in the metrics CSV files for reference.
        """
        config_file = "configs/config.py"  # The Python file to read

        with open(config_file, "r") as file:
            config_content = file.read()  # Read the entire content of the file into a string
        
        readme_content = readme_content + "\n\n" + config_content

        with open(os.path.join(output_dir, "readme.txt"), "w") as f:
            f.write(readme_content)

    # Get function and constraint
    if PROBLEM not in opt_checks:
        functions = get_functions()
        c = get_constraints()

        # Generate data (np.ndarray) --> here data can also be loaded from a file
        train_inputs, train_outputs = generate_data(
            functions, DATA_TRAINING, LEFT_LIMIT, RIGHT_LIMIT)
        test_inputs, test_outputs = generate_data(
            functions, DATA_TEST, LEFT_LIMIT, RIGHT_LIMIT)
        

    else:
        data = Dataloader(
            input_path=INPUT_DATA_PATH,
            output_path=OUTPUT_DATA_PATH,
            t_v_t_ratio=(10., 1., 1.)
        ).get_data()
        train_inputs = data["train_inputs"]
        train_outputs = data["train_outputs"]
        test_inputs = data["test_inputs"]
        test_outputs = data["test_outputs"]

        # Load constraints
        if PROBLEM == "nonconvex_linear":
            opt_prob = NonconvexProgram(
                params_path=PARAMS_PATH
            )
        elif PROBLEM == "nonconvex_nonlinear":
            opt_prob = NonlinearProgram(
                params_path=PARAMS_PATH
            )
        c = opt_prob.eq_constraints
        jac = opt_prob.jacobian

        # SSL Loss
        ssl_loss_config = SSLConfig(
            soft_constrained=SOFT_CONSTRAINED,
            weight_loss_soft=WEIGHT_LOSS_SOFT
        )
        ssl_loss_function = SSLLoss(
            config=ssl_loss_config,
            opt_prob=opt_prob
        ).to(device=device)


    
    # Scale data
    train_inputs_scaled, train_outputs_scaled, test_inputs_scaled, test_outputs_scaled, scaling_params = scale_data(
        train_inputs, train_outputs, test_inputs, test_outputs)
    
    
    if SAVE:
        # Convert for json dumping
        scaling_params_converted = {
        key: float(value) if isinstance(value, np.float64) else value.tolist() if isinstance(value, np.ndarray) else value
        for key, value in scaling_params.items()
        }
        with open(scaling_param_path, 'w') as json_file:
            json.dump(scaling_params_converted, json_file, indent=4)

    # Prepare scaling tensors
    scaling_input = (
        torch.tensor(scaling_params['input_mean'], dtype=torch.float32, device=device),
        torch.tensor(scaling_params['input_std'], dtype=torch.float32, device=device)
    )
    scaling_output = (
        torch.tensor(scaling_params['output_mean'], dtype=torch.float32, device=device),
        torch.tensor(scaling_params['output_std'], dtype=torch.float32, device=device)
    )

    # Prepare data tensors
    train_inputs_tensor = torch.tensor(train_inputs_scaled, dtype=torch.float32, device=device)#.unsqueeze(1)
    train_outputs_tensor = torch.tensor(train_outputs_scaled, dtype=torch.float32, device=device)
    test_inputs_tensor = torch.tensor(test_inputs_scaled, dtype=torch.float32, device=device)#.unsqueeze(1)
    test_outputs_tensor = torch.tensor(test_outputs_scaled, dtype=torch.float32, device=device) 

    # Dictionaries to store results for each mode
    losses_runs_dict = {}
    metrics_runs_dict = {}
    test_inputs_list_dict = {}
    test_outputs_list_dict = {}
    test_predictions_list_dict = {}
    test_prediction_before_projection_list_dict = {}
    projection_iter = {}

    # Loop over modes
    for mode in modes_to_run:
        
        print(f"\n--- Running mode: {mode.upper()} ---\n")
        constrained = True if mode == 'constrained' else False

        # Lists to store metrics and losses
        metrics_runs = []
        losses_runs = []
        test_predictions_list = []
        test_prediction_before_projection_list = []
        test_inputs_list = []
        test_outputs_list = []

        for run in range(N):
            if not FIX_SEED:
                random_seed = seeds[run]
            else: random_seed = seeds
            print(f"Run {run+1}/{N}, Seed={random_seed}")

            # Initialize model
            model = ENFORCE(
                scaling_input=scaling_input,
                scaling_output=scaling_output,
                c=c,
                constrained=constrained,
                weighting_option=PROJ_WEIGHTING_OPTION,
                random_seed=random_seed,
                ssl_loss=ssl_loss_function if not SUPERVISED else None,
                jac=jac if PROBLEM in opt_checks else None,
            ).to(device=device)
            # model.ada_np = torch.compile(model.ada_np)

            # Train model
            start_tr = time.time()
            if PROFILER:
                with profiler.profile(with_stack=True, profile_memory=True) as prof:
                    model = train_model(model, train_inputs_tensor, train_outputs_tensor, batch_size=BATCH_SIZE, random_seed=random_seed)
                end_tr = time.time()
                print(prof.key_averages(group_by_input_shape=True).table(sort_by='self_cpu_time_total', row_limit=20))
            else:
                model = train_model(model, train_inputs_tensor, train_outputs_tensor, batch_size=BATCH_SIZE, random_seed=random_seed)
                end_tr = time.time()
            
            
            if SAVE:
                model_path = os.path.join(output_dir, f"model_{timestamp}_{mode}_run{run}.pth")
                torch.save(model, model_path)

            # Evaluate model
            metrics, test_predictions, test_prediction_before_projection = evaluate_model(
                model, test_inputs_tensor, test_outputs_tensor, scaling_params, c)
            metrics.update({"training_time": end_tr-start_tr})
            metrics_runs.append(metrics)
            test_predictions_list.append(test_predictions)
            if constrained:
                test_prediction_before_projection_list.append(test_prediction_before_projection)

            test_inputs_list.append(test_inputs)
            test_outputs_list.append(test_outputs)

            # Collect losses
            losses_runs.append(model.losses)

        # Save results in dictionaries
        losses_runs_dict[mode] = losses_runs
        metrics_runs_dict[mode] = metrics_runs
        test_inputs_list_dict[mode] = test_inputs_list
        test_outputs_list_dict[mode] = test_outputs_list
        test_predictions_list_dict[mode] = test_predictions_list
        if constrained:
            test_prediction_before_projection_list_dict[mode] = test_prediction_before_projection_list
            projection_iter[mode] = projection_iter

        # Save metrics to CSV
        if SAVE:
            df_metrics = pd.DataFrame(metrics_runs)
            df_metrics['seed'] = seeds
            df_metrics.to_csv(os.path.join(output_dir, f"{mode}_metrics_runs.csv"), index=False)

            # Compute mean and std of metrics over runs
            metrics_array = pd.DataFrame(metrics_runs)
            mean_metrics = metrics_array.mean()
            std_metrics = metrics_array.std()
            df_mean_metrics = pd.DataFrame({
                'metric': mean_metrics.index,
                'mean': mean_metrics.values,
                'std': std_metrics.values
            })
            df_mean_metrics.to_csv(os.path.join(output_dir, f"{mode}_mean_metrics.csv"), index=False)

            # Save losses to CSV
            for run_idx, losses in enumerate(losses_runs):
                df_losses = pd.DataFrame(losses)
                df_losses['epoch'] = np.arange(1, EPOCHS + 1)
                df_losses.to_csv(os.path.join(losses_dir, f"{mode}_losses_run_{run_idx+1}.csv"), index=False)

            # Compute mean and std of losses over runs
            if constrained:
                # Constrained mode has multiple loss components
                loss_data_array = np.array([[loss['loss_data_after_projection'] for loss in run_losses] for run_losses in losses_runs])
                loss_before_proj_array = np.array([[loss['loss_data_before_projection'] for loss in run_losses] for run_losses in losses_runs])
                loss_displacement_array = np.array([[loss['loss_displacement'] for loss in run_losses] for run_losses in losses_runs])
                projection_iterations_array = np.array([[loss['projection_iterations'] for loss in run_losses] for run_losses in losses_runs])
                if not SUPERVISED:
                    objective_value_optimization_array = np.array([[loss['objective_value_optimization'] for loss in run_losses] for run_losses in losses_runs])
                    objective_value_prediction_array = np.array([[loss['objective_value_prediction'] for loss in run_losses] for run_losses in losses_runs])

                mean_loss_data = np.mean(loss_data_array, axis=0)
                std_loss_data = np.std(loss_data_array, axis=0)
                mean_loss_before_proj = np.mean(loss_before_proj_array, axis=0)
                std_loss_before_proj = np.std(loss_before_proj_array, axis=0)
                mean_loss_displacement = np.mean(loss_displacement_array, axis=0)
                std_loss_displacement = np.std(loss_displacement_array, axis=0)
                mean_projection_iterations = np.mean(projection_iterations_array, axis=0)
                std_projection_iterations = np.std(projection_iterations_array, axis=0)
                if not SUPERVISED:
                    mean_objective_value_optimization = np.mean(objective_value_optimization_array, axis=0)
                    std_objective_value_optimization = np.std(objective_value_optimization_array, axis=0)
                    mean_objective_value_prediction = np.mean(objective_value_prediction_array, axis=0)
                    std_objective_value_prediction = np.std(objective_value_prediction_array, axis=0)

                df_mean_losses = pd.DataFrame({
                    'epoch': np.arange(1, EPOCHS + 1),
                    'mean_loss_data_after_projection': mean_loss_data,
                    'std_loss_data_after_projection': std_loss_data,
                    'mean_loss_data_before_projection': mean_loss_before_proj,
                    'std_loss_data_before_projection': std_loss_before_proj,
                    'mean_loss_displacement': mean_loss_displacement,
                    'std_loss_displacement': std_loss_displacement,
                    'mean_projection_iterations': mean_projection_iterations,
                    'std_projection_iterations': std_projection_iterations
                })
                if not SUPERVISED:
                    df_mean_losses['mean_objective_value_optimization'] = mean_objective_value_optimization
                    df_mean_losses['std_objective_value_optimization'] = std_objective_value_optimization
                    df_mean_losses['mean_objective_value_prediction'] = mean_objective_value_prediction
                    df_mean_losses['std_objective_value_prediction'] = std_objective_value_prediction
            else:
                # Unconstrained mode has a single loss component
                loss_unconstrained_array = np.array([[loss['loss_unconstrained'] for loss in run_losses] for run_losses in losses_runs])
                mean_loss_unconstrained = np.mean(loss_unconstrained_array, axis=0)
                std_loss_unconstrained = np.std(loss_unconstrained_array, axis=0)

                df_mean_losses = pd.DataFrame({
                    'epoch': np.arange(1, EPOCHS + 1),
                    'mean_loss_unconstrained': mean_loss_unconstrained,
                    'std_loss_unconstrained': std_loss_unconstrained
                })

            df_mean_losses.to_csv(os.path.join(output_dir, f"{mode}_mean_losses.csv"), index=False)

            # Save test data and predictions
            # Since the test data is the same across runs, we can save it once per mode
            df_test_inputs = pd.DataFrame(test_inputs_list[0], columns=[f'x{i+1}' for i in range(test_inputs_list[0].shape[1])])
            df_test_outputs = pd.DataFrame(test_outputs_list[0], columns=[f'y{i+1}_true' for i in range(test_outputs_list[0].shape[1])])
            df_test_data = pd.concat([df_test_inputs, df_test_outputs], axis=1)
            df_test_data.to_csv(os.path.join(test_set_dir, f"{mode}_test_data.csv"), index=False)
            
            # Save predictions per run
            for run_idx, predictions in enumerate(test_predictions_list):
                df_predictions = pd.DataFrame(predictions, columns=[f'y{i+1}_pred' for i in range(predictions.shape[1])])
                if constrained:
                    predictions_before = test_prediction_before_projection_list[run_idx]
                    df_predictions_before = pd.DataFrame(predictions_before, columns=[f'y{i+1}_pred_before' for i in range(predictions_before.shape[1])])
                    df_run_predictions = pd.concat([df_test_inputs, df_predictions, df_predictions_before], axis=1)
                else:
                    df_run_predictions = pd.concat([df_test_inputs, df_predictions], axis=1)
                df_run_predictions.to_csv(os.path.join(output_dir, "test_predictions", f"{mode}_test_predictions_run_{run_idx+1}.csv"), index=False)



    # After both modes are run, perform plotting
    if PROBLEM in opt_checks:
        functions = range(min(OUTPUT_NEURONS, 6))
    if PLOT:
        plot_all_results(
            losses_runs_dict, metrics_runs_dict,
            test_inputs_list_dict, test_outputs_list_dict,
            test_predictions_list_dict, test_prediction_before_projection_list_dict,
            output_dir, functions, train_inputs, train_outputs,
            test_inputs, test_outputs
        )
        plt.show()

    

if __name__ == '__main__':
    main()