import datetime
import os
import pickle
import sys
import time
from collections import defaultdict
from contextlib import contextmanager

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from scipy import io
# from sklearn.metrics import auc, precision_recall_curve
# from sklearn.metrics.base import _average_binary_score
from torch.autograd import Variable
from torch.nn import Parameter

def to_numpy(tensor):
    if torch.is_tensor(tensor):
        return tensor.cpu().numpy()
    elif type(tensor).__module__ != 'numpy':
        raise ValueError("Cannot convert {} to numpy array"
                         .format(type(tensor)))
    return tensor


def to_torch(ndarray):
    if type(ndarray).__module__ == 'numpy':
        return torch.from_numpy(ndarray)
    elif not torch.is_tensor(ndarray):
        raise ValueError("Cannot convert {} to torch tensor"
                         .format(type(ndarray)))
    return ndarray
    
########################################################################################################


def to_scalar(vt):
    """Transform a length-1 pytorch Variable or Tensor to scalar.
    Suppose tx is a torch Tensor with shape tx.size() = torch.Size([1]),
    then npx = tx.cpu().numpy() has shape (1,), not 1."""
    if isinstance(vt, Variable):
        return vt.data.cpu().numpy().flatten()[0]
    if torch.is_tensor(vt):
        return vt.cpu().numpy().flatten()[0]
    raise TypeError('Input should be a variable or tensor')
########################################################################################################


def load_pickle(path):
    # if not os.path.exists(path):
    #     os.makedirs(path)
    assert os.path.exists(path), 'This pickle path is not exist!'
    with open(path, 'rb') as f:
        ret = pickle.load(f)
    return ret


def save_pickle(obj, path):
    """Create dir and save file."""
    may_make_dir(os.path.dirname(os.path.abspath(path)))
    with open(path, 'wb') as f:
        pickle.dump(obj, f, protocol=2)


def str2bool(string):
    return string.lower() in ("yes", "true", "t", "1")


def tight_float_str(x, fmt='{:.4f}'):
    return fmt.format(x).rstrip('0').rstrip('.')


def time_str(fmt=None):
    if fmt is None:
        fmt = '%Y-%m-%d_%H:%M:%S'
    return datetime.datetime.today().strftime(fmt)


def strong_checkdir(path):
    if not os.path.exists(path):
        raise Exception('Please check the '+path+', the path do not exist!')
    return True


def may_make_dir(path):
    """
    Args:
        path: a dir, or result of `os.path.dirname(os.path.abspath(file_path))`
    Note:
        `os.path.exists('')` returns `False`, while `os.path.exists('.')` returns `True`!
    """
    # This clause has mistakes:
    # if path is None or '':

    if path in [None, '']:
        return
    if not os.path.exists(path):
        os.makedirs(path)
        

def find_index(seq, item):
    for i, x in enumerate(seq):
        if item == x:
            return i
    return -1


def fliplr(img):
    img = torch.from_numpy(img)
    inv_idx = torch.arange(img.size(3) - 1, -1, -1).long()
    img_flip = img.index_select(3, inv_idx)
    img_flip = img_flip.numpy()
    return img_flip


def print_test_progess(tester, total_batches):
    '''Print the progress of extracting feature'''
    total_batches = (tester.data_loader.prefetcher.dataset_size
                     // tester.data_loader.prefetcher.batch_size + 1)
    if tester.step % 10 == 0:
        if not tester.printed:
            tester.printed = True
        else:
            # Clean the current line
            sys.stdout.write("\033[F\033[K")
        print('{}/{} batches done, +{:.2f}s, total {:.2f}s'
              .format(tester.step, total_batches,
                      time.time() - tester.last_time, time.time() - tester.st))
        tester.last_time = time.time()




def calculate_single_cos_affinity(feature_1, feature_2, smooth=True):
        '''
        NOTE:
        feature_1: B*C
        feature_2: B*C
        return: B*C
        '''
        assert feature_1.shape == feature_2.shape, "The two features\' shape should be the same!"
        normal_data_f1 = F.normalize(feature_1, p=2, dim=-1)
        normal_data_f2 = F.normalize(feature_2, p=2, dim=-1)
        # get all cosin affinity, [0, 1]
        if smooth:
            cos_affin_normal = 0.5*torch.sum(normal_data_f1 * normal_data_f2, dim=-1)+0.5
        else:
            cos_affin_normal = torch.sum(normal_data_f1 * normal_data_f2, dim=-1)
        return cos_affin_normal
    
def calculate_cos_affinity(feature_1, feature_2, smooth=True):
    '''
    NOTE:
    feature_1: B*C
    feature_2: N*C
    return: B*N
    '''
    epsilon=1e-8
    normal_data_f1 = F.normalize(feature_1, p=2, dim=-1)
    normal_data_f2 = F.normalize(feature_2, p=2, dim=-1)
    # get all cosin affinity, [0, 1]
    if smooth:
        cos_affin_normal = 0.5*torch.mm(normal_data_f1, normal_data_f2.t())+0.5
        cos_affin_normal = torch.clamp(cos_affin_normal, 0 + epsilon, 1 - epsilon)
    else:
        cos_affin_normal = torch.mm(normal_data_f1, normal_data_f2.t())
        cos_affin_normal = torch.clamp(cos_affin_normal, -1 + epsilon, 1 - epsilon)
    return cos_affin_normal

def norm_2(vector):
    return F.normalize(vector, p=2, dim=-1)
