"""
Script to evaluate the base model and fine-tuned models automatically.
The results are saved in /FM for longitudinal data/results/chronos.
"""

import os
import sys
import sys
sys.path.append(os.getcwd())
import numpy as np
import torch
import pandas as pd
import argparse
import yaml
from pathlib import Path
import json
from utils.evaluation.utils_ttm import (prepare_dataset_for_TTM, 
                       train_ttm, 
                       evaluate_ttm_model)
from tsfm_public import TimeSeriesPreprocessor
from utils.evaluation.utils_evaluation import (initialize_results_file_ttm, save_results)
                              
                              

os.environ["WANDB_DISABLED"] = "true"


# Parse config file path from command-line arguments
parser = argparse.ArgumentParser(description="TTM Evaluation")
parser.add_argument("--config", type=str, default='sources/config/config_ttm.yaml', help="Path to the config YAML file")
args = parser.parse_args()

# Load parameters from YAML configuration
with open(args.config, "r") as f:
    parameters = yaml.safe_load(f)

# Define parameters from configuration
DATA_NAME = parameters['data_name']
SAVE_DATA = parameters['save_results']
SEED = parameters['seed']
# data parameters 
train_length, train_samples = parameters['data_train_params']
train_channels = parameters['data_train_channels']
# sdforger parameters
sdforger_llm = parameters['sdforger_llm']
sdforger_augmentation_strategy = parameters['sdforger_augmentation_strategy']
sdforger_train_splitting = parameters['sdforger_train_splitting']
sdforger_minimum_windows_number = parameters['sdforger_minimum_windows_number']
sdforger_minimum_windows_length = parameters['sdforger_minimum_windows_length']
sdforger_embedding_dim = parameters['sdforger_embedding_dim']
sdforger_train_epochs = parameters['sdforger_train_epochs']
sdforger_batch = parameters['sdforger_batch']
sdforger_float_type = parameters['sdforger_float_type']
sdforger_text_template = parameters['sdforger_text_template']
sdforger_min_generations = parameters['sdforger_min_generations']
sdforger_max_generations = parameters['sdforger_max_generations']
sdforger_norms_diversity_threshold = parameters['sdforger_norms_diversity_threshold']
sdforger_variance_explained = parameters['sdforger_variance_explained']
sdforger_embedding_type = parameters['sdforger_embedding_type']
sdforger_permute = parameters['sdforger_permute']
sdforger_init_value = parameters['sdforger_init_value']
sdforger_learning_rate = parameters['sdforger_learning_rate']
# evaluation parameters
evaluation_parameters = parameters['evaluation']
train_data_path = evaluation_parameters['train_data_path']
generated_data_path = evaluation_parameters['generated_data_path']
# TTM parameters
TTM_parameters = parameters['TTM']
# input and output paths
generated_data_path = Path(evaluation_parameters['generated_data_path'])
evaluation_parameters = parameters['evaluation']

if not torch.cuda.is_available():
    device = 'cpu'
else:
    device = 'cuda'


