import cv2
import csv
import torch
import numpy as np
import numpy.ma as ma
import torch.nn.functional as F
from prettytable import PrettyTable
from sklearn.metrics import average_precision_score
from collections import defaultdict


def make_table(column_names, column_values):
    table = PrettyTable()
    table.field_names = column_names
    table.add_row(column_values)
    return table


def show_result(mAP, cmc, cmc_topk):
    column_names = ["mAP", *[f"CMC-{k}" for k in cmc_topk]]
    column_values = list(map(lambda x: round(x*100, 1), [mAP, *cmc]))
    table = make_table(column_names, column_values)
    print(table)


def show_result_thresholding(accuracy, f1_score, precision, recall, threshold):
    column_names = ["Accuracy", "F1 score", "Precision", "Recall", "Threshold"]
    column_values = list(map(lambda x: round(x, 4), [accuracy, f1_score, precision, recall, threshold]))
    table = make_table(column_names, column_values)
    print(table)


def show_result_benchmark(mAP, cmc, cmc_topk, accuracy, f1_score, precision, recall, threshold):
    column_names = ["mAP", *[f"CMC-{k}" for k in cmc_topk], "Accuracy", "F1 score", "Precision", "Recall", "Threshold"]
    column_values = list(map(lambda x: round(x*100, 1), [mAP, *cmc])) + \
        list(map(lambda x: round(x, 4), [accuracy, f1_score, precision, recall, threshold]))
    table = make_table(column_names, column_values)
    print(table)


def compute_distance_matrix(x, y, use_cosine=False):
    m, n = x.size(0), y.size(0)
    x = x.view(m, -1)
    y = y.view(n, -1)
    if use_cosine:
        print("Cosine distance is used.")
        x = F.normalize(x)
        y = F.normalize(y)
        dist_mat = -(x @ y.t())
    else:
        dist_mat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \
            torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t()
        dist_mat.addmm_(1, -2, x, y.t())
    return dist_mat


def _unique_sample(ids_dict, num):
    mask = np.zeros(num, dtype=np.bool)
    for _, indices in ids_dict.items():
        i = np.random.choice(indices)
        mask[i] = True
    return mask


def compute_cmc(distmat, query_ids=None, gallery_ids=None,
                query_cams=None, gallery_cams=None, topk=100,
                separate_camera_set=False,
                single_gallery_shot=False,
                first_match_break=False):
    distmat = to_numpy(distmat)
    m, n = distmat.shape
    # Fill up default values
    if query_ids is None:
        query_ids = np.arange(m)
    if gallery_ids is None:
        gallery_ids = np.arange(n)
    if query_cams is None:
        query_cams = np.zeros(m).astype(np.int32)
    if gallery_cams is None:
        gallery_cams = np.ones(n).astype(np.int32)
    # Ensure numpy array
    query_ids = np.asarray(query_ids)
    gallery_ids = np.asarray(gallery_ids)
    query_cams = np.asarray(query_cams)
    gallery_cams = np.asarray(gallery_cams)
    # Sort and find correct matches
    indices = np.argsort(distmat, axis=1)
    matches = (gallery_ids[indices] == query_ids[:, np.newaxis])
    # Compute CMC for each query
    ret = np.zeros(topk)
    num_valid_queries = 0
    for i in range(m):
        # Filter out the same id and same camera
        valid = ((gallery_ids[indices[i]] != query_ids[i]) |
                 (gallery_cams[indices[i]] != query_cams[i]))
        if separate_camera_set:
            # Filter out samples from same camera
            valid &= (gallery_cams[indices[i]] != query_cams[i])
        if not np.any(matches[i, valid]):
            continue
        if single_gallery_shot:
            repeat = 10
            gids = gallery_ids[indices[i][valid]]
            inds = np.where(valid)[0]
            ids_dict = defaultdict(list)
            for j, x in zip(inds, gids):
                ids_dict[x].append(j)
        else:
            repeat = 1
        for _ in range(repeat):
            if single_gallery_shot:
                # Randomly choose one instance for each id
                sampled = (valid & _unique_sample(ids_dict, len(valid)))
                index = np.nonzero(matches[i, sampled])[0]
            else:
                index = np.nonzero(matches[i, valid])[0]
            delta = 1. / (len(index) * repeat)
            for j, k in enumerate(index):
                if k - j >= topk:
                    break
                if first_match_break:
                    ret[k - j] += 1
                    break
                ret[k - j] += delta
        num_valid_queries += 1
    if num_valid_queries == 0:
        raise RuntimeError("No valid query")
    return ret.cumsum() / num_valid_queries


