import os
import time
from collections import defaultdict

import numpy as np
import torch
from torch.autograd import Variable
import torch.nn.functional as F

from .models.utils.load_parameters import model_load_state_dict
from .utils import fliplr, print_test_progess, may_make_dir


class Phase_Tester(object):
    """
    A class to use model get feature.
    Args:
        model: the model you would test
        data_loader: an object to get per-batch elements 
        TVT: trans to tensor
    funcs:
        _get_feature(images):
            return feat
        _test_init():
            init some setting
        test(normalize_feat, verbose=True):
            return feats, ids, im_names, labels, marks, id_logits
    """

    def __init__(self, backbone, data_loader, TVT, weight_file):
        self.backbone = backbone
        self.data_loader = data_loader
        self.TVT = TVT
        self.weight_file = weight_file

    def _calculate_cos_affinity(self, feature_1, feature_2):
        '''
        NOTE:
        feature_1: B*C
        feature_2: N*C
        return: B*N
        '''
        x, y = feature_1, feature_2
        normal_data_f1 = F.normalize(x, p=2, dim=-1)
        normal_data_f2 = F.normalize(y, p=2, dim=-1)
        # get all cosin affinity, [0, 1]
        cos_affin_normal = 0.5*torch.mm(normal_data_f1, normal_data_f2.t())+0.5
        return cos_affin_normal

    def _get_feature(self, images):
        old_backbone_status = self.backbone.training
        # Set eval mode: Force all BN layers to use global mean and variance,
        # also disable dropout.
        self.backbone.eval()
        images = self.TVT(torch.from_numpy(images).float())
        embedding_feat, _ = self.backbone(images)
        # embedding_feat = embedding_feat.data.cpu()

        # raw_feature = raw_feature.data.cpu()
        
        self.backbone.train(old_backbone_status)
        return embedding_feat

    def _test_init(self, task):
        assert os.path.exists(self.weight_file), 'The weight file for tester doesn\'t exist!'
        self.done = False
        self.step = 0
        self.printed = False
        self.st = time.time()
        self.last_time = time.time()
        print(40*'-')
        print('Testing from '+str(self.weight_file))
        print(40*'-')

        map_location = (lambda storage, loc: storage)
        ckpt = torch.load(self.weight_file, map_location=map_location)
        model_load_state_dict(self.backbone, ckpt['mod_state_dicts'][0])
        self.data_loader.task_init(task)


    def test(self, task, verbose=True):
        """
        Extract the features of the whole image set.
        Args:
          normalize_feat: True or False, whether to normalize feature to unit length
          verbose: whether to print the progress of extracting feature
        Returns:
          feat: numpy array with shape [N, C]
          labels: numpy array with shape [N]
        """
        self._test_init(task)
        feats, labels = [], []
        total_batches = self.data_loader.prefetcher.dataset_size // self.data_loader.prefetcher.batch_size + 1
        while not self.done:
            '''
            data loader fetch: ims, ids, labels, names, self.epoch_done
            '''
            next_batch_data = self.data_loader.next_batch()
            ims_, _, labels_, _, self.done = next_batch_data
            feats_ = self._get_feature(ims_)
            labels_ = torch.from_numpy(labels_).long()

            feats.append(feats_)
            labels.append(labels_)
            self.step += 1
            if verbose:
                print_test_progess(self, total_batches=total_batches)

        feats = torch.cat(feats)
        labels = torch.cat(labels)
        return feats, labels