
import argparse
import os

import torch
import torch.nn as nn

from cbml_benchmark.config import cfg
from cbml_benchmark.data import build_data
from cbml_benchmark.engine.trainer import do_train, do_test
from cbml_benchmark.losses import build_loss,build_aux_loss, build_trans_loss
from cbml_benchmark.modeling import build_model
from cbml_benchmark.solver import build_lr_scheduler, build_optimizer
from cbml_benchmark.utils.logger import setup_logger
from cbml_benchmark.utils.checkpoint import Checkpointer

import faiss
import numpy as np


def train(cfg):
    logger = setup_logger(name='Train', level=cfg.LOGGER.LEVEL)
    logger.info(cfg)
    model = build_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)

    train_loader = build_data(cfg, is_train=0)
    val_loader = build_data(cfg, is_train=1)
    train_val_loader = build_data(cfg, is_train=2)

    criterion = build_loss(cfg)
    criterion_aux = None
    if cfg.LOSSES.NAME_AUX is not '':
        criterion_aux = build_aux_loss(cfg)

    trans_criterion = build_trans_loss(cfg)

    loss_param = None
    if cfg.LOSSES.NAME == 'softtriple_loss' or cfg.LOSSES.NAME == 'proxynca_loss' or cfg.LOSSES.NAME == 'center_loss' or cfg.LOSSES.NAME == 'adv_loss' or cfg.LOSSES.NAME == 'proxy_anchor_loss':
        loss_param = criterion
    if cfg.LOSSES.NAME_AUX == 'softtriple_loss' or cfg.LOSSES.NAME_AUX == 'proxynca_loss' or cfg.LOSSES.NAME_AUX == 'center_loss' or cfg.LOSSES.NAME_AUX == 'adv_loss' or cfg.LOSSES.NAME_AUX == 'proxy_anchor_loss':
        loss_param = criterion_aux
    if cfg.LOSSES.TRANS_NAME == 'softtriple_loss' or cfg.LOSSES.TRANS_NAME == 'proxynca_loss' or cfg.LOSSES.TRANS_NAME == 'center_loss' or cfg.LOSSES.TRANS_NAME == 'adv_loss' or cfg.LOSSES.TRANS_NAME == 'proxy_anchor_loss':
        loss_param = criterion_aux

    optimizer = build_optimizer(cfg, model,loss_param=loss_param)
    scheduler = build_lr_scheduler(cfg, optimizer)

    logger.info(train_loader.dataset)
    logger.info(val_loader.dataset)
    logger.info(train_val_loader.dataset)

    arguments = dict()
    arguments["iteration"] = 0

    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
    checkpointer = Checkpointer(model, optimizer, scheduler, cfg.SAVE_DIR)

    do_train(
        cfg,
        model,
        train_loader,
        val_loader,
        train_val_loader,
        optimizer,
        scheduler,
        criterion,
        trans_criterion,
        criterion_aux,
        checkpointer,
        device,
        checkpoint_period,
        arguments,
        logger
    )

def test(cfg):
    logger = setup_logger(name='Train', level=cfg.LOGGER.LEVEL)
    logger.info(cfg)
    model = build_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)
    val_loader = build_data(cfg, is_train=False)
    logger.info(val_loader.dataset)

    do_test(
        model,
        val_loader,
        logger
    )


def parse_args():
    """
  Parse input arguments
  """
    parser = argparse.ArgumentParser(description='Train a retrieval network')
    parser.add_argument(
        '--cfg',
        dest='cfg_file',
        help='config file',
        default=None,
        type=str)
    parser.add_argument(
        '--phase',
        dest='train_test',
        help='train or test',
        default='train',
        type=str)
    # parser.add_argument('--local_rank',type=int)
    return parser.parse_args()


if __name__ == '__main__':
    args = parse_args()
    # local_rank = args.local_rank
    # print(local_rank)
    # torch.cuda.set_device(local_rank)
    # torch.distributed.init_process_group(backend='nccl',init_method='env://')
    cfg.merge_from_file(args.cfg_file)
    if args.train_test == 'train':
        train(cfg)
    else:
        test(cfg)
