import os
from itertools import permutations

import pandas as pd

from run_benchmark import evaluate_performance, parse_args, MODEL_CLASSES, COMPLEX_MODELS, ANALYTICAL_MODELS
from utils.data_utils import get_dataloaders
from utils.training_utils import (
    create_log_dir,
    get_device,
    get_dist_grid_codes,
    setup_pytorch,
    test,
)

def run_cc_ablation(args):
    # Argument parsing and validation
    if args.load_model_dir:
        assert args.model.upper() != 'ALL', "When loading a model, please specify a single model type, not 'ALL'."

    data_dir = args.data_dir
    batch_size = args.batch_size
    epochs = args.epochs
    save_results = args.save_results
    plot = args.plot
    save_model = args.save_model
    eval_only = args.eval_only
    load_model_dir = args.load_model_dir
    load_model_name = args.load_model_name
    
    # Set up training, logging, and experiment cases
    setup_pytorch()
    
    log_dir = None
    if save_results or plot or save_model:
        # Create a new log directory for each model
        log_dir = create_log_dir()
    
    grids_to_compare = get_dist_grid_codes(scenario=1)
    test_cases = list(permutations(grids_to_compare, 2)) # One-to-One generalization scenarios

    # Set up results tracking
    if save_results and log_dir:
        results_file = os.path.join(log_dir, 'cc_results_summary.csv')
        column_names = [
            'model',
            'training_grid',
            'testing_grid',
            'rmse_vm_pu',
            'rmse_va_degree',
            'mape_vm_pu',
            'mape_va_degree',
            'best_val_loss',
            'corresponding_train_loss',
            'total_epochs',
            'train_time'
        ]
        # Create a DataFrame for the results
        pd.DataFrame(columns=column_names).to_csv(results_file)
        print(f'\nResults will be saved to: {results_file}', flush=True)

    models_to_evaluate = [MODEL_CLASSES[args.model]] if args.model.upper() != 'ALL' else list(MODEL_CLASSES.values())

    test_case_counter = 0
    # Run evaluations
    for training_grid, testing_grid in test_cases:
        # Get data loaders

        need_real_valued_data = len(set(models_to_evaluate) - set(COMPLEX_MODELS)) > 0
        need_complex_valued_data = len(set(models_to_evaluate) & set(COMPLEX_MODELS)) > 0

        # If no real models are being evaluated, skip loading real data
        if need_real_valued_data:
            loader_train_real, loader_val_real, loader_test_real = get_dataloaders(
                data_dir, [training_grid], testing_grid, batch_size=batch_size
            )

        # If no complex models are being evaluated, skip loading complex data
        if need_complex_valued_data:
            loader_train_complex, loader_val_complex, loader_test_complex = get_dataloaders(
                data_dir, [training_grid], testing_grid, batch_size=batch_size, complex=True
            )
        # Keep track of results
        results = []

        test_case_counter += 1
        print('\n###################################################', flush=True)
        print('###################################################', flush=True)
        print(f'\nStarting Cross-Context Experiment {test_case_counter} / {len(test_cases)}', flush=True)
        print(f'\tTraining grid: {training_grid}', flush=True)
        print(f'\tTesting grid: {testing_grid}', flush=True)
        print('\n###################################################', flush=True)
        print('###################################################', flush=True)

        for model in models_to_evaluate:
            # Use complex data loaders for complex models
            if model in COMPLEX_MODELS:
                loader_train, loader_val, loader_test = loader_train_complex, loader_val_complex, loader_test_complex
            else:
                loader_train, loader_val, loader_test = loader_train_real, loader_val_real, loader_test_real
            print('\n--------------------------------------------------', flush=True)
            print(f'\nEvaluating model: {model.__name__} | Training grid: {training_grid} | Testing grid: {testing_grid}', flush=True)
            # Train and test model
            if model in ANALYTICAL_MODELS:
                rmse_vm, rmse_va, mape_vm, mape_va = test(model(), get_device(), loader_test)
                best_val_loss, corresponding_train_loss, total_epochs, train_time = 0, 0, 0, 0
            else:
                rmse_vm, rmse_va, mape_vm, mape_va, best_val_loss, corresponding_train_loss, total_epochs, train_time = \
                    evaluate_performance(model_class=model,
                                        loader_train=loader_train,
                                        loader_val=loader_val,
                                        loader_test=loader_test,
                                        epochs=epochs,
                                        log_dir=log_dir,
                                        plot=plot,
                                        save_model=save_model,
                                        eval_only=eval_only,
                                        load_model_dir=load_model_dir,
                                        model_load_experiment_id=f"{load_model_name}_{training_grid}_{testing_grid}",
                                        experiment_id=f"{model.__name__}_{training_grid}_{testing_grid}")
            
            results.append(
                (
                    model.__name__,
                    training_grid,
                    testing_grid,
                    rmse_vm,
                    rmse_va,
                    mape_vm,
                    mape_va,
                    best_val_loss,
                    corresponding_train_loss,
                    total_epochs,
                    train_time
                )
            )
            print(f'\nCompleted evaluation for model: {model.__name__}', flush=True)
            stats = f'time (s): {train_time}\n\trmse_vm: {rmse_vm}\n\trmse_va: {rmse_va}\n\tmape_vm: {mape_vm}\n\tmape_va: {mape_va}\n\tbest_val_loss: {best_val_loss}\n\tcorresponding_train_loss: {corresponding_train_loss}\n\ttotal_epochs: {total_epochs}'
            print(stats, flush=True)

        if save_results and log_dir:
            # Create a DataFrame for the results
            results_df = pd.DataFrame(results, columns=column_names)
            
            # Append to existing results file after each test case
            assert(results_file is not None), "results_file should not be None if save_results is True"
            results_df.to_csv(results_file, mode='a', index=True, header=False)
            print(f'\nAppended results to: {results_file}', flush=True)

        print('\n==================================================', flush=True)

    print('\nAll evaluations completed.', flush=True)

if __name__ == '__main__':
    args = parse_args()
    run_cc_ablation(args)
