import torch
import argparse

from utils import set_all_seeds
from utils import load_data_brownian
from utils import load_data_stable

from train_sde import train_boolode

from models.sde_nn import PerturbSDE
from models.sde_nn import PerturbSDE_monotonic
import logging
from datetime import datetime
import os
import json
from constants import *

def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

def parse_args():
    parser = argparse.ArgumentParser(description="Select Seed.")
    parser.add_argument('--seed', type=int, required=True, help="Select Seed.")
    parser.add_argument('--latent_dim', type=int, required=True, help="Select Latent Dimensions.")
    parser.add_argument('--train_mode', type=str, required=True, help="Select Train Mode.")
    parser.add_argument('--intervention', type=str, required=True, help="Select Intervention Type.")
    parser.add_argument('--sde_version', type=str, required=True, help="Select ODE Version.")
    parser.add_argument('--L1', type=str2bool, required=True, help="L1 Penalization (True/False).")
    parser.add_argument('--multi_interv', type=str2bool, required=True, help="Mutiple Interventions (True/False).")
    parser.add_argument('--data', type=str, required=True, help="Select Dataset.")
    parser.add_argument('--var_perturb', type=str2bool, required=True, help="Variable Perturbation (True/False).")
    return parser.parse_args()


def main(interv_file, data_file, max_epoch = 200, patience = 10,train_mode = 'diffusion' , max_time = 16 , t_span = 50, t_flow = 1., std = 0.3, latent_dim=200, knock_out = False, sde_version = 'original', L1 = False, multi_interv = True, perfect_interv = False
         ):

    logging.info("GRN Analysis. Simulated Data")

    initial_dist, validation_dict, train_dict, _, __ = load_data_stable(data_file, interv_file)
    
    std = std
    mean = 0 

    if train_mode == 'diffusion':
        brownian_dict, _ = load_data_brownian(train_dict,mean, std)
    elif train_mode == 'shooting_method':
        brownian_dict = {}

    with open(interv_file, 'r') as json_file:
        interv_dict = json.load(json_file)

    batch_size = initial_dist.shape[0]

    if sde_version == 'SDE':
        sde_model = PerturbSDE(batch_size, initial_dist.shape[1],latent_dim=latent_dim).to(device).type(dtype)
    elif sde_version == 'SDE_monotonic':
        sde_model = PerturbSDE_monotonic(batch_size, initial_dist.shape[1],latent_dim=latent_dim).to(device).type(dtype)

    optimizer = torch.optim.Adam(sde_model.parameters(), amsgrad=True)
    
    train_boolode(new_folder_path, interv_dict, initial_dist, brownian_dict, train_dict, validation_dict, sde_model, optimizer, t_flow = t_flow, t_span = t_span,max_epoch = max_epoch,patience = patience, max_time = max_time, knock_out = knock_out, L1 = L1,  perfect_interv = perfect_interv)
    
    model_save_path = os.path.join(new_folder_path, current_time + '.pth')

    torch.save(sde_model, model_save_path)


if __name__ == '__main__':  

    args = parse_args()
    
    seed_num = args.seed
    train_mode = args.train_mode
    intervention = args.intervention
    sde_version = args.sde_version
    L1 = args.L1
    multi_interv = args.multi_interv
    data = args.data
    latent_dim = args.latent_dim
    var_perturb = args.var_perturb

    if intervention == 'knock_out':
        knock_out = True
    else:
        knock_out = False

    set_all_seeds(seed_num)


    base_dir = '.'


    log_dir_name = 'models_and_logs'


    log_dir_path = os.path.join(base_dir, log_dir_name)

    current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')

    new_folder_path = os.path.join(log_dir_path, current_time)

    os.makedirs(new_folder_path, exist_ok=True)

    log_file_path = os.path.join(new_folder_path, current_time + '.log')

    logging.basicConfig(filename=log_file_path,
                        encoding='utf-8',
                        level=logging.INFO,
                        format='%(message)s')
    logging.info(current_time)

    logging.info('Intervention: ' + intervention)
    logging.info('ODE Version: ' + sde_version)
    logging.info('L1: ' + str(L1))
    logging.info('train_mode: ' + train_mode)
    logging.info('Variable Perturbation: ' + str(var_perturb))

    if intervention == 'over_expression_perfect_interv':
        perfect_interv = True
    elif intervention == 'over_expression_imperfect_interv':
        perfect_interv = False
    else:
        raise ValueError(f"Intervention not implemented.")




    data_file = 'simulated_data/BoolODE_adata.h5ad'
    interv_file = 'simulated_data/BoolODE_GeneToIndex.json'
    
    data_file = root_dir_random_DAGs + data + '.h5ad'
    interv_file = root_dir_random_DAGs + data + '.json'

    

    t_span = 10
    patience = 10
    max_epoch = 150
    max_time = 22

    
    std = 0.3
    t_flow = 0.1 

    logging.info("Latent dimensions: " + str(latent_dim) + ".")
    logging.info('Dataset: ' + data)
    logging.info('Perfect Interv: ' + str(perfect_interv))
    logging.info('Max epoch: ' + str(max_epoch))



    main(interv_file, data_file, max_epoch = max_epoch, patience = patience, max_time = max_time , t_span = t_span, t_flow = t_flow, std = std, latent_dim=latent_dim, train_mode = train_mode, knock_out = knock_out, sde_version = sde_version, L1 = L1, multi_interv= multi_interv, perfect_interv = perfect_interv)

