import argparse
from copy import deepcopy
import yaml
from typing import Dict, Tuple
from utils import set_seed
import torch
import os, pickle
from utils import check_or_make_folder, set_seed
from diffrax_neural_cde import trainer
#from flow_match_neural_cde import trainer
import sys

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def train(params: Dict):

    set_seed(params)
    check_or_make_folder("./data")
    os.chdir("./data")
    torch.cuda.empty_cache()

    params['dataset_name'] = params['data_path'].split("/")[-1].split(".")[0]
    print(params['dataset_name'])
    print(params)

    from prepare_data_node import DataProcessor

    data_proc = DataProcessor(params)
    data_loaded = data_proc.get_data() # update params dict
    calib_data = data_proc.data_tuples(data_loaded) # update params dict
    calib_data = data_proc.build_context(calib_data, add_time=False)

    calib_data = data_proc.normalize_calib_data(calib_data)
    calib_data = data_proc.build_context(calib_data, add_time=True) # unnormalized time
    calib_data = data_proc.np_to_jnp(calib_data)
    #calib_data = data_proc.spline_coeffs(calib_data)
    
    if params['load_model']:
        trainer(params, calib_data, train=False)
    else:    
        trainer(params, calib_data, train=True)
    
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--yaml_file', type=str, default=None)
    parser.add_argument('--data_path', type=str, default=None)
    parser.add_argument('--load_model', type=bool, default=False)
    parser.add_argument('--data_type', type=str, default=None)
    parser.add_argument('--ode_name', type=str, default=None)
    # horizon options 
    parser.add_argument('--train_horizon', type=int, default=None)
    parser.add_argument('--val_horizon', type=int, default=None)
    parser.add_argument('--test_horizon', type=int, default=None)
    parser.add_argument('--batch_size', type=int, default=None)
    parser.add_argument('--delta_t', type=float, default=None)
    parser.add_argument('--nsample', type=float, default=None)
    parser.add_argument('--steer', type=int, default=None)
    parser.add_argument('--iter_correction', type=int, default=0)
    # cde and decoder architectures
    parser.add_argument('--cde_nodes', type=int, default=None)
    parser.add_argument('--cde_layers', type=int, default=None)
    parser.add_argument('--decoder_nodes', type=int, default=None)
    parser.add_argument('--decoder_layers', type=int, default=None)
    parser.add_argument('--hidden_channels', type=int, default=None)
    parser.add_argument('--epochs', type=int, default=None)
    parser.add_argument('--lr', type=float, default=None)
    # context options
    parser.add_argument('--past_ts_ctxt', type=int, default=0) # How many timesteps ID as context
    parser.add_argument('--past_feat_ctxt', type=int, default=1) # How many features as context
    parser.add_argument('--init_cond_ctxt', type=bool, default=False)

    args = parser.parse_args()
    params = vars(args)


    if params['yaml_file']:
        with open(params['yaml_file'], 'r') as f:
            yaml_config = yaml.load(f, Loader=yaml.FullLoader)
            for config in yaml_config['args']:
                if config in params:
                    params[config] = yaml_config['args'][config]             

    train(params) 

    return                

if __name__ == '__main__':
    main()