import argparse
from copy import deepcopy
import yaml
from typing import Dict, Tuple
from utils import set_seed
import torch
import os, pickle
from utils import check_or_make_folder, set_seed
from diffrax_neural_cde import trainer
#from flow_match_neural_cde import trainer
import sys

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def train(params: Dict):

    set_seed(params)
    check_or_make_folder("./data")
    os.chdir("./data")
    torch.cuda.empty_cache()

    params['dataset_name'] = params['data_path'].split("/")[-1].split(".")[0]
    print(params['dataset_name'])
    params['n_sample'] = float(params['data_path'].split("/")[-2].split("_")[-1][:-7])
    print(params)
    print(f"The n_sample is: {params['n_sample']}")
    assert params['nsample'] <= params['n_sample'], "The nsample should be lower than n_sample"

    from prepare_data_node import DataProcessor

    data_proc = DataProcessor(params)
    data_loaded = data_proc.get_data() # update params dict
    calib_data = data_proc.data_tuples(data_loaded) # update params dict
    calib_data = data_proc.build_context(calib_data, add_time=False)

    if params['ode_name'] == 'gmat':
        calib_data = data_proc.normalize_var_length_data(calib_data)
        
        #samples = torch.normal(mean=0, std=0.1, size=(calib_data[2].X_ctx.shape[0],calib_data[2].X_ctx.shape[1],calib_data[2].X_ctx.shape[2]))
        #calib_data[2].X_ctx = calib_data[2].X_ctx + samples
        #samples = torch.normal(mean=0, std=0.1, size=(calib_data[3].X_ctx.shape[0],calib_data[3].X_ctx.shape[1],calib_data[3].X_ctx.shape[2]))
        #calib_data[3].X_ctx = calib_data[3].X_ctx + samples        
        calib_data = data_proc.build_context(calib_data, add_time=True) # unnormalized time
        calib_data = data_proc.torch_to_jax(calib_data)
        calib_data = data_proc.spline_coeffs(calib_data)        
    #elif params['ode_name'] == 'lorenz_node':
    #    pass
    else:    
        calib_data = data_proc.normalize_calib_data(calib_data)
        calib_data = data_proc.build_context(calib_data, add_time=True) # unnormalized time
        calib_data = data_proc.torch_to_jax(calib_data)
        calib_data = data_proc.spline_coeffs(calib_data)
    
    if params['load_model']:
        trainer(params, calib_data, train=False)
    else:    
        trainer(params, calib_data, train=True)
    
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--yaml_file', type=str, default=None)
    parser.add_argument('--data_path', type=str, default=None)
    parser.add_argument('--load_model', type=bool, default=False)
    parser.add_argument('--data_type', type=str, default=None)
    parser.add_argument('--ode_name', type=str, default=None)
    # horizon options 
    parser.add_argument('--train_horizon', type=int, default=None)
    parser.add_argument('--val_horizon', type=int, default=None)
    parser.add_argument('--interp_horizon', type=int, default=None)
    parser.add_argument('--batch_size', type=int, default=None)
    parser.add_argument('--delta_t', type=float, default=None)
    parser.add_argument('--n_sample', type=float, default=None)
    parser.add_argument('--nsample', type=float, default=None)
    parser.add_argument('--steer', type=int, default=None)
    # cde and decoder architectures
    parser.add_argument('--cde_nodes', type=int, default=None)
    parser.add_argument('--cde_layers', type=int, default=None)
    parser.add_argument('--decoder_nodes', type=int, default=None)
    parser.add_argument('--decoder_layers', type=int, default=None)
    parser.add_argument('--hidden_channels', type=int, default=None)
    parser.add_argument('--epochs', type=int, default=None)
    parser.add_argument('--lr', type=float, default=None)
    # context options
    parser.add_argument('--past_ts_ctxt', type=int, default=0) # How many timesteps ID as context
    parser.add_argument('--past_feat_ctxt', type=int, default=1) # How many features as context
    parser.add_argument('--init_cond_ctxt', type=bool, default=False)

    args = parser.parse_args()
    params = vars(args)

    if params['yaml_file']:
        with open(args.yaml_file, 'r') as f:
            yaml_config = yaml.load(f, Loader=yaml.FullLoader)
            for config in yaml_config['args']:
                if config in params:
                    params[config] = yaml_config['args'][config]             

    train(params) 

    return                

if __name__ == '__main__':
    main()