import io
from collections import Counter

import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt
import seaborn as sn
import torch
from PIL import Image
from torch.utils.data import Subset


class SubsetWithTargets(Subset):
    def __init__(self, dset, idxs):
        super().__init__(dset, idxs)
        self.targets = dset.targets[idxs]


def count_cls(dset: SubsetWithTargets):
    if not dset:
        raise ValueError("empty dset is given!")
    n_cls = Counter(dset.targets)
    return list(zip(n_cls.keys(), n_cls.values()))


def plot_class_count(cfg, class_counted):
    ratio_arr = []
    for i in range(len(class_counted)):
        ratio = torch.zeros(cfg.dset.n_cls)
        for j in class_counted[i]:
            ratio[j[0]] = j[1]
        ratio_arr.append(ratio)
    ratio = torch.stack(ratio_arr).numpy()

    fig = plt.figure(figsize=(cfg.dset.n_cls, cfg.fl.n_c))
    sn.heatmap(
        ratio,
        annot=True,
        fmt="g",
        cbar="",
        cmap="Greys",
        xticklabels=False,
        yticklabels=False,
    )
    buf = io.BytesIO()
    fig.savefig(buf)
    buf.seek(0)
    image = Image.open(buf)
    return image


def convert_run_name(cfg):
    if cfg.fl.combine == "avg":
        main_name = "FedAVG"
    elif cfg.gkd.is_gkd:
        main_name = "FedGKD"
    elif cfg.fl.combine == "em_entropy_soft":
        if cfg.ood.entropy.temp != 0:
            if cfg.fl.combine_from == "pre":
                main_name = "FedD"
            elif cfg.fl.combine_from == "avg":
                main_name = "FedAD"
    elif cfg.fl.combine == "df" and cfg.fl.combine_from == "avg":
        main_name = "FedDF"
    elif cfg.fl.combine == "df" and cfg.fl.combine_from == "pre":
        main_name = "FedDF_pre"
    elif cfg.fl.combine == "prox":
        main_name = "FedProx"
    elif cfg.fl.combine == "gan":
        main_name = "FedGAN"
    elif cfg.fl.combine == "et":
        main_name = "FedET"
    elif cfg.fl.combine == "logit_var":
        main_name = "logit_var"
    elif cfg.fl.combine == "gan_dafkd":
        main_name = "DaFKD"

    if cfg.gan.timing:
        gan = cfg.gan.timing
    else:
        gan = ""
    if cfg.fl.central:
        return f"{main_name}{gan}_{cfg.fl.combine_from}_norm_central"
    return f"{main_name}{gan}_{cfg.fl.combine_from}_norm"
