"""
Script to augment data for Chronos. Train, Train augmented, and test data are saved in
'/FM for longitudinal data/data/chronos'.
"""
import os
import sys
sys.path.append(os.getcwd())
import numpy as np
import yaml
import argparse
import json
import shutil
import torch
from pathlib import Path

from utils.generals import set_seed
from utils.augmentation.utils_preprocess_data import upload_and_preprocess_data
from utils.augmentation.utils_data_augmentation import make_json_serializable, plot_generated_data
from utils.augmentation.sdforger_augmentation import sdforger_augmentation


# Parse config file path from command-line arguments
parser = argparse.ArgumentParser(description="Data Augmentation")
parser.add_argument("--config", type=str, default='sources/config/config.yaml', help="Path to the config YAML file")
args, unknown = parser.parse_known_args()

# Load parameters from YAML configuration
with open(args.config, "r") as f:
    parameters = yaml.safe_load(f)

# Define parameters from configuration
DATA_NAME = parameters['data_name']
SAVE_DATA = parameters['save_results']
SEED = parameters['seed']
# data parameters
train_length, train_samples = parameters['data_train_params']
train_channels = parameters['data_train_channels']
# sdforger parameters
sdforger_llm = parameters['sdforger_llm']
sdforger_augmentation_strategy = parameters['sdforger_augmentation_strategy']
sdforger_train_splitting = parameters['sdforger_train_splitting']
sdforger_minimum_windows_number = parameters['sdforger_minimum_windows_number']
sdforger_minimum_windows_length = parameters['sdforger_minimum_windows_length']
sdforger_embedding_dim = parameters['sdforger_embedding_dim']
sdforger_train_epochs = parameters['sdforger_train_epochs']
sdforger_batch = parameters['sdforger_batch']
sdforger_float_type = parameters['sdforger_float_type']
sdforger_text_template = parameters['sdforger_text_template']
sdforger_min_generations = parameters['sdforger_min_generations']
sdforger_max_generations = parameters['sdforger_max_generations']
sdforger_norms_diversity_threshold = parameters['sdforger_norms_diversity_threshold']
sdforger_variance_explained = parameters['sdforger_variance_explained']
sdforger_embedding_type = parameters['sdforger_embedding_type']
sdforger_permute = parameters['sdforger_permute']
sdforger_init_value = parameters['sdforger_init_value']
sdforger_learning_rate = parameters['sdforger_learning_rate']
# evaluation parameters
evaluation_parameters = parameters['evaluation']
create_train_val_test = parameters['create_train_val_test']
train_data_path = evaluation_parameters['train_data_path']
generated_data_path = Path(evaluation_parameters['generated_data_path'])

if sys.platform == 'darwin':
    device = 'mps'
else:
    if torch.cuda.is_available():
        device = 'cuda'
    else:
        device = 'cpu'

if __name__ == "__main__":

    set_seed(SEED)
    info_dict = {}

    train_data_path = evaluation_parameters['train_data_path']
    data_output_folder = Path(train_data_path).parent
    print('data_output_folder', data_output_folder)
    try:
        if os.path.isfile(data_output_folder) or os.path.islink(data_output_folder):
            os.unlink(data_output_folder)
        elif os.path.isdir(data_output_folder):
            shutil.rmtree(data_output_folder)
    except Exception as e:
        print('Failed to delete %s. Reason: %s' % (data_output_folder, e))
    os.makedirs(data_output_folder, exist_ok=True)


    # -------------------------------------------------------------------------------------------------------- #
    # DATA PREPROCESSING

    # Preprocess and save data
    preprocessed_train_data, period, overlap, preprocessed_val_data, preprocessed_test_data = (
        upload_and_preprocess_data(
            DATA_NAME, train_channels, train_length,
            train_samples, sdforger_augmentation_strategy, sdforger_minimum_windows_length,
            sdforger_minimum_windows_number, sdforger_train_splitting, create_train_val_test))
    
    np.save(train_data_path, np.stack(preprocessed_train_data))
    print("Preprocessed train data saved ! Of shape ", preprocessed_train_data.shape)
    
    if create_train_val_test:
        np.save(os.path.join(Path(train_data_path).parent, 'val_data.npy'), np.stack(preprocessed_val_data))
        print("Preprocessed val data saved ! Of shape ", preprocessed_val_data.shape)
        np.save(os.path.join(Path(train_data_path).parent, 'test_data.npy'), np.stack(preprocessed_test_data))
        print("Preprocessed test data saved !, Of shape ", preprocessed_test_data.shape)

    # Save labels
    np.save(os.path.join(data_output_folder, 'labels.npy'), np.array(train_channels))
    # Save info_dict on preprocessing
    info_dict['train_windows'] = preprocessed_train_data.shape[1]
    info_dict['length_windows'] = preprocessed_train_data.shape[2]
    info_dict['period'] = period
    info_dict['overlap'] = overlap
    dict_path = (os.path.join(data_output_folder, f'info_dict_preprocessing.json'))
    info_dict = make_json_serializable(info_dict)
    with open(dict_path, 'w') as f:
        json.dump(info_dict, f)


    # -------------------------------------------------------------------------------------------------------- #
    # DATA AUGMENTATION

    set_seed(SEED)
    info_dict = {}
    
    data_output_folder = generated_data_path.parent
    os.makedirs(data_output_folder, exist_ok=True)

    # Get train data
    preprocessed_train_data = np.load(train_data_path).tolist()
    
    # Augment data
    print('SDForger Augmentation')
    new_data, data_embeddings, new_data_embeddings, embedding_dim, var_explained = sdforger_augmentation(
        preprocessed_train_data, sdforger_min_generations, sdforger_max_generations,
        sdforger_norms_diversity_threshold, sdforger_embedding_dim,
        sdforger_variance_explained, sdforger_embedding_type, sdforger_train_epochs, sdforger_llm, sdforger_text_template,
        sdforger_permute, sdforger_init_value, float(sdforger_learning_rate), sdforger_batch, sdforger_float_type,
        device)

    # in case we have different samples and different points for channels and windows
    min_n_samples = min([d.shape[0] for d in new_data])
    min_n_points = min([d.shape[1] for d in new_data])
    new_data = np.array([d[:min_n_samples, :min_n_points] for d in new_data])

    if SAVE_DATA:

        plot_generated_data(data_output_folder, preprocessed_train_data, new_data)
        

        # Save generated data
        np.save(os.path.join(generated_data_path), new_data) 

        # Save augmentation information in dictionary
        info_dict['embedding_dim'] = list(embedding_dim)
        info_dict['var_explained'] = list(var_explained)
        info_dict['new_samples'] = min_n_samples
        info_dict['train_windows'] = len(preprocessed_train_data[0])
        dict_path = (os.path.join(data_output_folder, f'info_dict.json'))
        info_dict = make_json_serializable(info_dict)
        with open(dict_path, 'w') as f:
            json.dump(info_dict, f)
        
        # # Save embeddings
        # np.savez(os.path.join(data_output_folder, f'new_data_embeddings{suffix}.npz'), *new_data_embeddings)
        # np.savez(os.path.join(data_output_folder, f'data_embeddings{suffix}.npz'), *data_embeddings)


    