import os
import sys
import time
import datetime
import argparse
import logging
import os.path as osp
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch import distributed as dist
#from apex import amp

from configs.default_img import get_img_config
from configs.default_vid import get_vid_config
from data import build_dataloader
from models import build_model
from losses import build_losses
from tools.utils import save_checkpoint, set_seed, get_logger
from train import train_cal
from test import test, test_prcc

'''MADE import'''
from loss import make_loss
from torch.cuda import amp
from solver import make_optimizer
from solver.scheduler_factory import create_scheduler


VID_DATASET = ['ccvid']


def parse_option():
    parser = argparse.ArgumentParser(description='Train clothes-changing re-id model with clothes-based adversarial loss')
    parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file')
    # Datasets
    parser.add_argument('--root', type=str, help="your root path to data directory")
    parser.add_argument('--dataset',type=str,help="ltcc, prcc, vcclothes, ccvid, last, deepchange")
    # Miscs
    parser.add_argument('--output', type=str, help="your output path to save model and logs")
    parser.add_argument('--resume', type=str, metavar='PATH')
    parser.add_argument('--amp', action='store_true', help="automatic mixed precision")
    parser.add_argument('--eval', action='store_true', help="evaluation only")
    parser.add_argument('--tag', type=str, help='tag for log file')
    parser.add_argument('--gpu', default='0,1', type=str, help='gpu device ids for CUDA_VISIBLE_DEVICES')

    args, unparsed = parser.parse_known_args()
    if args.dataset in VID_DATASET:
        config = get_vid_config(args)
    else:
        config = get_img_config(args)

    return config


def main(config):
    test_mode = True
    # Build dataloader
    if config.DATA.DATASET == 'prcc':
        trainloader, queryloader_same, queryloader_diff, galleryloader, dataset, train_sampler,_,_ = build_dataloader(config)
    else:
        trainloader, queryloader, galleryloader, dataset, train_sampler,_ = build_dataloader(config)
   
    # Build model
    model = build_model(config, dataset.num_train_pids, dataset.num_train_clothes)
   
    '''MADE loss function and optimizer setiing'''
    loss_func, center_criterion = make_loss(config, num_classes=dataset.num_train_pids)
    optimizer, optimizer_center = make_optimizer(config, model, center_criterion)
    scheduler = create_scheduler(config, optimizer)

    start_epoch = config.TRAIN.START_EPOCH
    if config.MODEL.RESUME or test_mode == True:
        checkpoint_path = '/home/ad358172/briar_24/MADE/logs/ltcc/eva02_meta_cloth_l1'
        logger.info("Loading checkpoint from '{}'".format(checkpoint_path))
        model.load_param(checkpoint_path)
        # # checkpoint = torch.load(checkpoint_path)
        # #model.load_state_dict(checkpoint['model_state_dict'])
        
        # model_dict = model.state_dict()
        # load_state_dict = checkpoint['model_state_dict']
        
        # # now remove the unwanted keys:
        # if "head.bias" in load_state_dict:
        #     del load_state_dict["head.bias"]

        # if "head.weight" in load_state_dict:
        #     del load_state_dict["head.weight"] 

        # msg = model.load_state_dict(load_state_dict, strict=False)
        # print(f"resume model: {msg}")
        print('model loaded except head weights')

    local_rank = dist.get_rank()
    model = model.cuda(local_rank)
    torch.cuda.set_device(local_rank)

        # if config.TRAIN.AMP:
        #     model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
    
    model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
    
    if config.EVAL_MODE or test_mode == True:
        logger.info("Evaluate only")
        with torch.no_grad():
            if config.DATA.DATASET == 'prcc':
                test_prcc(model, queryloader_same, queryloader_diff, galleryloader, dataset)
            else:
                test(config, model, queryloader, galleryloader, dataset)
        return

    start_time = time.time()
    train_time = 0
    best_rank1 = -np.inf
    best_epoch = 0
    logger.info("==> Start training")
    scaler = amp.GradScaler()

    for epoch in range(start_epoch, config.TRAIN.MAX_EPOCH):
        scheduler.step(epoch)
        train_sampler.set_epoch(epoch)
        start_train_time = time.time()
        train_cal(config, epoch, model, loss_func, center_criterion, optimizer,optimizer_center, trainloader,scaler,scheduler)
        train_time += round(time.time() - start_train_time)        
        
        if (epoch+1) > config.TEST.START_EVAL and config.TEST.EVAL_STEP > 0 and \
            (epoch+1) % config.TEST.EVAL_STEP == 0 or (epoch+1) == config.TRAIN.MAX_EPOCH:
            logger.info("==> Test")
            torch.cuda.empty_cache()
            if config.DATA.DATASET == 'prcc':
                rank1 = test_prcc(model, queryloader_same, queryloader_diff, galleryloader, dataset)
            else:
                rank1 = test(config, model, queryloader, galleryloader, dataset)
            torch.cuda.empty_cache()
            is_best = rank1 > best_rank1
            if is_best:
                best_rank1 = rank1
                best_epoch = epoch + 1

            model_state_dict = model.module.state_dict()
            if local_rank == 0:
                save_checkpoint({
                    'model_state_dict': model_state_dict,
                    #'classifier_state_dict': classifier_state_dict,
                    #'clothes_classifier_state_dict': clothes_classifier_state_dict,
                    'rank1': rank1,
                    'epoch': epoch,
                }, is_best, osp.join(config.OUTPUT, 'checkpoint_ep' + str(epoch+1) + '.pth.tar'))
        #scheduler.step()

    logger.info("==> Best Rank-1 {:.1%}, achieved at epoch {}".format(best_rank1, best_epoch))

    elapsed = round(time.time() - start_time)
    elapsed = str(datetime.timedelta(seconds=elapsed))
    train_time = str(datetime.timedelta(seconds=train_time))
    logger.info("Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".format(elapsed, train_time))
    

if __name__ == '__main__':
    config = parse_option()
    # Set GPU
    os.environ['CUDA_VISIBLE_DEVICES'] = config.GPU
    # Init dist
    dist.init_process_group(backend="nccl", init_method='env://')
    local_rank = dist.get_rank()
    # Set random seed
    set_seed(config.SEED + local_rank)
    # get logger
    if not config.EVAL_MODE:
        output_file = osp.join(config.OUTPUT, 'log_train_.log')
    else:
        output_file = osp.join(config.OUTPUT, 'log_test.log')
    logger = get_logger(output_file, local_rank, 'reid')
    logger.info("Config:\n-----------------------------------------")
    logger.info(config)
    logger.info("-----------------------------------------")

    main(config)