from itertools import product
from pathlib import Path
from pprint import pprint
import colorsys
import json
import pickle as pkl

from einops import rearrange, repeat
from sklearn.decomposition import PCA
import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch as pt
import torch.nn.functional as ptnf

from object_centric_bench.datum import DataLoader
from object_centric_bench.datum.utils import draw_segmentation_np
from object_centric_bench.learn import MetricWrap, AverageLog
from object_centric_bench.model import ModelWrap2
from object_centric_bench.utils import Config, build_from_config


def generate_spectrum_colors(num_color):
    spectrum = []
    for i in range(num_color):
        hue = i / float(num_color)
        rgb = colorsys.hsv_to_rgb(hue, 1.0, 1.0)
        spectrum.append([int(255 * c) for c in rgb])
    return np.array(spectrum, dtype="uint8")  # (n,c=3)


def visualiz_codemap(zidx: np.ndarray, groups=[64, 64], dim=1):
    """
    zidx: tensor in shape (..,g=len(groups),h,w)
    image: in shape (..,h,w,c)
    """
    num_code = np.prod(groups)
    spectrum = generate_spectrum_colors(num_code)  # (n,c=3)
    image_idx = []
    for i, g in enumerate(groups):
        idx = zidx[i, :, :]
        idx = (idx * num_code / g).astype("int")
        image_idx.append(idx)
    # image_idx = np.mean(image_idx, 0).astype("int")
    image_idx = np.concatenate(image_idx, dim)
    image = spectrum[image_idx]
    return image.astype("uint8")


def allocate_color_by_reference(tensor, reference, scale=None):
    """np.ndarray
    tensor: shape=(n,c)
    reference: shape=(c,)
    """
    assert tensor.ndim == 2 and reference.ndim == 1
    dist = 1 - ptnf.cosine_similarity(
        pt.from_numpy(tensor), pt.from_numpy(reference[None, :]), 1
    )
    dist = dist.numpy()
    dist /= 2
    if scale:
        dist *= scale
        dist = dist.clip(0, 1)
    num = tensor.shape[0]
    colors = generate_spectrum_colors(num)
    idx = np.round(dist * num).astype("int")
    return colors[idx]


def tuple_zidx_to_scalar_zidx(tzidx: pt.Tensor, groups: list):
    """
    tzidx: in shape (b,g,h,w)
    szidx: in shape (b,h,w)
    """
    szidx = 0
    g_old = 1
    for i, g in enumerate(groups):
        idx = tzidx[:, i, :, :]
        szidx += idx * g_old
        g_old *= g
    return szidx


def adaptive_clustering_of_image_feature(tensor, n=None, viz=True, flag=3):
    from hdbscan import HDBSCAN
    from sklearn.cluster import AffinityPropagation, DBSCAN, MeanShift, KMeans

    c, h, w = tensor.shape
    tensor = rearrange(tensor, "c h w -> (h w) c")

    if flag == 1:
        clustering = HDBSCAN(
            min_cluster_size=5, min_samples=5, cluster_selection_epsilon=0.5
        )
        # clustering = DBSCAN(eps=0.5, min_samples=5)
    elif flag == 2:  # too slow
        clustering = MeanShift(bandwidth=1)
    elif flag == 3:
        clustering = KMeans(n)
    else:
        raise "NotImplemented"

    clustering.fit(tensor)
    cluster = clustering.labels_  # + 1
    cluster = rearrange(cluster, "(h w) -> h w", h=h)
    if not viz:
        return cluster

    pca = PCA(n_components=3)
    color0 = pca.fit_transform(clustering.cluster_centers_)
    cmin = color0.min()
    cmax = color0.max()
    color0 = ((color0 - cmin) / (cmax - cmin) * 255).astype("uint8")
    color = color0[cluster]
    return color


def generate_pca_colormap(feature, n_component=4):
    """
    feature: shape=(h,w,c)
    """
    h, w, c = feature.shape
    feature2 = rearrange(feature, "h w c -> (h w) c")
    pca = PCA(n_components=n_component)
    component = pca.fit_transform(feature2)[:, -3:]  # XXX :3 ???
    component2 = rearrange(component, "(h w) c -> h w c", h=h, w=w)
    cmin = component2.min()
    cmax = component2.max()
    color = ((component2 - cmin) / (cmax - cmin) * 255).astype("uint8")
    return color


