import numpy as np
import os
import torch
import random
import pickle
import hashlib
from tqdm import tqdm
from data_providers.al_sampler import \
    ImagenetALDataProvider, CIFAR10ALDataProvider, CIFAR100ALDataProvider, CALTECH101ALDataProvider,\
    MNISTALDataProvider, FMNISTALDataProvider, KMNISTALDataProvider, SVHNALDataProvider, \
    EMNISTDIGALDataProvider, EMNISTLETALDataProvider, CELEBADataProvider, IMAGENET1KDataProvider, \
    INATURALIST21SUPERDataProvider


class structure:
    def __init__(self):
        self.clip_grad_norm = None
        self.print_freq = 5
        self.model_ema_steps = 0
        self.lr_warmup_epochs = 0
        self.step_size = 20


def get_data_provider_by_name(dataset_name: str):
    if dataset_name == ImagenetALDataProvider.name():
        DataProviderClass = ImagenetALDataProvider
    elif dataset_name == CIFAR10ALDataProvider.name():
        DataProviderClass = CIFAR10ALDataProvider
    elif dataset_name == CIFAR100ALDataProvider.name():
        DataProviderClass = CIFAR100ALDataProvider
    elif dataset_name == MNISTALDataProvider.name():
        DataProviderClass = MNISTALDataProvider
    elif dataset_name == FMNISTALDataProvider.name():
        DataProviderClass = FMNISTALDataProvider
    elif dataset_name == KMNISTALDataProvider.name():
        DataProviderClass = KMNISTALDataProvider
    elif dataset_name == SVHNALDataProvider.name():
        DataProviderClass = SVHNALDataProvider
    elif dataset_name == CALTECH101ALDataProvider.name():
        DataProviderClass = CALTECH101ALDataProvider
    elif dataset_name == EMNISTDIGALDataProvider.name():
        DataProviderClass = EMNISTDIGALDataProvider
    elif dataset_name == EMNISTLETALDataProvider.name():
        DataProviderClass = EMNISTLETALDataProvider
    elif dataset_name == CELEBADataProvider.name():
        DataProviderClass = CELEBADataProvider
    elif dataset_name == IMAGENET1KDataProvider.name():
        DataProviderClass = IMAGENET1KDataProvider
    elif dataset_name == INATURALIST21SUPERDataProvider.name():
        DataProviderClass = INATURALIST21SUPERDataProvider
    else:
        raise NotImplementedError
    return DataProviderClass


def init_seeds(seed=0):
    torch.cuda.empty_cache()
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True  # https://pytorch.org/docs/stable/notes/randomness.html
    torch.backends.cudnn.benchmark = False
    # torch.use_deterministic_algorithms(True)


def get_class_centers(all_features, all_labels):
    class_id = np.unique(all_labels)
    class_data_ids = dict()
    centers = dict()
    for cid in class_id:
        cids = np.where(all_labels == cid)
        class_data_ids[cid] = cids
        all_fea = all_features[cids]
        center = np.mean(all_fea, axis=0)
        centers[cid] = center
    return class_data_ids, centers


def get_cand_text(fea_dict:dict, cand_num:int=5):
    """
    fea_dict: key: txt_description, value: (feature, improving, similarity, iter)
    cand_num: number of returned text
    """
    if len(fea_dict) <= cand_num:
        return [k for k, v in fea_dict.items()]
    text_arr = []
    score_arr = []
    for k, v in fea_dict.items():
        text_arr.append(k)
        score_arr.append(v[1])
    sorted_args = torch.argsort(torch.as_tensor(score_arr)).numpy()[::-1]
    return [text_arr[sorted_args[i]] for i in range(cand_num)]


def get_highest_score_fea(fea_dict:dict):
    """
    fea_dict: key: txt_description, value: (feature, improving, similarity, iter)
    cand_num: number of returned text
    """
    if len(fea_dict) <= 1:
        fea_val = list(fea_dict.values())
        return fea_val[0][0], fea_val[0][1]
    highest_score = -1
    ret_fea = None
    for k, v in fea_dict.items():
        if v[1] > highest_score:
            highest_score = v[1]
            ret_fea = v[0]
    return ret_fea, highest_score


def get_features(model, dataloader, return_class_center=False):
    # try:
    #     save_name = hashlib.sha1(str(dataloader.sampler.indices).encode()).hexdigest()[:6]
    # except AttributeError as e:
    #     save_name = "testda"
    # save_name += hashlib.sha1(str(dataloader.dataset[0][0].numpy()).encode()).hexdigest()[:6]
    # if os.path.exists(f"fea_save/{save_name}.pkl"):
    #     with open(f"fea_save/{save_name}.pkl", 'rb') as f:
    #         return pickle.load(f)
    all_features = []
    all_labels = []

    model.eval()
    with torch.no_grad():
        for images, labels in tqdm(dataloader):
            features = model.encode_image(images.to("cuda:0"))
            features /= features.norm(dim=-1, keepdim=True)

            all_features.append(features)
            all_labels.append(labels)

    all_fea, all_lab = torch.cat(all_features, dim=0).cpu().numpy(), torch.cat(all_labels).cpu().numpy()
    if not return_class_center:
        # with open(f"fea_save/{save_name}.pkl", 'wb') as f:
        #     pickle.dump((all_fea, all_lab), f)
        return all_fea, all_lab
    else:
        return all_fea, all_lab, get_class_centers(all_fea, all_lab)


def get_labs_from_dataloader(dataloader):
    labs_arr = []
    for _, labs in tqdm(dataloader):
        labs_arr.append(labs.numpy())
    return np.concatenate(labs_arr)


def check_all_class(dataloader, nclasses):
    lab_uniq = np.unique(get_labs_from_dataloader(dataloader))
    class_num = len(lab_uniq)
    if class_num == nclasses:
        return True
    else:
        print(f"Class number is {class_num}, not equal to {nclasses}")
        return False


def load_gen_fea(save_root, iter):
    assert iter > 0
    ret_fea = []
    ret_lab = []
    for it in np.arange(1, iter+1, 1):
        # if os.path.exists(os.path.join(save_root, str(it), "gen_fea_dict.pkl")):
            # with open(os.path.join(save_root, str(it), "gen_fea_dict.pkl"), 'rb') as f:
            #     fea_dict = pickle.load(f)
        with open(os.path.join(save_root, str(it), "gen_fea_all.pkl"), "rb") as f:
            gen_imgs_all = pickle.load(f)
        with open(os.path.join(save_root, str(it), "gen_label_all.pkl"), "rb") as f:
            gen_labels_all = pickle.load(f)
        ret_fea.append(gen_imgs_all)
        ret_lab.append(gen_labels_all)
        
    if len(ret_fea) == 1:
        return ret_fea[0], ret_lab[0]
    else:
        return torch.cat(ret_fea, dim=0), torch.cat(ret_lab, dim=0)
