""" Main file for training/evaluating the model """
import argparse
import datetime
import logging
import os
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

import pandas as pd
from eval_fn import evaluate
import tools.utils as utils
from dataset import ECGDataset
from vit import vit_base

from cosine_annealing_warmup import CosineAnnealingWarmupRestarts
from engine_finetune import train
from tools.utils import RunPhase


logging.basicConfig(level='INFO')
logger = logging.getLogger(__name__)


def _get_test_result_from_best_dict(best_dict,
                                    model,
                                    test_dl,
                                    test_save_dir,
                                    params,
                                    ):
    trn_ckpt_path = best_dict['model_abs_pth']
    assert trn_ckpt_path, "checkpoint path should be not empty string"
    model_saved_name = os.path.basename(trn_ckpt_path)
    logger.info(f"Init testing done for {trn_ckpt_path}")
    _, summary, metrics_string, output, labels, names, _ = evaluate(model,
                                                                    None,
                                                                    test_dl,
                                                                    params,
                                                                    RunPhase.TEST,
                                                                    test_save_dir,
                                                                    model_saved_name)
    logger.info(f"metrics_string: {metrics_string}")
    logger.info("Testing done")
    utils.save_df(params, output, labels, names, '', '', "outputs", os.path.join(test_save_dir, model_saved_name))

    print(summary)


def train_and_evaluate(params,
                       model,
                       trn_ckpt_dir,
                       train_dl,
                       valid_dl,
                       optimizer,
                       loss_fn,
                       scheduler,
                       ):

    best_dict = {'met': "valid_loss",
                 'epoch': 0,
                 'itr': 0,
                 'value': float('inf'),
                 'model_state_dict': None,
                 'optimizer_state_dict': None,
                 'model_abs_pth': ""}

    len_trn_dl = len(train_dl)

    grad_accml_step_size = params['train']['grad_accml_step']

    logger.info(f"length of train loader {len_trn_dl}")
    logger.info(f"grad accml step size: {grad_accml_step_size}")

    for epoch in range(1, params['train']['num_epochs'] + 1):
        best_dict = train(params,
                          model,
                          trn_ckpt_dir,
                          train_dl,
                          valid_dl,
                          optimizer,
                          loss_fn,
                          scheduler,
                          epoch,
                          best_dict,
                          grad_accml_step_size,
                          )
    return best_dict


def main(cfg_file_path="./configs/finetune.yaml"):
    params = utils.load_yaml(cfg_file_path)
    if params['init_set']['cur_time'] == "":
        params['init_set']['cur_time'] = datetime.datetime.now().strftime('%y%m%d%H%M%S')

    utils.set_gpu(params)
    utils.set_seed(params)

    seq_len = int(params['preproc']['signal_len_cut_sec'] * params['preproc']['sample_rate_to'])
    model_args = params['init_set']['model_cfg']
    vit_args = model_args.pop('vit_args', {})
    model_args = {**model_args, **vit_args}

    model = vit_base(num_leads=params['finetune']['num_leads'],
                     seq_len=seq_len,
                     **model_args)

    learning_rate = params['train']['optimizer']['learning_rate']
    weight_decay = params['train']['optimizer']['weight_decay']
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

    log_path = params["init_set"]["log_dir"]
    utils.make_save_dir(log_path)

    # set save paths
    current_time = params['init_set']['cur_time']
    run_name = f"finetune_{current_time}_{params['init_set']['model_cfg']['model_name']}"
    save_best_models_dir = os.path.join(params['init_set']['log_dir'], f"{run_name}.pth")
    utils.make_save_dir(save_best_models_dir)
    tail_str = f"test_result_{current_time}"
    save_test_results_dir = os.path.join(save_best_models_dir, tail_str)

    # ready finetune
    checkpoint_path = params['finetune']['load_prtrn_model_abs_path']
    assert os.path.exists(checkpoint_path), f"Pre trained model checkpoint path should exist: {checkpoint_path}"
    checkpoint = torch.load(checkpoint_path,
                            map_location=torch.device('cpu'))
    state_dict = checkpoint['model_state_dict']
    model.load_state_dict(state_dict)
    if params['finetune']['method'] == 'linear_eval':
        for param in model.parameters():
            param.requires_grad = False
    model.reset_head(num_classes=len(params['preproc']['task']))
    model.head.weight.requires_grad = True
    model.head.bias.requires_grad = True

    scheduler_cfg = params['train']['lr_scheduler']

    scheduler = CosineAnnealingWarmupRestarts(optimizer,
                                              first_cycle_steps=scheduler_cfg['first_cycle_steps'],
                                              cycle_mult=scheduler_cfg['cycle_mult'],
                                              max_lr=scheduler_cfg['cos_max_lr'],
                                              min_lr=scheduler_cfg['cos_min_lr'],
                                              warmup_steps=scheduler_cfg['warmup_steps'],
                                              gamma=scheduler_cfg['gamma'])

    loss_fn = F.cross_entropy

    # use cuda
    if params['cuda']:
        model = model.cuda()

    gpus = params['device_config']['gpu'].split(',')
    if len(gpus) > 1:
        model = torch.nn.DataParallel(model, device_ids=list(range(len(gpus))))

    # set dataloaders
    path_x = params['dataset']['waveform_dir']
    trainset = ECGDataset(path_x,
                          pd.read_pickle(os.path.join(params['dataset']['index_dir'],
                                                      params['dataset']['train_fname'])),
                          params,
                          'train')
    train_loader = DataLoader(trainset,
                              batch_size=params['train']['batch_size'],
                              shuffle=True,
                              num_workers=params['init_set']['num_workers'])

    valset = ECGDataset(path_x,
                        pd.read_pickle(os.path.join(params['dataset']['index_dir'],
                                                    params['dataset']['val_fname'])),
                        params,
                        'test')
    val_loader = DataLoader(valset,
                            batch_size=params['train']['batch_size'],
                            shuffle=False,
                            num_workers=params['init_set']['num_workers'])
    testset = ECGDataset(path_x,
                         pd.read_pickle(os.path.join(params['dataset']['index_dir'],
                                                     params['dataset']['test_fname'])),
                         params,
                         'test')
    test_loader = DataLoader(testset,
                             batch_size=params['train']['batch_size'],
                             shuffle=False,
                             num_workers=params['init_set']['num_workers'])

    logger.info("Init finetune done")

    best_dict = train_and_evaluate(params,
                                   model,
                                   save_best_models_dir,
                                   train_loader,
                                   val_loader,
                                   optimizer,
                                   loss_fn,
                                   scheduler,
                                   )
    logger.info("Finetuing done")
    _get_test_result_from_best_dict(best_dict,
                                    model,
                                    test_loader,
                                    save_test_results_dir,
                                    params,
                                    )


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--cfg_file', default='configs/params.yaml', type=str,
                        help="Config file name under the config directory.")
    parser.add_argument('-log', '--loglevel', default='info',
                        help='Provide logging level. Example --loglevel debug, default=warning')
    parser.add_argument('--override_yaml', default=None, help='override yaml file that modifies params')
    parser.add_argument('--train_test_mode', default=None, help='train test mode: train, test, both, finetune')
    args = parser.parse_args()
    logging.basicConfig(level=args.loglevel.upper())

    main(args.cfg_file)
