import argparse
import yaml
from typing import Dict, Tuple
from prepare_data import DataProcessor
from diffrax_neural_ode import trainer
import torch
import os
from utils import check_or_make_folder, set_seed
import pickle
import sys

def train(params: Dict):

    # change directory to save model/data
    print(params)
    params['ode_name'] = params['data_path'].split("/")[1]
    params['dataset_name'] = params['data_path'].split("/")[-1].split(".")[0]
    data_dir = params['data_path'].split("/")[1]
    check_or_make_folder(f"./{data_dir}")
    os.chdir(f"./{data_dir}")
    torch.cuda.empty_cache()
    set_seed(params)

    # prepare data and train attention model
    data_proc = DataProcessor(params)
    dataset = data_proc.get_data(load_model=params['load_model']) # update params dict
    if not params['load_model']:
        trainer(params, dataset, train=True)
    else:
        folder = f"./seed{params['seed']}_{params['dataset_name']}_{params['epochs']}epochs_{params['lr']}lr_{params['num_nodes']}nodes_{params['num_layers']}layers_{params['train_horizon']}trHz_{params['val_horizon']}valHz_{params['interp_horizon']}intHz_{params['batch_size']}bs_{params['delta_t']}delT_{params['n_sample']}nsample"
        with open(folder + "/data_dict.pkl", "rb") as f:
            data_dict = pickle.load(f)
        params = data_dict['params']    
        trainer(params, dataset, train=False)
    
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--yaml_file', type=str, default=None)
    parser.add_argument('--data_path', type=str, default=None)
    parser.add_argument('--load_model', type=bool, default=False)
    parser.add_argument('--train_val_ratio', type=float, default=0.8)
    # define train and val horizon
    parser.add_argument('--ode_name', type=str)
    parser.add_argument('--train_horizon', type=int)
    parser.add_argument('--val_horizon', type=int)
    parser.add_argument('--interp_horizon', type=int)
    parser.add_argument('--n_sample', type=float, default=None)
    # dataset info
    parser.add_argument('--delta_t', type=float)
    #parser.add_argument('--timesteps', type=int)
    # func params
    parser.add_argument('--epochs', type=int, default=None)
    parser.add_argument('--lr', type=float, default=0.0001)
    parser.add_argument('--batch_size', type=int, default=256)
    parser.add_argument('--num_nodes', type=int, default=200)
    parser.add_argument('--num_layers', type=int, default=3)

    args = parser.parse_args()
    params = vars(args)

    if params['yaml_file']:
        with open(args.yaml_file, 'r') as f:
            yaml_config = yaml.load(f, Loader=yaml.FullLoader)
            for config in yaml_config['args']:
                if config in params:
                    params[config] = yaml_config['args'][config]             

    train(params) 

    return                

if __name__ == '__main__':
    main()