import os, argparse
from logging import WARNING
import warnings
warnings.filterwarnings("ignore")

import psutil
# import ray

# from logger import Logger
from utils import set_seed, save_metrics_params, update_params_from_cmdline, save_settings_to_json

import warcraft_shortest_path.data_utils as warcraft_shortest_path_data
import warcraft_shortest_path.trainers as warcraft_shortest_path_trainers

import torch
print("torch.cuda.is_available() = ")
print( torch.cuda.is_available() )

import pickle
import pandas as pd
import numpy as np



os.environ["CUDA_VISIBLE_DEVICES"]="0"


dataset_loaders = {
    "warcraft_shortest_path": warcraft_shortest_path_data.load_dataset,
    "warcraft_shortest_path_multi": warcraft_shortest_path_data.load_dataset_multi_same_map,
    # "warcraft_shortest_path_multi": warcraft_shortest_path_data.load_dataset_multi

}

trainer_loaders = {
    "warcraft_shortest_path": warcraft_shortest_path_trainers.get_trainer,
    "warcraft_shortest_path_multi": warcraft_shortest_path_trainers.get_trainer,

}

required_top_level_params = [
    "model_dir",
    "seed",
    "loader_params",
    "problem_type",
    "trainer_name",
    "trainer_params",
    "num_epochs",
    "evaluate_every",
    "save_visualizations"
]
optional_top_level_params = ["num_cpus", "use_ray", "default_json", "id", "fast_mode", "fast_forward_training"]

def verify_top_level_params(**kwargs):
    for kwarg in kwargs:
        if kwarg not in required_top_level_params and kwarg not in optional_top_level_params:
            raise ValueError("Unknown top_level argument: {}".format(kwarg))

    for required in required_top_level_params:
        if required not in kwargs.keys():
            raise ValueError("Missing required argument: {}".format(required))


