from __future__ import print_function, absolute_import
import time

import torch
import torch.nn as nn
from torch.nn import functional as F

from .evaluation_metrics import accuracy
from .loss import AR_CELoss, CenterLoss
from .utils.meters import AverageMeter


class Trainer(object):
    def __init__(self, model, num_classes, memory):
        self.model = model
        self.num_cluster = num_classes
        self.memory = memory

        self.criterion_ce = nn.CrossEntropyLoss().cuda()
        self.criterion_ar = AR_CELoss().cuda()
        self.criterion_cr = CenterLoss().cuda()


    def train(self, epoch, data_loaders, optimizer, print_freq=1, train_iters=200, merge=True):
        self.model.train()

        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses_ce = AverageMeter()
        losses_ar = AverageMeter()
        losses_cr = AverageMeter()
        min_lrs = AverageMeter()
        max_lrs = AverageMeter()
        precisions = AverageMeter()

        end = time.time()

        for i in range(train_iters):

            multi = len(data_loaders)
            features = [0] * multi
            cls_outs = [0] * multi
            split_targets = [0] * multi
            for j in range(multi):
                source_inputs = data_loaders[j].next()
                s_inputs, targets = self._parse_data(source_inputs)
                single_features, single_cls_out = self.model(s_inputs)
                features[j] = single_features
                cls_outs[j] = single_cls_out
                if j == multi - 1:
                    # print('Real Labels:', targets)
                    memory_features = F.normalize(single_features)
                    targets, sim_score = self.memory.get_pred(memory_features, get_score=True)

                split_targets[j] = targets

            loss_ce, loss_ar, loss_cr, prec1 = self._multi_forward(features, cls_outs, split_targets, sim_score)

            loss = loss_ce + loss_ar + loss_cr

            data_time.update(time.time() - end)

            losses_ce.update(loss_ce.item())
            losses_ar.update(loss_ar.item())
            losses_cr.update(loss_cr.item())
            precisions.update(prec1)
            min_lr = 10
            max_lr = 0
            for group in optimizer.param_groups:
                min_lr = min(min_lr, group['lr'])
                max_lr = max(max_lr, group['lr'])

            min_lrs.update(min_lr)
            max_lrs.update(max_lr)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            batch_time.update(time.time() - end)
            end = time.time()

            if ((i + 1) % print_freq == 0):
                # print(aaa)
                print('Epoch: [{}][{}/{}]\t'
                      'Time {:.3f} ({:.3f})\t'
                      'Data {:.3f} ({:.3f})\t'
                      'Loss_ce {:.3f} ({:.3f})\t'
                      'Loss_ar {:.3f} ({:.3f})\t'
                      'Loss_cr {:.3f} ({:.3f})\t'
                      'Prec {:.2%} ({:.2%})\n'
                      'Min LR {:.10f}  ({:.10f})\t'
                      'Max LR {:.10f}  ({:.10f})'
                      .format(epoch, i + 1, train_iters,
                              batch_time.val, batch_time.avg,
                              data_time.val, data_time.avg,
                              losses_ce.val, losses_ce.avg,
                              losses_ar.val, losses_ar.avg,
                              losses_cr.val, losses_cr.avg,
                              precisions.val, precisions.avg,
                              min_lrs.val, min_lrs.avg,
                              max_lrs.val, max_lrs.avg))


    def test(self, data_loader_target):
        self.model.eval()
        accu_num = 0
        num_samples = 0
        for i, (imgs, labels) in enumerate(data_loader_target):
            imgs = imgs.cuda()
            labels = labels.cuda()
            batch_size = imgs.shape[0]
            with torch.no_grad():
                _, outputs = self.model(imgs)
                batch_accu_num = accuracy(outputs, labels, topk=(1, 2, 3, 4, 5), test=True)
                batch_accu_num = torch.tensor(batch_accu_num).reshape(-1)
            accu_num += batch_accu_num
            num_samples += batch_size

        prec_accu = accu_num / num_samples

        return prec_accu

    def _parse_data(self, inputs):
        imgs, ids = inputs

        inputs = imgs.cuda()
        targets = ids.cuda()

        return inputs, targets

    # This function requires data from different domains to be entered in the form of a list
    def _multi_forward(self, s_features, s_outputs, targets, sim_score):
        loss_ce = 0
        # prec = 0
        for i, s_output in enumerate(s_outputs):
            _loss_ce = self.criterion_ce(s_output, targets[i])
            if i == len(s_outputs) - 1:
                _prec, = accuracy(s_output.data, targets[i].data)
            # prec += _prec[0]

            loss_ce += _loss_ce

        prec = _prec[0]

        # _prec = accuracy(s_outputs[-1].data, targets[-1].data)
        # prec = _prec[0]
        features = torch.cat(s_features, dim=0)
        targets = torch.cat(targets, dim=0)
        loss_ar = self.criterion_ar(s_outputs[-1], sim_score)
        loss_cr = self.criterion_cr(features, targets, self.memory.feature_centers)
        self.memory.update(features, targets)
        return loss_ce, loss_ar, loss_cr, prec