"""
Script to evaluate the base model and fine-tuned models automatically.
The results are saved in /FM for longitudinal data/results/chronos.
"""

import os
import sys
import sys
sys.path.append(os.getcwd())
import numpy as np
import torch
import pandas as pd
import argparse
import yaml
from pathlib import Path
import json
from utils.evaluation.TSG_evaluation import tsg_evaluation
from utils.evaluation.utils_evaluation import initialize_results_file_tsg, save_results


# Parse config file path from command-line arguments
parser = argparse.ArgumentParser(description="TSG Evaluation")
parser.add_argument("--config", type=str, default='sources/config/config.yaml', help="Path to the config YAML file")
args = parser.parse_args()

# Load parameters from YAML configuration
with open(args.config, "r") as f:
    parameters = yaml.safe_load(f)

# Define parameters from configuration
DATA_NAME = parameters['data_name']
SAVE_DATA = parameters['save_results']
SEED = parameters['seed']
# data parameters 
train_length, train_samples = parameters['data_train_params']
train_channels = parameters['data_train_channels']
# sdforger parameters
sdforger_llm = parameters['sdforger_llm']
sdforger_augmentation_strategy = parameters['sdforger_augmentation_strategy']
sdforger_train_splitting = parameters['sdforger_train_splitting']
sdforger_minimum_windows_number = parameters['sdforger_minimum_windows_number']
sdforger_minimum_windows_length = parameters['sdforger_minimum_windows_length']
sdforger_embedding_dim = parameters['sdforger_embedding_dim']
sdforger_train_epochs = parameters['sdforger_train_epochs']
sdforger_batch = parameters['sdforger_batch']
sdforger_float_type = parameters['sdforger_float_type']
sdforger_text_template = parameters['sdforger_text_template']
sdforger_min_generations = parameters['sdforger_min_generations']
sdforger_max_generations = parameters['sdforger_max_generations']
sdforger_norms_diversity_threshold = parameters['sdforger_norms_diversity_threshold']
sdforger_variance_explained = parameters['sdforger_variance_explained']
sdforger_embedding_type = parameters['sdforger_embedding_type']
sdforger_permute = parameters['sdforger_permute']
sdforger_init_value = parameters['sdforger_init_value']
sdforger_learning_rate = parameters['sdforger_learning_rate']
# evaluation parameters
evaluation_parameters = parameters['evaluation']
train_data_path = evaluation_parameters['train_data_path']
generated_data_path = evaluation_parameters['generated_data_path']
# input and output paths
generated_data_path = Path(evaluation_parameters['generated_data_path'])
evaluation_parameters = parameters['evaluation']

if not torch.cuda.is_available():
    device = 'cpu'
else:
    device = 'cuda'


# Main evaluation process
if __name__ == "__main__":

    # Set paths for output, model, and data input
    OUTPUT_PATH = os.path.join('output')
    os.makedirs(OUTPUT_PATH, exist_ok=True)

    # Set output results file
    # csv_file_path = os.path.join(str(OUTPUT_PATH), data + '_TSG_evaluation.csv')
    csv_file_path = os.path.join(str(OUTPUT_PATH), 'TSG_evaluation.csv')
    df_results = initialize_results_file_tsg(csv_file_path)

    generated_data_path = evaluation_parameters['generated_data_path']
    print('generated_path', generated_data_path)
    # Split the path string into components
    last_components = generated_data_path.split('/')[-7:]
    key_excel = '/'.join(last_components)
    print('key_excel', key_excel)

    # if row_exists_in_dataframe(df_results, ['key'], [key_excel]):
    #     print(key_excel)
    #     print('Model already evaluated')


    # else:

    # -------------------------------------------------------------------------------------------------------- #
    print(f'\nGet Original Data')
    train_data_path = evaluation_parameters['train_data_path']    
    original_dataset = np.load(train_data_path).transpose(1,2,0)
    print(original_dataset.shape)

    # -------------------------------------------------------------------------------------------------------- #
    print(f'\nGet Generated Data')
    generated_data_path = evaluation_parameters['generated_data_path']    
    generated_dataset = np.load(generated_data_path).transpose(1,2,0)
    print(generated_dataset.shape)


    # -------------------------------------------------------------------------------------------------------- #
        
    # Recover info from info_dict
    suffix = generated_data_path.split("/")[-1][len("new_data"):-len(".npy")]

    # Recover info from path
    # dict_key = decode_sdf_path(generated_data_path)
    # (data, target, augmentation, n_points, 
    #     n_sample, min_nb_train_windows, min_length_train_windows, train_splitting, 
    #     seed, sdforge_llm, embedding_type, diversity_threshold, 
    #     var_requested, permute, init, lr, 
    #     min_generated, max_generated) = (dict_key[key] for key in dict_key.keys())
    
    preprocessing_dict_path = os.path.join(str(Path(generated_data_path).parent), f'info_dict_preprocessing.json')
    with open(preprocessing_dict_path, 'r') as f:
        info_dict = json.load(f)
    train_windows, length_windows, period, overlap = (
        info_dict['train_windows'],
        info_dict['length_windows'],
        info_dict['period'],
        info_dict['overlap']
    )
    
    augmentation_dict_path = os.path.join(str(Path(generated_data_path).parent), f'info_dict.json')

    with open(augmentation_dict_path, 'r') as f:
        info_dict = json.load(f)
        k, new_samples, var_explained = (
        info_dict['embedding_dim'][0],
        info_dict['new_samples'],
        info_dict['var_explained'][0])
    
    values_to_save = [key_excel, 
                    SEED, DATA_NAME, train_channels, sdforger_augmentation_strategy, 
                    train_length, train_samples, sdforger_minimum_windows_number,
                    train_windows, sdforger_minimum_windows_length, overlap, period,
                    sdforger_llm, sdforger_learning_rate, sdforger_train_splitting,
                    sdforger_permute, sdforger_init_value, sdforger_embedding_type, k, 
                    new_samples, sdforger_variance_explained, var_explained]
    
    print(f'\nTSG Evaluation')
    result = tsg_evaluation( original_dataset, generated_dataset )
    
    # result_shivani = shivani_evaluation( original_dataset, generated_dataset, device, result_path_visualization = os.path.join(OUTPUT_PATH, str(target)), 
    #                         combined_name = generated_data_path.split("/")[-2] + suffix )
    
    print('\nSaving sdforger')
    save_results(df_results, csv_file_path, values_to_save + 
                [result['MDD'], result['ACD'], result['SD'], result['KD'], result['ED'], result['DTW'], result['SHAP-RE']])