@pt.no_grad()
def val_epoch(cfg, dataset_v, model, loss_fn, metric_fn, callback_v):
    cv2_resize_nearest = lambda i, x: cv2.resize(
        i, None, fx=x, fy=x, interpolation=cv2.INTER_NEAREST_EXACT
    )
    pack = locals()
    pack["epoch"] = 0
    model.eval()
    [_.before_epoch(**pack) for _ in callback_v]

    zidxs = []
    slotss = []

    for step, batch in enumerate(dataset_v):  # tqdm
        pack["step"] = step
        pack["batch"] = batch
        [_.before_step(**pack) for _ in callback_v]

        with pt.autocast("cuda", enabled=True):
            output = model(batch)

            pack["output"] = output
            [_.after_forward(**pack) for _ in callback_v]

            loss = loss_fn(output, batch)
        metric = metric_fn(output, batch)  # in autocast may cause inf

        ### <<< save img0, seg0, imgz and seg1
        i = 0
        img_key = "image" if "image" in batch else "video"
        img0 = (
            (batch[img_key][i] * 127.5 + 127.5).clip(0, 255).to(pt.uint8).cpu().numpy()
        )
        img0 = rearrange(img0, "c h w -> h w c")

        global item_name
        save_dir = Path(item_name)
        if not save_dir.exists():
            save_dir.mkdir()
        save_path = f"{save_dir}/{step:04d}-{i:04d}-"
        cv2.imwrite(save_path + "i0.png", img0)

        seg0 = batch["segment"][i].cpu().numpy()  # (h,w)
        seg1 = output["segment"][i].cpu().numpy()
        imgs0 = draw_segmentation_np(img0, seg0)
        imgs1 = draw_segmentation_np(img0, seg1)
        cv2.imwrite(save_path + "is0.png", imgs0)
        cv2.imwrite(save_path + "is1.png", imgs1)

        """for scale in range(3):
            zidx = output[f"zidx{scale}"][i].cpu().numpy()  # (g,h,w)
            quant = output[f"quant{scale}"][i].cpu().numpy()  # (c,h,w)
            fact = 4 * 2**scale
            imgz = cv2_resize_nearest(  # TODO XXX TODO XXX TODO XXX TODO XXX groups
                visualiz_codemap(zidx, groups=[64, 64], dim=0), fact
            )
            imgqc = cv2_resize_nearest(
                adaptive_clustering_of_image_feature(quant, cfg.max_num, True), fact
            )
            assert zidx.shape[0] == 2
            quant_si = (
                model.m.codebook[0](pt.from_numpy(zidx[0, :, :]).cuda()).cpu().numpy()
            )
            quant_sv = (
                model.m.codebook[scale + 1](pt.from_numpy(zidx[1, :, :]).cuda())
                .cpu()
                .numpy()
            )
            imgqcsi = cv2_resize_nearest(
                adaptive_clustering_of_image_feature(
                    quant_si.transpose(2, 0, 1), cfg.max_num
                ),
                fact,
            )
            imgqcsv = cv2_resize_nearest(
                adaptive_clustering_of_image_feature(
                    quant_sv.transpose(2, 0, 1), cfg.max_num
                ),
                fact,
            )

            cv2.imwrite(save_path + f"{scale}iz.png", imgz)
            cv2.imwrite(save_path + f"{scale}iqc.png", imgqc)
            cv2.imwrite(save_path + f"{scale}iqcsi.png", imgqcsi)
            cv2.imwrite(save_path + f"{scale}iqcsv.png", imgqcsv)"""

        ### >>>

        pack["loss"] = loss
        pack["metric"] = metric
        [_.after_step(**pack) for _ in callback_v]

    [_.after_epoch(**pack) for _ in callback_v]

    if zidxs:
        zidxs = np.concatenate(zidxs, 0)
    if slotss:
        slotss = np.concatenate(slotss, 0)

    for cb in callback_v:
        if type(cb) == AverageLog:
            log_info = cb.mean()
    return log_info, zidxs, slotss


