import os
import re
import torch
import random
import collections
import numpy as np
import pandas as pd
import torch.nn.functional as F
from tqdm import tqdm
from glob import glob
from datetime import datetime
from prettytable import PrettyTable
from sklearn.metrics import normalized_mutual_info_score

from utils.utils import cluster, cluster_infomap
from utils.evaluation import evaluate
import reidlib.evaluation as reideval


def parse_market(filename):
    pattern = re.compile(r'([-\d]+)_c(\d)')
    pid, cid = map(int, pattern.search(filename).groups())
    return (pid, cid)


def parse_msmt(filename):
    pid, cid, idx = filename.split("_")
    pid = int(pid)
    cid = int(cid[1:])
    return pid, cid


def parse_veri(filename):
    pattern = re.compile(r'([-\d]+)_c([-\d]+)')
    pid, cid = map(int, pattern.search(filename).groups())
    return pid, cid


def parse_cuhk(filename):
    pattern = re.compile(r'([-\d]+)_c(\d)')
    pid, cid = map(int, pattern.search(filename).groups())
    return (pid, cid)


def parse_personx(filename):
    pattern = re.compile(r'([-\d]+)_c([-\d]+)')
    pid, cid = map(int, pattern.search(filename).groups())
    return pid, cid


def parse_dataset(dir_path, dataset):
    img_paths = glob(os.path.join(dir_path, "*.jpg"))
    img_paths += glob(os.path.join(dir_path, "*.jpeg"))
    img_paths += glob(os.path.join(dir_path, "*.png"))
    random.shuffle(img_paths)
    img_paths = sorted(img_paths)
    outs = []
    for img_path in img_paths:
        filename = os.path.splitext(os.path.basename(img_path))[0]
        if dataset == "market":
            pid, cid = parse_market(filename)
            if pid == -1:
                continue  # junk images are just ignored
            assert 0 <= pid <= 1501  # pid == 0 means background
            assert 1 <= cid <= 6
            cid -= 1  # index starts from 0
        elif dataset == "msmt":
            pid, cid = parse_msmt(filename)
            if pid == -1:
                continue  # junk images are just ignored
            assert 1 <= cid <= 15
            cid -= 1  # index starts from 0
        elif dataset == "veri":
            pid, cid = parse_veri(filename)
            if pid == -1:
                continue  # junk images are just ignored
            assert 0 <= pid <= 776  # pid == 0 means background
            assert 1 <= cid <= 20
            cid -= 1  # index starts from 0
        elif dataset == "cuhk":
            pid, cid = parse_market(filename)
            if pid == -1:
                continue  # junk images are just ignored
            cid -= 1  # index starts from 0
        elif dataset == "personx":
            pid, cid = parse_personx(filename)
            if pid == -1:
                continue  # junk images are just ignored
            cid_old2cid_new = {3: 1, 4: 2, 8: 3, 10: 4, 11: 5, 12: 6}
            cid = cid_old2cid_new[cid]
            cid -= 1  # index starts from 0
        elif dataset == "clonedperson":
            pid, sid, cid = parse_clonedperson(filename)
            camera_offset = [0, 0, 0, 4, 4, 8, 12, 12, 12, 12, 16, 16, 20]
            if pid == -1:
                continue
            cid = camera_offset[sid] + cid  # make it starting from 0
        outs.append((img_path, pid, cid))
    return outs


def split_samples(samples):
    out_paths = []
    out_pids = []
    out_cids = []
    for path, pid, cid in samples:
        out_paths.append(path)
        out_pids.append(pid)
        out_cids.append(cid)
    return out_paths, out_pids, out_cids


def compute_cid2feature(features, cids):
    cid2feature = dict()
    for feature, cid in zip(features, cids):
        if cid not in cid2feature:
            cid2feature[cid] = [feature]
        else:
            cid2feature[cid].append(feature)
    for cid, feature_list in cid2feature.items():
        features_cid = torch.stack(feature_list, dim=0)
        cid2feature[cid] = features_cid
    return cid2feature


def compute_camera_means_scales(features, cids):
    cid2feature = compute_cid2feature(features, cids)
    cid2mean = dict()
    cid2scale = dict()
    for cid, features_cid in cid2feature.items():
        scale, mean = torch.std_mean(features_cid, dim=0)
        cid2mean[cid] = mean
        cid2scale[cid] = scale
    return cid2mean, cid2scale


def debias(features, cids, cid2mean, cid2scale, eps=1e-9):
    features_debiased = []
    for feature, cid in zip(features, cids):
        feature_debiased = F.normalize(((feature - cid2mean[cid]) / (cid2scale[cid] + eps)), dim=0)
        features_debiased.append(feature_debiased)
    features_debiased = torch.stack(features_debiased)
    return features_debiased