# Main evaluation process
if __name__ == "__main__":


    # Set paths for output, model, and data input
    OUTPUT_PATH = os.path.join('output')
    os.makedirs(OUTPUT_PATH, exist_ok=True)

    # Set output results file
    # csv_file_path = os.path.join(str(OUTPUT_PATH), data + '_TSG_evaluation.csv')
    csv_file_path = os.path.join(str(OUTPUT_PATH), 'TTM_evaluation.csv')
    df_results = initialize_results_file_ttm(csv_file_path)


    # -------------------------------------------------------------------------------------------------------- #
    # Get important objects

    train_data_path = evaluation_parameters['train_data_path']
    generated_data_path = evaluation_parameters['generated_data_path']
    
    labels_data = np.load( os.path.join(Path(train_data_path).parent, 'labels.npy') )


    # Create the TSP from the column_specifiers
    tsp = TimeSeriesPreprocessor(
        **TTM_parameters['TTM_column_specifiers'],
        context_length=TTM_parameters['context_length'],
        prediction_length=TTM_parameters['forecast_length'],
        scaling=True,
        encode_categorical=False,
        # scaler_type="standard",
        )

    val_data_path = os.path.join(Path(train_data_path).parent, 'val_data.npy')
    validation_dataset = prepare_dataset_for_TTM(np.load(val_data_path), labels_data, tsp)
    
    test_data_path = os.path.join(Path(train_data_path).parent, 'test_data.npy')
    test_dataset = prepare_dataset_for_TTM(np.load(test_data_path), labels_data, tsp)


    zeroshot_key = str(os.path.join(Path(train_data_path).parent)) + 'zeroshot'
    combinaison_train_gen_key = str(os.path.join(Path(generated_data_path).parent)) + '-combinaison'

    preprocessing_dict_path = os.path.join(str(Path(test_data_path).parent), f'info_dict_preprocessing.json')
    with open(preprocessing_dict_path, 'r') as f:
        info_dict = json.load(f)
    train_windows, length_windows, period, overlap = (
        info_dict['train_windows'], 
        info_dict['length_windows'], 
        info_dict['period'], 
        info_dict['overlap'])

    # -------------------------------------------------------------------------------------------------------- #

    for data_path in (zeroshot_key, train_data_path, generated_data_path, combinaison_train_gen_key) : 
            
        # --------------------------------------------------------------------------------------------------------------------------------------------------------- #
        # RESULTS FOR ZEROSHOT
        
        if data_path == zeroshot_key:

            output_model_path = TTM_parameters['TTM_model']
            results = evaluate_ttm_model(test_dataset, output_model_path, tsp, seed=SEED)

            # Save results
            # preprocessing_dict_path = os.path.join(str(Path(test_data_path).parent), f'info_dict_preprocessing.json')
            # with open(preprocessing_dict_path, 'r') as f:
            #     info_dict = json.load(f)
            # train_windows, length_windows, period, overlap = (
            #     info_dict['train_windows'], info_dict['length_windows'], info_dict['period'], info_dict['overlap'])
            # Recover info from path
            # dict_key = decode_train_path(train_data_path)
            # (data, target, augmentation, n_points, 
            #     n_sample, min_nb_train_windows, min_length_train_windows, train_splitting, 
            #     seed) = (dict_key[key] for key in dict_key.keys())
            
            values_to_save = [data_path, 
                                SEED, DATA_NAME, train_channels, sdforger_augmentation_strategy, 
                                train_length, train_samples, sdforger_minimum_windows_number,
                                train_windows, sdforger_minimum_windows_length, overlap, period,
                                None, None, None,
                                None, None, None, None, 
                                None, None, None]
            
            print('\nSaving zeroshot')
            save_results(df_results, csv_file_path, values_to_save + 
                        [results["avg_rmse"], results["avg_mase"], results["avg_wql"], results["avg_h1"]])

        # --------------------------------------------------------------------------------------------------------------------------------------------------------- #
        # RESULTS FOR COMBI SYN + GEN -- Non opptimized and messy bc of rush 
        
        elif data_path == combinaison_train_gen_key:
            
            print(f'\nGet Finetuning Data')
            output_model_path = os.path.join(Path(generated_data_path).parent, 'TTM_finetuned_model_on_ori_and_syn/')
            
            finetuning_dataset_train = np.load(train_data_path)
            finetuning_dataset_syn = np.load(generated_data_path)
            print('train_data', finetuning_dataset_train.shape)
            print('syn_data', finetuning_dataset_syn.shape)
            finetuning_dataset = np.concatenate([finetuning_dataset_train, finetuning_dataset_syn], axis=1)
            print(finetuning_dataset.shape)
            
            training_dataset = prepare_dataset_for_TTM(finetuning_dataset, labels_data, tsp)
            
            train_ttm (
                training_dataset,
                validation_dataset,
                TTM_parameters['TTM_model'],
                TTM_parameters['TTM_MODEL_REVISION'],
                "mix_channel", #decoder mode
                tsp,
                output_model_path,
                device = device
                )

            results = evaluate_ttm_model(test_dataset, output_model_path, tsp, seed=None)

            # ------------------------------------------------- #
            # Saving SDFORGER 

            # -------------------------------------------------------------------------------------------------------- #
            # Recover info from info_dict

            # Recover info from path
            # dict_key = decode_sdf_path(generated_data_path)
            # (data, target, augmentation, n_points, 
            #     n_sample, min_nb_train_windows, min_length_train_windows, train_splitting, 
            #     seed, sdforge_llm, embedding_type, diversity_threshold, 
            #     var_requested, permute, init, lr, 
            #     min_generated, max_generated) = (dict_key[key] for key in dict_key.keys())
            
            # preprocessing_dict_path = os.path.join(str(Path(generated_data_path).parent), f'info_dict_preprocessing.json')
            # with open(preprocessing_dict_path, 'r') as f:
            #     info_dict = json.load(f)
            # train_windows, length_windows, period, overlap = (
            #     info_dict['train_windows'],
            #     info_dict['length_windows'],
            #     info_dict['period'],
            #     info_dict['overlap']
            # )
            
            augmentation_dict_path = os.path.join(str(Path(generated_data_path).parent), f'info_dict.json')

            with open(augmentation_dict_path, 'r') as f:
                info_dict = json.load(f)
            k, new_samples, var_explained = (
                info_dict['embedding_dim'][0],
                info_dict['new_samples'],
                info_dict['var_explained'][0])
            
            values_to_save = [combinaison_train_gen_key, 
                            SEED, DATA_NAME, train_channels, sdforger_augmentation_strategy, 
                            train_length, train_samples, sdforger_minimum_windows_number,
                            train_windows, sdforger_minimum_windows_length, overlap, period,
                            sdforger_llm, sdforger_learning_rate, sdforger_train_splitting,
                            sdforger_permute, sdforger_init_value, sdforger_embedding_type, k, 
                            new_samples, sdforger_variance_explained, var_explained]
            
            print('\nSaving sdforger')
            save_results(df_results, csv_file_path, values_to_save +
                            [results["avg_rmse"], results["avg_mase"], results["avg_wql"], results["avg_h1"]])

    
        # --------------------------------------------------------------------------------------------------------------------------------------------------------- #
        # RESULTS FOR ONLY SYN OR GEN 
        
        else : 
            print(f'\nGet Finetuning Data')
            if data_path == train_data_path:
                output_model_path = os.path.join(Path(data_path).parent, 'TTM_finetuned_model_on_ori/')
            if data_path == generated_data_path:
                output_model_path = os.path.join(Path(data_path).parent, 'TTM_finetuned_model_on_syn/')
            finetuning_dataset = np.load(data_path)
            print(finetuning_dataset.shape)

            training_dataset = prepare_dataset_for_TTM(finetuning_dataset, labels_data, tsp)

            train_ttm (
                training_dataset,
                validation_dataset,
                TTM_parameters['TTM_model'],
                TTM_parameters['TTM_MODEL_REVISION'],
                "mix_channel", #decoder mode
                tsp,
                output_model_path,
                device = device
                )

            results = evaluate_ttm_model(test_dataset, output_model_path, tsp, seed=None)

            # ------------------------------------------------- #
            # Saving Train Data Only 
            
            if data_path == train_data_path :

                # # Save results
                # preprocessing_dict_path = os.path.join(str(Path(test_data_path).parent), f'info_dict_preprocessing.json')
                # with open(preprocessing_dict_path, 'r') as f:
                #     info_dict = json.load(f)
                # train_windows, length_windows, period, overlap = (
                #     info_dict['train_windows'], info_dict['length_windows'], info_dict['period'], info_dict['overlap'])
                # # Recover info from path
                # dict_key = decode_train_path(train_data_path)
                # (data, target, augmentation, n_points, 
                #     n_sample, min_nb_train_windows, min_length_train_windows, train_splitting, 
                #     seed) = (dict_key[key] for key in dict_key.keys())
                
                values_to_save = [data_path, 
                                SEED, DATA_NAME, train_channels, sdforger_augmentation_strategy, 
                                train_length, train_samples, sdforger_minimum_windows_number,
                                train_windows, sdforger_minimum_windows_length, overlap, period,
                                None, None, None,
                                None, None, None, None, 
                                None, None, None]
                
                print('\nSaving finetuning on original data')
                save_results(df_results, csv_file_path, values_to_save +
                                [results["avg_rmse"], results["avg_mase"], results["avg_wql"], results["avg_h1"]])
                        

            # ------------------------------------------------- #
            # Saving SDFORGER 
            elif data_path == generated_data_path:

                # -------------------------------------------------------------------------------------------------------- #
                # Recover info from info_dict

                # Recover info from path
                # dict_key = decode_sdf_path(generated_data_path)
                # (data, target, augmentation, n_points, 
                #     n_sample, min_nb_train_windows, min_length_train_windows, train_splitting, 
                #     seed, sdforge_llm, embedding_type, diversity_threshold, 
                #     var_requested, permute, init, lr, 
                #     min_generated, max_generated) = (dict_key[key] for key in dict_key.keys())
                
                # preprocessing_dict_path = os.path.join(str(Path(generated_data_path).parent), f'info_dict_preprocessing.json')
                # with open(preprocessing_dict_path, 'r') as f:
                #     info_dict = json.load(f)
                # train_windows, length_windows, period, overlap = (
                #     info_dict['train_windows'],
                #     info_dict['length_windows'],
                #     info_dict['period'],
                #     info_dict['overlap']
                # )
                
                augmentation_dict_path = os.path.join(str(Path(generated_data_path).parent), f'info_dict.json')

                with open(augmentation_dict_path, 'r') as f:
                    info_dict = json.load(f)
                k, new_samples, var_explained = (
                    info_dict['embedding_dim'][0],
                    info_dict['new_samples'],
                    info_dict['var_explained'][0])
                
                values_to_save = [generated_data_path, 
                                  SEED, DATA_NAME, train_channels, sdforger_augmentation_strategy, 
                                train_length, train_samples, sdforger_minimum_windows_number,
                                train_windows, sdforger_minimum_windows_length, overlap, period,
                                sdforger_llm, sdforger_learning_rate, sdforger_train_splitting,
                                sdforger_permute, sdforger_init_value, sdforger_embedding_type, k, 
                                new_samples, sdforger_variance_explained, var_explained]
                
                print('\nSaving sdforger')
                save_results(df_results, csv_file_path, values_to_save +
                                [results["avg_rmse"], results["avg_mase"], results["avg_wql"], results["avg_h1"]])

