import logging
import numpy as np
import os
import pickle
import time
import torch
import torch.nn as nn
import cv2
from utils.meter import AverageMeter
from utils.metrics import R1_mAP, R1_mAP_eval
from utils.visualize import crop_plot_region, crop_plot_stn
import warnings

warnings.filterwarnings("ignore")

def do_train(cfg,
             model,
             center_criterion,
             train_loader,
             val_loader,
             optimizer,
             optimizer_center,
             scheduler,
             loss_fn,
             num_query):
    log_period = cfg.SOLVER.LOG_PERIOD
    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
    eval_period = cfg.SOLVER.EVAL_PERIOD

    device = "cuda"
    epochs = cfg.SOLVER.MAX_EPOCHS

    logger = logging.getLogger("reid_baseline.train")
    logger.info('start training')

    if device:
        model.to(device)

    loss_meter = AverageMeter()
    acc_meter = AverageMeter()

    evaluator = R1_mAP_eval(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)
    model.scale1_net._freeze_stages() if cfg.MODEL.USE_SCALE1  else None
    model.scale2_net._freeze_stages() if cfg.MODEL.USE_SCALE2 else None
    logger.info('Freezing the stages number:{}'.format(cfg.MODEL.FROZEN))

    if torch.cuda.device_count() > 1:
        print('Using {} GPUs for training'.format(torch.cuda.device_count()))
        model = nn.DataParallel(model)

    # train
    for epoch in range(1, epochs + 1):
        start_time = time.time()
        loss_meter.reset()
        acc_meter.reset()
        evaluator.reset()
        scheduler.step()
        model.train()
        for n_iter, (train_img, origin_img, vid, kpt, camid) in enumerate(train_loader):
            optimizer.zero_grad()
            optimizer_center.zero_grad()
            train_img = train_img.to(device)
            origin_img = origin_img.to(device)
            vid = vid.to(device)

            score, feat = model(train_img)
            loss = loss_fn(score, feat, vid)

            loss.backward()
            optimizer.step()
            if 'center' in cfg.MODEL.METRIC_LOSS_TYPE:
                for param in center_criterion.parameters():
                    param.grad.data *= (1. / cfg.SOLVER.CENTER_LOSS_WEIGHT)
                optimizer_center.step()

            acc = (score.max(1)[1] == vid).float().mean()
            loss_meter.update(loss.item(), train_img.shape[0])
            acc_meter.update(acc, 1)

            if (n_iter + 1) % log_period == 0:
                logger.info("Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}"
                            .format(epoch, (n_iter + 1), len(train_loader),
                                    loss_meter.avg, acc_meter.avg, scheduler.get_lr()[0]))

        end_time = time.time()
        time_per_batch = (end_time - start_time) / (n_iter + 1)
        logger.info("Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]"
                    .format(epoch, time_per_batch, train_loader.batch_size / time_per_batch))

        if (epoch % checkpoint_period == 0):
            torch.save(model.state_dict(), os.path.join(cfg.OUTPUT_DIR, '%i.pth'%epoch))

        if (epoch == 1 or epoch % eval_period == 0) and (cfg.DATASETS.NAMES != 'city'):
            model.eval()
            for n_iter, (img, vid, camid, trackid, img_path) in enumerate(val_loader):
                with torch.no_grad():
                    img = img.to(device)
                    if cfg.TEST.FLIP_FEATS == 'on':
                        feat = torch.FloatTensor(img.size(0), feat.shape[-1]).zero_().cuda()
                        for i in range(2):
                            if i == 1:
                                inv_idx = torch.arange(img.size(3) - 1, -1, -1).long().cuda()
                                img = img.index_select(3, inv_idx)
                            f = model(img)
                            feat = feat + f
                    else:
                        feat = model(img)
                    evaluator.update((feat, vid, camid))

            cmc, mAP = evaluator.compute()[:2]
            logger.info("Validation Results - Epoch: {}".format(epoch))
            logger.info("mAP: {:.1%}".format(mAP))
            for r in [1, 5, 10]:
                logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1]))

def do_inference(cfg,
                 model,
                 val_loader,
                 num_query):
    device = "cuda"
    logger = logging.getLogger("reid_baseline.test")
    logger.info("Enter inferencing")
    if cfg.TEST.EVAL:
        evaluator = R1_mAP_eval(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM,
                                reranking=cfg.TEST.RE_RANKING)
    else:
        evaluator = R1_mAP(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM,
                           reranking=cfg.TEST.RE_RANKING, reranking_track=cfg.TEST.RE_RANKING_TRACK)
    evaluator.reset()

    if device:
        model.to(device)
    if torch.cuda.device_count() > 1:
        print('Using {} GPUs for inference'.format(torch.cuda.device_count()))
        model = nn.DataParallel(model)
    model.eval()

    for n_iter, (img, pid, camid, trackid, img_path) in enumerate(val_loader):
        with torch.no_grad():
            img = img.to(device)
            if cfg.TEST.FLIP_FEATS == 'on':
                feat = torch.FloatTensor(img.size(0), 2048).zero_().cuda()
                for i in range(2):
                    if i == 1:
                        inv_idx = torch.arange(img.size(3) - 1, -1, -1).long().cuda()
                        img = img.index_select(3, inv_idx)
                    f = model(img)
                    feat = feat + f
            else:
                feat = model(img)

            if cfg.TEST.EVAL:
                evaluator.update((feat, pid, camid))
            else:
                evaluator.update((feat, pid, camid, trackid, img_path))

    if cfg.TEST.EVAL:
        cmc, mAP = evaluator.compute()[:2]
        logger.info("Validation Results ")
        logger.info("mAP: {:.1%}".format(mAP))
        for r in [1, 5, 10]:
            logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1]))
    else:
        txt_file = os.path.join(cfg.OUTPUT_DIR, "result.txt")
        distmat = evaluator.compute(txt_file)[:1]
        #np.save(os.path.join(cfg.OUTPUT_DIR, cfg.TEST.DIST_MAT) , distmat)
        print('over')