def main_eval_single(
    cfg_file=None,
    ckpt_file=None,
):
    data_dir = "/media/GeneralZ/Storage/Static/datasets"
    pt.backends.cudnn.benchmark = True

    cfg_file = Path(cfg_file)
    data_path = Path(data_dir)
    ckpt_file = Path(ckpt_file)
    save_path = Path("save")

    assert cfg_file.name.endswith(".py")
    assert cfg_file.is_file()
    cfg_name = cfg_file.name.split(".")[0]
    cfg = Config.fromfile(cfg_file)

    # cfg.metric_fn.update(
    #     dict(  # XXX XXX XXX XXX XXX
    #         mbo=dict(
    #             metric=dict(
    #                 type="mBO", num_pd=cfg.max_num, num_gt=cfg.max_num, fg=False
    #             ),
    #             map=dict(output=dict(input="segment"), batch=dict(target="segment")),
    #         ),
    #         # mbo_fg=dict(
    #         #     metric=dict(
    #         #         type="mBO", num_pd=cfg.max_num, num_gt=cfg.max_num, fg=True
    #         #     ),
    #         #     map=dict(output=dict(input="segment"), batch=dict(target="segment")),
    #         # ),
    #         miou=dict(
    #             metric=dict(
    #                 type="HungarianMIoU",
    #                 num_pd=cfg.max_num,
    #                 num_gt=cfg.max_num,
    #                 fg=False,
    #             ),
    #             map=dict(output=dict(input="segment"), batch=dict(target="segment")),
    #         ),
    #         # miou_fg=dict(
    #         #     metric=dict(
    #         #         type="HungarianMIoU",
    #         #         num_pd=cfg.max_num,
    #         #         num_gt=cfg.max_num,
    #         #         fg=True,
    #         #     ),
    #         #     map=dict(output=dict(input="segment"), batch=dict(target="segment")),
    #         # ),
    #     )
    # )

    ## datum init

    cfg.dataset_t.base_dir = cfg.dataset_v.base_dir = data_path
    # # TODO XXX TODO XXX TODO XXX TODO XXX TODO XXX clevrtex_ood
    # cfg.dataset_v.data_file = "clevrtex/ood.lmdb"
    # cfg.dataset_v.split = "ood"

    dataset_v = build_from_config(cfg.dataset_v)
    dataload_v = DataLoader(
        dataset_v,
        2,  # cfg.batch_size_v,  # TODO XXX TODO XXX TODO XXX TODO XXX // 2
        num_workers=cfg.num_work,
        shuffle=False,
        pin_memory=True,
    )

    ## model init

    if "MultiScale" in cfg.model.type:
        cfg.groups = [64, 64]  # TODO XXX TODO XXX TODO XXX TODO XXX
    model = build_from_config(cfg.model)
    # print(model)

    model = ModelWrap2(model, cfg.model_imap, cfg.model_omap)
    model = model.cuda()
    if ckpt_file:
        model.load(ckpt_file, None, verbose=False)
    if cfg.freez:
        model.freez(cfg.freez, verbose=False)

    """from object_centric_bench.model.ocl import Codebook

    cb = model.m.codebook
    if isinstance(cb, pt.nn.ModuleList):
        codes = pt.cat(
            [
                repeat(cb[0].templat.weight, "(n 1) c -> (n m) c", m=64),
                repeat(cb[1].templat.weight, "(1 n) c -> (m n) c", m=64),
            ],
            1,
        )
    else:
        codes = cb.templat.weight
    dist = Codebook.euclidean_distance(codes, codes)
    print(f'{dist.mean().item():.2f}, {dist.min().item():.2f}, {dist.max().item():.2f}')"""

    ## learn init

    loss_fn = MetricWrap(**build_from_config(cfg.loss_fn))
    metric_fn = MetricWrap(**build_from_config(cfg.metric_fn))

    cfg.callback_v = [_ for _ in cfg.callback_v if _.type != "SaveModel"]
    for cb in cfg.callback_v:
        if cb.type == "AverageLog":
            cb.log_file = None
    callback_v = build_from_config(cfg.callback_v)

    ## do eval

    # log_info, zidxs = val_epoch(cfg, dataload_v, model, loss_fn, metric_fn, callback_v)
    log_info, zidxs, slotss = val_epoch(
        cfg, dataload_v, model, loss_fn, metric_fn, callback_v
    )
    if zidxs is not None and len(zidxs) > 0:
        if hasattr(cfg, "groups"):
            zidxs = tuple_zidx_to_scalar_zidx(zidxs, cfg.groups)
        np.save(f"zidx.{cfg_file.name[:-3]}.npy", zidxs.astype("uint16").ravel())
    if slotss is not None and len(slotss) > 0:
        np.save(f"slots.{cfg_file.name[:-3]}.npy", slotss)
    return log_info