def mean_ap(distmat, query_ids=None, gallery_ids=None,
            query_cams=None, gallery_cams=None):
    distmat = to_numpy(distmat)
    m, n = distmat.shape
    # Fill up default values
    if query_ids is None:
        query_ids = np.arange(m)
    if gallery_ids is None:
        gallery_ids = np.arange(n)
    if query_cams is None:
        query_cams = np.zeros(m).astype(np.int32)
    if gallery_cams is None:
        gallery_cams = np.ones(n).astype(np.int32)
    # Ensure numpy array
    query_ids = np.asarray(query_ids)
    gallery_ids = np.asarray(gallery_ids)
    query_cams = np.asarray(query_cams)
    gallery_cams = np.asarray(gallery_cams)
    # Sort and find correct matches
    indices = np.argsort(distmat, axis=1)
    matches = (gallery_ids[indices] == query_ids[:, np.newaxis])
    # Compute AP for each query
    aps = []
    for i in range(m):
        # Filter out the same id and same camera
        valid = ((gallery_ids[indices[i]] != query_ids[i]) |
                 (gallery_cams[indices[i]] != query_cams[i]))
        y_true = matches[i, valid]
        y_score = -distmat[i][indices[i]][valid]
        if not np.any(y_true):
            continue
        aps.append(average_precision_score(y_true, y_score))
    if len(aps) == 0:
        raise RuntimeError("No valid query")
    return float(np.mean(aps))


def to_numpy(tensor):
    if torch.is_tensor(tensor):
        return tensor.cpu().numpy()
    elif type(tensor).__module__ != 'numpy':
        raise ValueError("Cannot convert {} to numpy array"
                         .format(type(tensor)))
    return tensor


def to_torch(ndarray):
    if type(ndarray).__module__ == 'numpy':
        return torch.from_numpy(ndarray)
    elif not torch.is_tensor(ndarray):
        raise ValueError("Cannot convert {} to torch tensor"
                         .format(type(ndarray)))
    return ndarray


def read_csv(path):
    with open(path, "r") as f:
        out = list(csv.reader(f))
    return out


def resize_with_pad(image, new_shape, padding_color=(0, 0, 0)):
    original_shape = (image.shape[1], image.shape[0])
    ratio = float(max(new_shape))/max(original_shape)
    new_size = tuple([int(x*ratio) for x in original_shape])
    image = cv2.resize(image, new_size, interpolation=cv2.INTER_LANCZOS4)
    delta_w = new_shape[0] - new_size[0]
    delta_h = new_shape[1] - new_size[1]
    top, bottom = delta_h // 2, delta_h - (delta_h // 2)
    left, right = delta_w // 2, delta_w - (delta_w // 2)
    image = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=padding_color)
    return image


def compute_accuracy(
    dist_mat, query_pids, query_cids, gallery_pids, gallery_cids,
    ignore_same_cid=False
):
    if ignore_same_cid:
        cmc_config = dict(
            separate_camera_set=False,
            single_gallery_shot=False,
            first_match_break=True,
            topk=1
        )
        cmc_scores = compute_cmc(dist_mat, query_pids, gallery_pids, query_cids, gallery_cids, **cmc_config)
        accuracy = cmc_scores[0]
    else:
        top_1_indices = np.argmin(dist_mat, axis=1)
        accuracy = np.mean(query_pids == gallery_pids[top_1_indices])
    return accuracy


def compute_f1_score(
    dist_mat, query_pids, gallery_pids,
    iter_max=40, threshold_min=0.2, threshold_unit=0.01
):
    f1_score_best = -1
    precision_best = None
    recall_best = None
    threshold_best = None
    for i in range(iter_max):
        threshold = threshold_min + threshold_unit * i
        f1_score, precision, recall = get_reid_f1_score(
            dist_mat, query_pids, gallery_pids, threshold)
        if f1_score > f1_score_best:
            f1_score_best = f1_score
            precision_best = precision
            recall_best = recall
            threshold_best = threshold
    return f1_score_best, precision_best, recall_best, threshold_best


def get_reid_f1_score(dist_mat, query_pids, gallery_pids, threshold):
    pred = dist_mat < threshold
    gt = np.array([gallery_pids == query_pid for query_pid in query_pids])
    tp = np.sum(gt & pred)
    fp = np.sum(~gt & pred)
    fn = np.sum(gt & ~pred)
    precision = tp / (tp + fp)
    recall = tp / (tp + fn)
    f1_score = 2 * (precision * recall) / (precision + recall)
    return f1_score, precision, recall


def compute_f1_score_ma(
    dist_mat, query_pids, gallery_pids,
    iter_max=40, threshold_min=0.2, threshold_unit=0.01
):
    f1_score_best = -1
    precision_best = None
    recall_best = None
    threshold_best = None
    for i in range(iter_max):
        threshold = threshold_min + threshold_unit * i
        f1_score, precision, recall = get_reid_f1_score_ma(
            dist_mat, query_pids, gallery_pids, threshold)
        if f1_score > f1_score_best:
            f1_score_best = f1_score
            precision_best = precision
            recall_best = recall
            threshold_best = threshold
    return f1_score_best, precision_best, recall_best, threshold_best


def get_reid_f1_score_ma(dist_mat, query_pids, gallery_pids, threshold):
    pred = dist_mat < threshold
    gt = ma.array([query_pid == gallery_pid for query_pid, gallery_pid in zip(query_pids, gallery_pids)])
    tp = ma.sum(gt & pred)
    fp = ma.sum(~gt & pred)
    fn = ma.sum(gt & ~pred)
    precision = tp / (tp + fp)
    recall = tp / (tp + fn)
    f1_score = 2 * (precision * recall) / (precision + recall)
    return f1_score, precision, recall