def generate_cluster_features(labels, features):
    centers = collections.defaultdict(list)
    for label, feature in zip(labels, features):
        if label == -1:
            continue
        centers[label].append(feature)
    centers = [
        torch.stack(centers[idx], dim=0).mean(0) for idx in sorted(centers.keys())
    ]
    centers = torch.stack(centers, dim=0)
    return centers


def print_clustering_result(labels_pseudo, labels_gt, cams, outlier_label=-1, simple=False):
    cluster_num = len(set(labels_pseudo)) - (1 if outlier_label in labels_pseudo else 0)
    inlier_indices = labels_pseudo != outlier_label
    labels_inlier = labels_pseudo[inlier_indices]
    cids_inlier = np.array(cams)[inlier_indices]
    pids_inlier = np.array(labels_gt)[inlier_indices]
    cid_score = normalized_mutual_info_score(cids_inlier, labels_inlier)
    pid_score = normalized_mutual_info_score(pids_inlier, labels_inlier)
    img_num = len(labels_pseudo)
    inlier_num = sum(inlier_indices)
    inlier_rate = round(100 * inlier_num / img_num, 1)

    if not simple:
        table = PrettyTable()
        table.field_names = ["Camera correlation", "Accuracy", "# Clusters", "# Inliers"]
        table.add_row([round(cid_score, 4), round(pid_score, 4), cluster_num, f"{inlier_num} ({inlier_rate})%"])
        print(table)
    else:
        camera_correlation = round(100*cid_score, 1)
        accuracy = round(100*pid_score, 1)
        inlier_ratio = inlier_rate
        inlier_num = inlier_num
        cluster_num = cluster_num
        print(camera_correlation, accuracy, inlier_ratio, inlier_num, cluster_num)
    return cluster_num


def exp_clustering(features, cids, pids, eps=0.6, use_jaccard=True):
    exp_base_path = f"outs/{datetime.now().strftime('%Y-%m-%d_%H_%M_%S')}"
    os.makedirs(exp_base_path, exist_ok=True)
    labels = cluster(features, eps=eps, use_jaccard=use_jaccard)
    # make_directories_for_clusters(target_paths, labels, exp_base_path)
    print_clustering_result(labels, pids, cids, outlier_label=-1)
    return labels


def exp_clustering_infomap(features, cids, pids, eps=0.5):
    exp_base_path = f"outs/{datetime.now().strftime('%Y-%m-%d_%H_%M_%S')}"
    os.makedirs(exp_base_path, exist_ok=True)
    labels = cluster_infomap(features, eps=eps)
    print_clustering_result(labels, pids, cids, outlier_label=-1, simple=True)
    return labels


def generate_cluster_weights_by_camera_entropy_binary(labels, cids):
    label2cids = collections.defaultdict(list)
    for label, cid in zip(labels, cids):
        if label == -1:
            continue
        label2cids[label].append(cid)

    weights = []
    for label in sorted(label2cids.keys()):
        cids_curr = label2cids[label]
        if len(set(cids_curr)) == 1:
            weight = 0.
        else:
            weight = 1.
        weights.append(weight)
    weights = torch.tensor(weights)
    return weights


def prepare_test_data(model, preprocessor, dir_path, dataset="market"):
    samples = parse_dataset(dir_path, dataset)
    paths, pids, cids = split_samples(samples)
    features = reideval.preprocessing.extract_features_from_paths(model, preprocessor, paths)
    cid2mean, cid2scale = compute_camera_means_scales(features, cids)

    data_num = len(paths)
    pid_num = len(set(pids))
    cid_num = len(set(cids))
    table = PrettyTable()
    table.field_names = ["# Data", "# PID", "# CID"]
    table.add_row([data_num, pid_num, cid_num])
    print(table)
    return features, pids, cids, cid2mean, cid2scale


