""" Script used for the main functionnalities of the woods package 

There is 2 mode of operation:
    - training mode: trains a model on a given dataset with a given test environment using a given algorithm
    - test mode: tests an existing model on a given dataset with a given test environment using a given algorithm

Raises:
    NotImplementedError: Some part of the code is not implemented yet
"""

import os
import json
import time
import random
import argparse
import numpy as np

import torch
from torch import nn, optim
import sys
import copy
from woods import datasets
from woods.models.assign_models import get_model
from woods import hyperparams
from woods.modelparams.assign_model_hparams import get_model_hparams
from woods.objectives.assign_objectives import get_objective_class
from woods.objectives.assign_objectives import OBJECTIVES
from woods import utils
from woods.train import train, get_accuracies
import warnings
warnings.filterwarnings("ignore")

def get_train_domain_validation(records):
    """ Perform the train domain validation model section on a single training run and returns the results
    
    The model selection is performed by computing the average training domains accuracy of all training checkpoints and choosing the highest one.
        best_step = argmax_{step in checkpoints}( mean(train_envs_acc) )

    Args:
        records (dict): Dictionary of records from a single training run

    Returns:
        float: validation accuracy of the best checkpoint of the training run
        float: test accuracy of the best checkpoint (highest validation accuracy) of the training run
    """

    # Make copy of record
    records = copy.deepcopy(records)

    flags = records.pop('flags')
    hparams = records.pop('hparams')
    env_name = datasets.get_environments(flags['dataset'])
    # print("x",flags['dataset'])
    meas = datasets.get_performance_measure(flags['dataset'])
        
    val_keys = [str(e)+'_out_'+meas for i,e in enumerate(env_name) if i != flags['test_env']]
    test_key = str(env_name[flags['test_env']]) + '_in_acc'
    test_f1_key = str(env_name[flags['test_env']]) + '_in_f1'
    test_precision_key = str(env_name[flags['test_env']]) + '_in_precision'
    test_recall_key = str(env_name[flags['test_env']]) + '_in_recall'

    val_dict = {}
    test_dict = {}
    test_f1_dict = {}
    test_precision_dict = {}
    test_recall_dict = {}

    for step, step_dict in records.items():
        # print(step, step_dict)
        val_array = [step_dict[k] for k in val_keys]
        val_dict[step] = np.mean(val_array)

        test_dict[step] = step_dict[test_key]
        test_f1_dict[step] = step_dict[test_f1_key]
        test_precision_dict[step] = step_dict[test_precision_key]
        test_recall_dict[step] = step_dict[test_recall_key]

    ## Picking the max value from a dict
    # Fastest:
    # best_step = [k for k,v in val_dict.items() if v==max(val_dict.values())][0]
    # Cleanest:
    best_step = max(val_dict, key=val_dict.get)
    
    return val_dict[best_step], test_dict[best_step], [test_f1_dict[best_step], test_precision_dict[best_step], test_recall_dict[best_step]]