def main_eval_multi_transfer():
    cfg_files = [f"config/{_}" for _ in []]
    ckpt_dirs = [f"archive.v8-ln-movi/{_}" for _ in []]

    table0 = {}
    keys0 = None

    for cfg_file, ckpt_dir in product(cfg_files, ckpt_dirs):
        cfg_file = Path(cfg_file)
        ckpt_dir = Path(ckpt_dir)

        k1 = cfg_file.name[:-3]
        k2 = ckpt_dir.name

        parts1 = k1.split("-")
        parts2 = k2.split("-")
        if not (
            len(parts1) == len(parts2) == 2
            or len(parts1) == len(parts2) == 3
            and parts1[2] == parts2[2]
        ):
            continue
        print("#" * 10, cfg_file, "#" * 10)

        """infos = [main_eval_single(cfg_file, ckpt_dir / "best.pth")]
        info = infos[0]
        """
        infos = []
        for ckpt_file in list(ckpt_dir.glob("*.pth")):
            if ckpt_file.name == "best.pth":
                continue
            info = main_eval_single(cfg_file, ckpt_file)
            infos.append(info)

        if len(infos) == 0:
            continue

        if keys0 is None:
            keys0 = list(info.keys())
        merge = {}
        for _info_ in infos:
            for key in keys0:
                value = _info_[key]
                if key in merge:
                    merge[key].append(value)
                else:
                    merge[key] = [value]
        table0[(k1, k2)] = np.array([np.mean(_, axis=0) for _ in merge.values()])

    pprint(table0)
    with open("main_eval_multi_transfer-table0.pkl", "wb") as f:
        pkl.dump(table0, f)

    table1 = {}
    for k, v in table0.items():
        k0 = (k[0],) * 2
        if k0 in table0:
            base = table0[k0]
        table1[k] = v - base

    pprint(table1)
    with open("main_eval_multi_transfer-table1.pkl", "wb") as f:
        pkl.dump(table1, f)


def main_eval_multi_zidx_simi_hist():
    for item in [
        "slotdiffuz_r_vqvae-coco",
        # "slotdiffuz_r_vqvae-coco-ms",
        "slotdiffuz_r_vqvae-voc",
        # "slotdiffuz_r_vqvae-voc-ms",
    ]:
        global item_name
        item_name = item
        log_info = main_eval_single(
            cfg_file=f"config-ms-c4/{item}.py",
            ckpt_file=f"archive-ms-c4/{item}/best.pth",
        )
        continue
        line = (
            f"{item},"
            + ",".join([f"{_ * 100:.2f}" for _ in list(log_info.values())])
            + "\n"
        )
        with open("main_eval_multi_zidx_simi_hist.txt", "a") as f:
            f.writelines(line)
    # return
    zidx_files = list(Path(".").glob("*.npy"))
    zidx_files.sort()
    num_bin = 100
    fig, axs = plt.subplots(2, 3)
    axs = axs.flatten()
    for i, zidx_file in enumerate(zidx_files):
        print(zidx_file)
        zidxs = np.load(zidx_file)
        hist, bins = np.histogram(zidxs, num_bin)
        xs = (bins[1:] + bins[:-1]) / 2
        cv = np.std(hist) / np.mean(hist)  # coefficient of variation
        axs[i // 2].plot(xs, hist, label=zidx_file.name[:-4] + f" {cv:.4f}")
    [_.legend() for _ in axs]
    plt.show()


if __name__ == "__main__":
    # main_eval_single()
    main_eval_multi_zidx_simi_hist()
    # main_eval_multi_transfer()