def evaluate_clustering_result_infomap(dataset2dir_path, model, preprocessor, dataset_name, epss, out_dir_path):
    print("Evaluate the clustering results. The results are saved at ", out_dir_path)
    os.makedirs(out_dir_path, exist_ok=True)
    out_train_ori_path = os.path.join(out_dir_path, "train_original.csv")
    out_train_deb_path = os.path.join(out_dir_path, "train_debiased.csv")
    out_test_ori_path = os.path.join(out_dir_path, "test_original.csv")
    out_test_deb_path = os.path.join(out_dir_path, "test_debiased.csv")

    # Train split
    target_dir_path = dataset2dir_path[f"{dataset_name}_train"]
    tsamples = parse_dataset(target_dir_path, dataset_name)
    tpaths, tpids, tcids = split_samples(tsamples)
    tfeats = reideval.preprocessing.extract_features_from_paths(model, preprocessor, tpaths)
    tcid2mean, tcid2scale = compute_camera_means_scales(tfeats, tcids)
    tfeats_deb = debias(tfeats, tcids, tcid2mean, tcid2scale)

    df_train_ori = make_clustering_bias_curve(tfeats, tcids, tpids, epss)
    df_train_deb = make_clustering_bias_curve(tfeats_deb, tcids, tpids, epss)
    df_train_ori.to_csv(out_train_ori_path, index=False)
    df_train_deb.to_csv(out_train_deb_path, index=False)

    # Test split
    query_dir_path = dataset2dir_path[f"{dataset_name}_query"]
    gallery_dir_path = dataset2dir_path[f"{dataset_name}_test"]
    qfeats, qpids, qcids, _, _ = prepare_test_data(model, preprocessor, query_dir_path, dataset_name)
    gfeats, gpids, gcids, _, _ = prepare_test_data(model, preprocessor, gallery_dir_path, dataset_name)
    afeats = torch.cat([qfeats, gfeats])
    acids = qcids + gcids
    apids = qpids + gpids
    acid2mean, acid2scale = compute_camera_means_scales(afeats, acids)
    afeats_deb = debias(afeats, acids, acid2mean, acid2scale)

    df_test_ori = make_clustering_bias_curve(afeats, acids, apids, epss)
    df_test_deb = make_clustering_bias_curve(afeats_deb, acids, apids, epss)
    df_test_ori.to_csv(out_test_ori_path, index=False)
    df_test_deb.to_csv(out_test_deb_path, index=False)


def make_clustering_bias_curve(feats, cids, pids, epss, outlier_label=-1):
    camera_correlation_ls = []
    accuracy_ls = []
    inlier_ratio_ls = []
    inlier_num_ls = []
    cluster_num_ls = []
    eps_ls = []
    for eps in tqdm(epss):
        try:
            labels_pseudo = cluster_infomap(feats, eps=eps)
        except:
            continue
        labels_gt = pids
        cluster_num = len(set(labels_pseudo)) - (1 if outlier_label in labels_pseudo else 0)
        inlier_indices = labels_pseudo != outlier_label
        labels_inlier = labels_pseudo[inlier_indices]
        cids_inlier = np.array(cids)[inlier_indices]
        pids_inlier = np.array(labels_gt)[inlier_indices]
        cid_score = normalized_mutual_info_score(cids_inlier, labels_inlier)
        pid_score = normalized_mutual_info_score(pids_inlier, labels_inlier)
        img_num = len(labels_pseudo)
        inlier_num = sum(inlier_indices)
        inlier_rate = round(100 * inlier_num / img_num, 2)

        camera_correlation = round(100*cid_score, 1)
        accuracy = round(100*pid_score, 1)
        inlier_ratio = inlier_rate
        inlier_num = inlier_num
        cluster_num = cluster_num
        eps = round(eps, 2)

        camera_correlation_ls.append(camera_correlation)
        accuracy_ls.append(accuracy)
        inlier_ratio_ls.append(inlier_ratio)
        inlier_num_ls.append(inlier_num)
        cluster_num_ls.append(cluster_num)
        eps_ls.append(eps)
    df = pd.DataFrame({
        "bias": camera_correlation_ls,
        "accuracy": accuracy_ls,
        "inlier_ratio": inlier_ratio_ls,
        "inlier_num": inlier_num_ls,
        "cluster_num": cluster_num_ls,
        "eps": eps_ls,
    })
    return df


def evaluate_reid(dataset2dir_path, model, preprocessor, dataset_name, name):
    print("Evaluate the ReID performance:", name, dataset_name)
    query_dir_path = dataset2dir_path[f"{dataset_name}_query"]
    gallery_dir_path = dataset2dir_path[f"{dataset_name}_test"]
    qfeats, qpids, qcids, _, _ = prepare_test_data(model, preprocessor, query_dir_path, dataset_name)
    gfeats, gpids, gcids, _, _ = prepare_test_data(model, preprocessor, gallery_dir_path, dataset_name)
    afeats = torch.cat([qfeats, gfeats])
    acids = qcids + gcids
    acid2mean, acid2scale = compute_camera_means_scales(afeats, acids)
    qfeats_deb = debias(qfeats, qcids, acid2mean, acid2scale)
    gfeats_deb = debias(gfeats, gcids, acid2mean, acid2scale)

    print("[Original features]")
    evaluate(qfeats, qcids, qpids, gfeats, gcids, gpids)
    print("[Debiased features]")
    evaluate(qfeats_deb, qcids, qpids, gfeats_deb, gcids, gpids)