if __name__ == '__main__':

    ## Args
    parser = argparse.ArgumentParser(description='Train a model on a dataset with an objective and test on a test_env')
    # Main mode
    parser.add_argument('mode', choices=['train', 'eval'])
    # Dataset arguments
    parser.add_argument('--test_env', type=int, default = None)
    parser.add_argument('--dataset', type=str)
    parser.add_argument('--holdout_fraction', type=float, default=0.2)
    # Setup arguments
    parser.add_argument('--objective', type=str, choices=OBJECTIVES)
    # Hyperparameters arguments
    parser.add_argument('--model_name', type=str, required=True, help='chooses a basemodel')
    parser.add_argument('--sample_hparams', action='store_true')
    parser.add_argument('--hparams_seed', type=int, default=0, help='Seed for random hparams (Is not used if sample_hparams is not true')
    parser.add_argument('--trial_seed', type=int, default=0, help='Trial number for seeding split_dataset and random_hparams.')
    parser.add_argument('--seed', type=int, default=0, help='Seed for everything else')
    # Directory arguments
    parser.add_argument('--data_path', type=str, default='../data/')
    parser.add_argument('--save_path', type=str, default='./results/')
    parser.add_argument('--download', action='store_true')
    # Model evaluation arguments
    parser.add_argument('--save', action='store_true')
    parser.add_argument('--model_path', type=str, default=None)
    # training device
    parser.add_argument('--device', type=str, default='0')

    # for FEDNet
    parser.add_argument('--alpha', type=float, default=0.1)
    parser.add_argument('--freq_type', type=str, default='fft')
    parser.add_argument('--constraint_type', type=str, default='contrast', choices=['contrast', 'cross'])
    parser.add_argument('--temperature', type=float, default=0.07)


    flags = parser.parse_args()

    # Device definition
    if torch.cuda.is_available():
        device = torch.device("cuda:{}".format(flags.device))
        print("Using CUDA device ", device)
    else:
        device = torch.device("cpu")


    print('Flags:')
    for k,v in sorted(vars(flags).items()):
        print("\t{}: {}".format(k, v))
    
    ## Making job ID and checking if done
    job_name = utils.get_job_name(vars(flags))
    print("job_name:", job_name)
    print(not os.path.isfile(os.path.join(flags.save_path, 'logs', job_name+'.json')))
    assert isinstance(flags.test_env, int) or flags.test_env is None, "Invalid test environment"
    if flags.mode == 'train':
        assert not os.path.isfile(os.path.join(flags.save_path, 'logs', job_name+'.json')), "\n*********************************\n*** Job Already ran and saved ***\n*********************************\n"
    
    ## Getting hparams
    hparam_sampling_seed = utils.seed_hash(flags.hparams_seed, flags.trial_seed)
    training_hparams = hyperparams.get_training_hparams(flags.dataset, hparam_sampling_seed, flags.sample_hparams)
    training_hparams['device'] = device
    
    objective_hparams = hyperparams.get_objective_hparams(flags.objective, hparam_sampling_seed, flags.sample_hparams)
    objective_hparams['device'] = device
    objective_hparams['batch_size'] = training_hparams['batch_size']
    objective_hparams['weight_decay'] = training_hparams['weight_decay']
    objective_hparams['lr'] = training_hparams['lr']

    model_hparams = get_model_hparams(flags.dataset, flags.model_name)
    model_hparams['device'] = device
    model_hparams['model_path'] = flags.model_path
    model_hparams['alpha'] = flags.alpha
    model_hparams['freq_type'] = flags.freq_type
    model_hparams['temperature'] = flags.temperature
    model_hparams['constraint_type'] = flags.constraint_type

    print('HParams:')
    for k, v in sorted(training_hparams.items()):
        print('\t{}: {}'.format(k, v))
    for k, v in sorted(model_hparams.items()):
        print('\t{}: {}'.format(k, v))
    for k, v in sorted(objective_hparams.items()):
        print('\t{}: {}'.format(k, v))

    ## Make dataset
    dataset_class = datasets.get_dataset_class(flags.dataset)
    dataset = dataset_class(flags, training_hparams)
    _, in_loaders = dataset.get_train_loaders()

    # Make some checks about the dataset
    if datasets.num_environments(flags.dataset) == 1:
        assert flags.objective == 'ERM', "Dataset has only one environment, cannot compute multi-environment penalties"

    ## Setting global seed
    random.seed(flags.seed)
    np.random.seed(flags.seed)
    torch.manual_seed(flags.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    ## Initialize a model to train
    model = get_model(dataset, model_hparams)
    print("Number of parameters = ", sum(p.numel() for p in model.parameters() if p.requires_grad))

    # Define training aid
    parameters = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.Adam(parameters, lr=training_hparams['lr'], weight_decay=training_hparams['weight_decay'])

    ## Initialize some Objective
    objective_class = get_objective_class(flags.objective)
    objective = objective_class(model, dataset, optimizer, objective_hparams)

    ## Do the thing
    model.to(device)
    if flags.mode == 'train':

        model, record, table = train(flags, training_hparams, model, objective, dataset, device)

        hparams = {}
        del training_hparams['device']
        del model_hparams['device']
        del objective_hparams['device']
        hparams.update(training_hparams)
        hparams.update(model_hparams)
        hparams.update(objective_hparams)
        record['hparams'] = hparams
        record['flags'] = vars(flags)
        best_val, best_acc, best_other_metric = get_train_domain_validation(record)
        print("Best validation accuracy: {:.4f}".format(best_val), "Best test accuracy: {:.4f}".format(best_acc))
        print("Best test F1: {:.4f}".format(best_other_metric[0]), "Best test Precision: {:.4f}".format(best_other_metric[1]), "Best test Recall: {:.4f}".format(best_other_metric[2]))


        ## Save stuff
        if flags.save:
            os.makedirs(os.path.join(flags.save_path, 'logs'), exist_ok=True)
            with open(os.path.join(flags.save_path, 'logs', job_name+'.json'), 'w') as f:
                json.dump(record, f)
            os.makedirs(os.path.join(flags.save_path, 'models'), exist_ok=True)
            torch.save(model.state_dict(), os.path.join(flags.save_path, 'models', job_name+'.pt'))
            os.makedirs(os.path.join(flags.save_path, 'outputs'), exist_ok=True)
            with open(os.path.join(flags.save_path, 'outputs', job_name+'.txt'), 'w') as f:
                f.write('HParams:\n')
                for k, v in sorted(training_hparams.items()):
                    f.write('\t{}: {}\n'.format(k, v))
                for k, v in sorted(model_hparams.items()):
                    f.write('\t{}: {}\n'.format(k, v))
                for k, v in sorted(objective_hparams.items()):
                    f.write('\t{}: {}\n'.format(k, v))
                job_id = 'Training ' + flags.objective  + ' on ' + flags.dataset + ' (H=' + str(flags.hparams_seed) + ', T=' + str(flags.trial_seed) + ')'
                f.write(table.get_string(title=job_id, border=True, hrule=0))
                f.write("\n\nBest validation accuracy: {:.4f}\nBest test accuracy: {:.4f}".format(best_val, best_acc))
                f.write("\nBest test F1: {:.4f}\nBest test Precision: {:.4f}\nBest test Recall: {:.4f}".format(best_other_metric[0], best_other_metric[1], best_other_metric[2]))


    elif flags.mode == 'eval':
        # raise NotImplementedError('This part is in quarantine')
        # Load the weights
        assert flags.model_path != None, "You must give the model_path in order to evaluate a model"
        model.load_state_dict(torch.load(os.path.join(flags.model_path)))

        # Get accuracies
        val_start = time.time()
        record = get_accuracies(objective, dataset, device)
        val_time = time.time() - val_start

        # train_names, _ = dataset.get_train_loaders()
        # t = utils.setup_pretty_table(flags)
        # if dataset.TASK == 'regression':
        #     t.add_row(['eval'] 
        #             + ["{:.1e} :: {:.1e}".format(record[str(e)+'_in_loss'], record[str(e)+'_out_loss']) for e in dataset.ENVS] 
        #             + ["{:.1e}".format(np.average([record[str(e)+'_loss'] for e in train_names]))]  
        #             + ['.']
        #             + ['.'] 
        #             + ["{:.2f}".format(val_time)])
        # else:
        #     t.add_row(['eval'] 
        #             + ["{:.2f} :: {:.2f}".format(record[str(e)+'_in_acc'], record[str(e)+'_out_acc']) for e in dataset.ENVS] 
        #             + ["{:.2f}".format(np.average([record[str(e)+'_loss'] for e in train_names]))]  
        #             + ['.']
        #             + ['.'] 
        #             + ["{:.2f}".format(val_time)])
        # print("\n".join(t.get_string().splitlines()[-2:]))