def main(args):
    params = update_params_from_cmdline(args,verbose=True)
    os.makedirs(params.model_dir, exist_ok=True)
    save_settings_to_json(params, params.model_dir)

    num_cpus = params.get("num_cpus", psutil.cpu_count(logical=True))
    use_ray = params.get("use_ray", False)
    fast_forward_training = params.get("fast_forward_training", False)
    if use_ray:
        ray.init(
            num_cpus=num_cpus,
            logging_level=WARNING,
            ignore_reinit_error=True,
            redis_max_memory=10 ** 9,
            log_to_driver=False,
            **params.get("ray_params", {})
        )

    set_seed(params.seed)

    # Logger.configure(params.model_dir, "tensorboard")

    dataset_loader = dataset_loaders[params.problem_type]
    train_iterator, test_iterator, metadata = dataset_loader(**params.loader_params)

    trainer_class = trainer_loaders[params.problem_type](params.trainer_name)
    fast_mode = params.get("fast_mode", False)
    trainer = trainer_class(
        train_iterator=train_iterator,
        test_iterator=test_iterator,
        metadata=metadata,
        fast_mode=fast_mode,
        **params.trainer_params
    )
    train_results = {}

    #JK 0630
    epoch_regrets = []
    epoch_path_lens, epoch_path_lens_diff, epoch_path_lens_species =[], [], []
    epoch_val_path_lens, epoch_val_path_lens_diff, epoch_val_path_lens_species, epoch_val_path_lens_gini = [], [], [], []
    epoch_val_owa_path_lens, epoch_owa_path_lens = [], []
    epoch_val_regret = []
    batch_time = []
    training_interval = []


    for i in range(params.num_epochs):

        train_results = trainer.train_epoch()
        batch_time.append(train_results['batch_time'])
        epoch_val_path_lens += train_results["val_path_lens_list"]
        epoch_val_owa_path_lens += train_results["val_owa_path_lens_list"]
        epoch_val_path_lens_diff +=  train_results["val_path_lens_diff_list"]
        epoch_val_path_lens_species.extend(train_results["val_path_lens_species_list"])
        epoch_val_path_lens_gini.extend(train_results["val_path_lens_gini_list"])
        epoch_val_regret += train_results["val_regret"]
        # print('train_results', train_results)
        training_interval += train_results['training_interval']
        if trainer.early_stopping: 
            break

    torch.save(trainer.model.state_dict(), "combres1.pt")
    print('model saved')

    eval_results = trainer.evaluate(is_test=True)
    print(eval_results)
    train_results = train_results or {}
    #save_metrics_params(params=params, metrics={**eval_results, **train_results})
    print('epoch_val_regret', epoch_val_regret)
    if params.save_visualizations:
        print("Saving visualization images")
        trainer.log_visualization()

    # if use_ray:
    #     ray.shutdown()



    # print(params.run)
    print(params.trainer_name)

    print( './p/wcsp_batch_regrets_'+params.trainer_name+ '_'+'.p' )

    results_dump = {}
    # results_dump["run"] = params.run
    results_dump["seed"] = params.seed
    results_dump["trainer_name"]  = params.trainer_name
    # results_dump["epoch_path_lens"] = epoch_path_lens
    # results_dump["epoch_owa_path_lens"] = epoch_owa_path_lens
    # results_dump["epoch_path_lens_diff"] = epoch_path_lens_diff
    # results_dump["epoch_path_lens_species"] = epoch_path_lens_species
    results_dump["val_path_lens"] = epoch_val_path_lens
    results_dump["val_owa_path_lens"] = epoch_val_owa_path_lens

    results_dump["val_path_lens_diff"] = epoch_val_path_lens_diff
    results_dump["val_path_lens_species"] = np.stack(epoch_val_path_lens_species)
    results_dump["val_path_lens_gini"] = epoch_val_path_lens_gini
    results_dump["val_regret"] = epoch_val_regret

    results_dump["best_path_lens"] = eval_results['path_lens']
    results_dump["best_path_lens_diff"] = eval_results['path_lens_diff']
    results_dump["gini_path_lens"] = eval_results['gini_path_lens']
    results_dump["owa_path_lens"] = eval_results['owa_path_lens']
    results_dump["regret"] = eval_results['regret']
    results_dump["owa_optimal"] = trainer.owa_true_path_lens

    for i in range(params.loader_params.n_task): 
        results_dump["path_lens_species_{}".format(i)] = eval_results['path_lens_species_{}'.format(i)]

    results_dump["beta"] = params.trainer_params.beta
    results_dump["index"] = params.index
    results_dump['problem_type'] = params.problem_type
    results_dump["species"] = params.loader_params.task_idx
    results_dump["nspecies"] = params.loader_params.n_task
    results_dump['training_time'] = trainer.training_time
    results_dump['batch_time'] = batch_time
    results_dump['training_interval'] = training_interval
    results_dump['stopping_epoch'] = i + 1
    results_dump['num_sample'] = params.loader_params.num_sample
    results_dump['owa_weight'] = params.trainer_params.owa_weight

    print('total training time', trainer.training_time)

    print('results_dump', results_dump)

    version_suffix = 'ver6'

    # out_fp = os.path.join(ROOT_DIR,'p/wcsp_pickle_'+params.problem_type+"_"+params.trainer_name+"_"+ str(params.loader_params.n_task)+ '_'+str(params.loader_params.task_idx  )+ '_' +str(params.trainer_params.beta)+"_"+ params.trainer_params.owa_weight + '_'+ str(params.index) +"_"+ str(params.seed) +'_ver3.p')
    file_name = 'wcsp_pickle_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}.p'.format(params.problem_type, params.trainer_name, params.loader_params.n_task, params.loader_params.task_idx, params.trainer_params.beta, params.trainer_params.owa_weight,params.loader_params.num_sample ,params.index, params.seed, version_suffix)
    out_fp = os.path.join(ROOT_DIR,'p', file_name)
    print('out_fp', out_fp)
    print('ROOT_DIR', ROOT_DIR)
    pickle.dump(results_dump, open(out_fp, 'wb'))
    print('save results files to: ', out_fp)

    csv_outs = eval_results
    csv_outs["Trainer_Name"] = params.trainer_name
    csv_outs["seed"] = params.seed

    csv_outs = {k:[v] for (k,v) in csv_outs.items()   }
    df_outs = pd.DataFrame.from_dict(csv_outs)
    outPathCsv = os.path.join(ROOT_DIR,'csv', file_name[:-1] + '.csv')
    df_outs.to_csv(outPathCsv)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Multitask training script')

    # Model parameters.

    parser.add_argument('--seed', type=int, default=1111,
                    help='random seed')
    parser.add_argument('--n_task', type=int, default=1111,
                    help='number of task')
    parser.add_argument('--task_idx', type=int, default=0,
                    help='corresponding species index,task_idx=-1 meaning use same data for all species for sanity check')
    parser.add_argument('--beta', type=float,default=10,
                    help="OWA's smoothing parameter")
    parser.add_argument('--config_file', type=str, help='fp of configuration file ')
    parser.add_argument('--trainer_name', type=str, help='name of training method',  
                    choices=['DijkstraMultiDescent', 'DijkstraMultiOWADescent','DijkstraMultiOWADescent2', 'DijkstraMultiGradNormDescent','DijkstraOWADescent', 'DijkstraDescent', 'BaselineMulti'])
    parser.add_argument('--model_name', type=str, help='name of NN architecture',  
                    choices=['PartialResNetMTL', 'PartialResNet','CombResnet18', 'MLPMTL'])
    parser.add_argument('--owa_weight',default='gini', type=str, help='way to initualize OWA weight',  
                    choices=['one', 'gini', 'ginisquare'])
    parser.add_argument('--num_sample',default=500, type=int, help='number of training sample')
    parser.add_argument('--index', default=0,type=str, help='name of experience', )

    args = parser.parse_args()
    

    main(args)