# Pipeline entry point
import logging
import os
import torch
from torchvision import datasets
import infra.logger as lib_log
import infra.args_config as lib_args
import utils.exp_config as lib_exp
import utils.datasets_config as lib_dscf
import utils.expdata_config as lib_data
from model.vit import ViTForClassfication
from model.mot import ViTGating, MoTTracker
import model.train as lib_tr
import numpy as np

def main_per_round(exp):
    exp.stage_init()
    _ = lib_tr.train_epoch(exp)

def main_per_run(exp):
    #####################  Init. Before Each Run  #####################
    logger.warning(f'This is the #.{exp.r.runID} / {exp.args.nb_runs} run.')
    # Reset random seeds to ensure consistency among runs
    exp.rng_states_init()
    exp.random_states_init()
    # Init. loader for this run
    exp.loader_init()
    # Init. models for this run
    if exp.args.vit_model == 'attention':
        logger.info(f'Initialize ViT model.')
        exp.r.all_models = [
            ViTForClassfication(exp.ds.vit_param)
            for _ in range(exp.args.n_expert) ]
    else:
        raise ValueError

    exp.r.moe_gates = ViTGating(
        config=exp.ds.vit_param)
    exp.r.moe_tracker = MoTTracker(
        moe_gating=exp.r.moe_gates,
        expert_models=exp.r.all_models)
        
    ########################  Begin each round  #######################
    for exp.r.roundID in range(0, exp.ds.lr_param['Round']):
        main_per_round(exp)
        if exp.roundid_for_early_stop >= 0 and exp.r.roundID == exp.roundid_for_early_stop:
            logger.warning(f"Early stop at round={exp.roundid_for_early_stop} due to manual configuration.")
            break
    ###############   This is the end of current run   ################
    exp.dump_model(is_final=True)

def main_loop(exp):
    np.set_printoptions(suppress=True, precision=3)
    logger = logging.getLogger()
    exp.r.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    if not exp.args.override_can_use_cpu and not torch.cuda.is_available():
        logger.error(f'GPU device not found; device={exp.r.device}')
        raise Exception
    exp.rng_seeds_setup()
    exp.rng_states_init()
    exp.ds = lib_dscf.init_dataset(exp)
    exp.r.trainset = datasets.ImageFolder(
        os.path.join(exp.ds.basefolder_fullpath, 'train'), 
        exp.ds.transform_train)
    exp.r.testset = datasets.ImageFolder(
        os.path.join(exp.ds.basefolder_fullpath, 'test'), 
        exp.ds.transform_test)
    exp.d = lib_data.ExpData(exp, exp.r.trainset, exp.r.testset)
    logger.info(f'Saving dataset file at {exp.path_expfolder}/dataset.npz')
    exp.handle_override()
    
    logger.info(f'-------------------------- Begin Setup Report ---------------------------')
    exp.rng_seeds_print()
    exp.print_attributes()
    exp.ds.print_attributes()
    if exp.args.notes != "":
        logger.warning(f'Experiment notes: {exp.args.notes}')
    logger.warning(f'Dataset=[ {exp.args.dataset} ], Earlystop after Round=[ {exp.roundid_for_early_stop} ], N_Expert = [ {exp.args.n_expert} ]')
    logger.info(f'-------------------------- End of Setup Report --------------------------')
    for exp.r.runID in range(0, exp.args.nb_runs):
        main_per_run(exp)

if __name__=='__main__':
    args = lib_args.config_args()
    exp = lib_exp.ExpStatus(args, path_exp_config='main/config/exp_global.yaml')
    lib_log.config_log(os.path.join(exp.path_expfolder, 'exp.log'))
    logger = logging.getLogger()
    logger.info(f'Serial header for this experiment: {exp.exp_header}')
    if not exp.f1:
        logger.warning(f'Intermediate folder at {exp.path_explog_fullpath} not exist; generating empty folder.')
    if exp.f2:
        logger.warning(f'Result subfolder {exp.path_expfolder} already existed.')
    else:
        logger.info(f'Create result subfolder under {exp.path_expfolder}.')
    if exp.args.override_set_dataset_path != 'none':
        logger.warning(f'override_set_dataset_path triggerred!')
    logger.info(f'Args parsing and exp initialization finished.')
    try:
        main_loop(exp)
    except Exception as e:
        logger.critical(e, exc_info=True)
