import os
import cv2
import json
import shutil
import torch
import torch.nn.functional as F
import numpy as np
from typing import *
from time import time
from sklearn.cluster import DBSCAN
from .faiss_rerank import compute_jaccard_distance
from .infomap_utils import cluster_by_infomap, get_dist_nbr


def read_json(path):
    with open(path, "r") as f:
        out = json.load(f)
    return out


def write_json(out, path):
    with open(path, "w") as f:
        json.dump(out, f, indent=4)


def resize_with_pad(image: np.array, new_shape: Tuple[int, int], padding_color: Tuple[int] = (0, 0, 0)):
    original_shape = (image.shape[1], image.shape[0])
    ratio = float(max(new_shape)) / max(original_shape)
    new_size = tuple([int(x * ratio) for x in original_shape])
    image = cv2.resize(image, new_size, interpolation=cv2.INTER_LANCZOS4)
    delta_w = new_shape[0] - new_size[0]
    delta_h = new_shape[1] - new_size[1]
    top, bottom = delta_h // 2, delta_h - (delta_h // 2)
    left, right = delta_w // 2, delta_w - (delta_w // 2)
    image = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=padding_color)
    return image


def make_directories_for_clusters(img_paths, labels, out_dir_path):
    out_cluster_path = os.path.join(out_dir_path, "clusters")
    out_vis_path = os.path.join(out_dir_path, "visualization")
    for img_path, label in zip(img_paths, labels):
        out_path = os.path.join(out_cluster_path, f"{label:06}")
        os.makedirs(out_path, exist_ok=True)
        shutil.copy(img_path, out_path)
    # visualize_clusters_pseudo(img_paths, labels, out_vis_path)


def visualize_clusters_pseudo(img_paths, labels, out_dir_path, row_num=6, h=128, w=128):
    print("visualize_clusters_pseudo")
    os.makedirs(out_dir_path, exist_ok=True)
    label2path = dict()
    for path, label in zip(img_paths, labels):
        if label not in label2path:
            label2path[label] = [path]
        else:
            label2path[label].append(path)
    cluster_size_max = max([len(paths) for label, paths in label2path.items()])
    labels = list(label2path.keys())
    for i in range(len(label2path) // row_num):
        labels_sub = labels[row_num * i: row_num * (i + 1)]
        img_rows = []
        for label in labels_sub:
            img_row = [resize_with_pad(cv2.imread(path), (h, w)) for path in label2path[label]]
            if len(img_row) < cluster_size_max:
                for _ in range(cluster_size_max - len(img_row)):
                    img_row.append(np.zeros((h, w, 3)))
            img_row = np.concatenate(img_row, axis=1)
            img_rows.append(img_row)
        img_rows = np.concatenate(img_rows, axis=0)
        out_path = os.path.join(out_dir_path, f"{i:06}.jpg")
        cv2.imwrite(out_path, img_rows)

        print("visualize_clusters_pseudo", out_path, img_rows.shape)


@torch.no_grad()
def compute_pairwise_distance(x, y, use_cosine=False):
    m, n = x.size(0), y.size(0)
    x = x.view(m, -1)
    y = y.view(n, -1)
    if use_cosine:
        print("Cosine distance is used.")
        x = F.normalize(x)
        y = F.normalize(y)
        dist_m = -(x @ y.t())
    else:
        dist_m = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \
            torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t()
        dist_m.addmm_(1, -2, x, y.t())
    return dist_m


@torch.no_grad()
def cluster(features, 
            eps=0.6,
            use_jaccard=True, k1=30, k2=6, 
            use_coinse=False):
    clustering = DBSCAN(eps=eps, min_samples=4, metric='precomputed', n_jobs=-1)

    # Compute distance matrix
    if use_jaccard:
        dist_mat = compute_jaccard_distance(features, k1=k1, k2=k2)
    else:
        dist_mat = compute_pairwise_distance(features, features, use_coinse=use_coinse)

    # Cluster the features
    print("Clustering features...", end="\r")
    start_time = time()
    labels = clustering.fit_predict(dist_mat)
    end_time = time()
    print(f"Clustering finished ({round(end_time - start_time)}s)")
    return labels

import io
import sys
import contextlib


@contextlib.contextmanager
def silence():
    sys.stdout, old = io.StringIO(), sys.stdout
    try:
        yield
    finally:
        sys.stdout = old


@torch.no_grad()
def cluster_infomap(features, eps=0.5, k1=15, k2=4):
    features = F.normalize(features, dim=1).cpu().numpy()
    feat_dists, feat_nbrs = get_dist_nbr(features=features, k=k1, knn_method='faiss-gpu')
    # Cluster the features
    labels = cluster_by_infomap(feat_nbrs, feat_dists, min_sim=eps, cluster_num=k2, verbose=False)
    labels = labels.astype(np.intp)
    return labels
