#!/usr/bin/env python3
"""
major actions here: fine-tune the features and evaluate different settings
"""
import os
import torch
import warnings

import numpy as np
import random

from time import sleep
from random import randint

import src.utils.logging as logging
from src.configs.config import get_cfg
from src.data import loader as data_loader
from src.engine.evaluator import Evaluator
from src.engine.trainer import Trainer
from src.models.build_model import build_model
from src.utils.file_io import PathManager
from tensorboardX import SummaryWriter

from launch import default_argument_parser, logging_train_setup
warnings.filterwarnings("ignore")
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

def setup(args, lr, wd, l1, l2, flag):
    """
    Create configs and perform basic setups.
    """
    cfg = get_cfg()
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)

    # setup dist
    #cfg.DIST_INIT_PATH = "tcp://{}:12399".format(os.environ["SLURMD_NODENAME"])

    # setup output dir
    # output_dir / data_name / feature_name / lr_wd / run1
    output_dir = cfg.OUTPUT_DIR
    cfg.SOLVER.BASE_LR = lr
    cfg.SOLVER.WEIGHT_DECAY = wd
    output_folder = os.path.join(
        cfg.DATA.NAME, cfg.DATA.FEATURE, f"lr{lr}_wd{wd}_l1{l1}_l2{l2}")
    lr = lr / 256 * cfg.DATA.BATCH_SIZE
    cfg.SOLVER.BASE_LR = lr

    # train cfg.RUN_N_TIMES times
    count = 1
    while count <= cfg.RUN_N_TIMES:
        if flag == False:
            output_path = os.path.join(output_dir, output_folder, f"run" + str(count))
        else:
            output_path = os.path.join(output_dir, output_folder, f"runfinal")
        # pause for a random time, so concurrent process with same setting won't interfere with each other. # noqa
        sleep(randint(3, 30))
        if not PathManager.exists(output_path):
            PathManager.mkdirs(output_path)
            cfg.OUTPUT_DIR = output_path
            break
        else:
            count += 1
    if count > cfg.RUN_N_TIMES:
        raise ValueError(
            f"Already run {cfg.RUN_N_TIMES} times for {output_folder}, no need to run more")

    cfg.freeze()
    return cfg


def get_loaders(cfg, logger, final):
    logger.info("Loading training data (final training data for vtab)...")
    if final == True:
        train_loader = data_loader.construct_trainval_loader(cfg)
    else:
        train_loader = data_loader.construct_train_loader(cfg)

    logger.info("Loading validation data...")
    # not really needed for vtab
    val_loader = data_loader.construct_val_loader(cfg)
    logger.info("Loading test data...")
    if cfg.DATA.NO_TEST:
        logger.info("...no test data is constructed")
        test_loader = None
    else:
        test_loader = data_loader.construct_test_loader(cfg)
    return train_loader,  val_loader, test_loader


def train(cfg, args, lr, wd, l1, l2):
    # clear up residual cache from previous runs
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    # main training / eval actions here

    # fix the seed for reproducibility
    if cfg.SEED is not None:
        torch.manual_seed(cfg.SEED)
        np.random.seed(cfg.SEED)
        random.seed(3)

    # setup training env including loggers
    logging_train_setup(args, cfg)
    logger = logging.get_logger("ELSE")

    train_loader, val_loader, test_loader = get_loaders(cfg, logger, False)
    logger.info("Constructing models...")
    model, cur_device = build_model(cfg)

    logger.info("Setting up Evalutator...")
    evaluator = Evaluator()
    logger.info("Setting up Trainer...")
    flag = True
    trainer = Trainer(cfg, model, evaluator, cur_device, flag, l1, l2)
    s = 'caltech101_ViT_' + str(lr).replace('.', '_') + '_' + str(wd).replace('.', '_') + '_' + str(l1).replace('.', '_') + '_' + str(l2).replace('.', '_')
    writer = SummaryWriter(s)
    s = 'save/caltech101/checkpoint_vit'+str(lr).replace('.', '_') + '_' + str(wd).replace('.', '_') + '_' + str(l1).replace('.', '_') + '_' + str(l2).replace('.', '_')+'.pth'
    output_vit = s
    total = 100
    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('number of params (M): %.2f' % (n_parameters / 1.e6))

    best_acc = trainer.train_classifier(train_loader, val_loader, test_loader, writer, output_vit, total)

    if cfg.SOLVER.TOTAL_EPOCH == 0:
        trainer.eval_classifier(test_loader, "test", 0)
    writer.close()
    return best_acc

def trainfinal(cfg, args, lr, wd, l1, l2):
    # clear up residual cache from previous runs
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    # main training / eval actions here

    # fix the seed for reproducibility
    if cfg.SEED is not None:
        torch.manual_seed(cfg.SEED)
        np.random.seed(cfg.SEED)
        random.seed(0)

    # setup training env including loggers
    logging_train_setup(args, cfg)
    logger = logging.get_logger("ELSE")

    train_loader, val_loader, test_loader = get_loaders(cfg, logger, True)
    logger.info("Constructing models...")
    model, cur_device = build_model(cfg)

    logger.info("Setting up Evalutator...")
    evaluator = Evaluator()
    logger.info("Setting up Trainer...")
    flag = True
    trainer = Trainer(cfg, model, evaluator, cur_device, flag, l1, l2)
    s = 'caltech101_ViT_' + str(lr).replace('.', '_') + '_' + str(wd).replace('.', '_') + '_' + str(l1).replace('.', '_') + '_' + str(l2).replace('.', '_')
    writer = SummaryWriter(s)
    s = 'save/caltech101/checkpoint5_vit'+str(lr).replace('.', '_') + '_' + str(wd).replace('.', '_') + '_' + str(l1).replace('.', '_') + '_' + str(l2).replace('.', '_')+'.pth'
    output_vit = s
    total = 100
    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('number of params (M): %.2f' % (n_parameters / 1.e6))

    trainer.train_final(train_loader, val_loader, test_loader, writer, output_vit, total)

    if cfg.SOLVER.TOTAL_EPOCH == 0:
        trainer.eval_classifier(test_loader, "test", 0)
    writer.close()

def main(args):
    """main function to call from workflow"""
    'hyper-paramters'
    lr_range = [0.01, 0.05, 0.1, 0.25, 0.5, 1.25]
    wd_range = [0.01, 0.001, 0.0001, 0.00001]
    l1_range = [0.0, 0.00001, 0.0001, 0.001]
    l2_range = [0.01, 0.001, 0.0001, 0.00001, 0.0]
    best = 0
    final = []
    for lr in lr_range:
        for wd in wd_range:
            for l1 in l1_range:
                for l2 in l2_range:
                    print("hyper:", lr, wd, l1, l2)
                    print("Final:", final)
                    print("Best:", best)
                    # set up cfg and args
                    cfg = setup(args, lr, wd, l1, l2, False)

                    # Perform training.
                    acc = train(cfg, args, lr, wd, l1, l2)
                    if acc > best:
                        final = [lr, wd, l1, l2]
                        best = acc
    cfg = setup(args, final[0], final[1], final[2], final[3], True)
    trainfinal(cfg, args, final[0], final[1], final[2], final[3])


if __name__ == '__main__':
    args = default_argument_parser().parse_args()
    main(args)
