import itertools
import json
import random

import numpy as np
import pandas as pn
import torch

import blahut_arimoto as ba

in_dataset_list = ["cifar10", "cifar100", "imagenet"]
models_cifar10 = ["resnet18", "resnet34"]
models_cifar100 = ["resnet18", "resnet34"]
models_imagenet = [
    "resnet18.tv_in1k",
    "resnet34.tv_in1k",
    "resnet50.tv_in1k",
    "resnet101.tv_in1k",
    "vit_base_patch16_224.augreg_in21k_ft_in1k",
]
ood_dataset_list_cifar = list(
    {
        "cifar100": "test",
        "svhn": "test",
        "isun": None,
        "lsun_c": None,
        "lsun_r": None,
        "tiny_imagenet_c": None,
        "tiny_imagenet_r": None,
        "textures": None,
        "places365": None,
        "uniform": None,
        "gaussian": None,
    }.keys()
)

ood_dataset_list_imagenet = list(
    {
        "inaturalist_clean": None,
        "species_clean": None,
        "places_clean": None,
        "openimage_o_clean": None,
        "ssb_easy": None,
        "textures_clean": None,
        "ninco": None,
        "ssb_hard": None,
    }.keys()
)


def get_prep_scores(scores_path1, drop_methods=["gradnorm", "vim", "dice"]):
    scores = pn.read_csv(scores_path1)
    # drop duplicates
    scores = scores.drop_duplicates(subset=["model", "method"], keep="last")
    # drop gradnorm and vim
    scores = scores[~scores["method"].isin(drop_methods)]
    # drop na
    scores = scores.dropna()
    # sort by model and method
    scores = scores.sort_values(by=["model", "method"])
    return scores


def get_device():
    return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def initialize_seeds(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def create_exp_dict(methods_sc, methods_ood):
    combinations = itertools.product(methods_sc, methods_ood)
    experiments_dict = {}
    for d, c in enumerate(combinations):
        experiments_dict[d] = {"sc": c[0], "ood": c[1]}
    return experiments_dict


def remove_element_from_list(l, element):
    return [x for x in l if x != element]


def assert_equivalence_label_idx(labels_array_ood, datasets_idx_ood, labels_array_sc, datasets_idx_sc, labels_array_ba, datasets_idx_ba):
    assert np.all(labels_array_ood == labels_array_sc), "labels do not match"
    assert np.all(labels_array_ood == labels_array_ba), "labels do not match"
    assert np.all(datasets_idx_ood == datasets_idx_sc), "datasets_idx do not match"
    assert np.all(datasets_idx_ood == datasets_idx_ba), "datasets_idx do not match"


def add_ba(scores_array, labels_array, datasets_idx, device):
    X = scores_array[datasets_idx[:, 0] == 0][:1000]
    probs_array = ba.p_value_fn(scores_array, X=X)
    # compute optimal weights for all the samples in the dataset (zero-shot)
    w = ba.blahut_arimoto(torch.from_numpy(probs_array).float(), device=device, threshold=1e-4, verbose=False).numpy()
    ba_scores = np.sum(w * probs_array, 1)
    # append the blahut arimoto scores to the scores array
    scores_array = np.concatenate([scores_array, ba_scores.reshape(-1, 1)], 1)
    labels_array = np.concatenate([labels_array, labels_array[:, 1].reshape(-1, 1)], 1)
    datasets_idx = np.concatenate([datasets_idx, datasets_idx[:, 1].reshape(-1, 1)], 1)

    return scores_array, labels_array, datasets_idx


def add_baseline(scores_array, labels_array, datasets_idx, device):
    X = scores_array[datasets_idx[:, 0] == 0][:1000]
    probs_array = ba.p_value_fn(scores_array, X=X)
    # compute optimal weights for all the samples in the dataset (zero-shot)
    w = np.array([1 / probs_array.shape[1]] * probs_array.shape[1]).reshape(1, -1)
    ba_scores = np.sum(w * probs_array, 1)
    # append the blahut arimoto scores to the scores array
    scores_array = np.concatenate([scores_array, ba_scores.reshape(-1, 1)], 1)
    labels_array = np.concatenate([labels_array, labels_array[:, 1].reshape(-1, 1)], 1)
    datasets_idx = np.concatenate([datasets_idx, datasets_idx[:, 1].reshape(-1, 1)], 1)

    return scores_array, labels_array, datasets_idx


def get_model_scores_labels_idx_w_ba(scores, model_name, device):
    df = scores.query("model == @model_name")
    # this has one row per sample (all in-d and all ood) and one column per method
    labels_array = np.concatenate(df["labels"].map(lambda x: np.array(json.loads(x)).reshape(-1, 1)).values, 1)
    # this has one row per sample (all in-d and all ood) and one column per method
    scores_array = np.concatenate(df["scores"].map(lambda x: np.array(json.loads(x)).reshape(-1, 1)).values, 1)
    # this has one row per sample (all in-d and all ood) and one column per method
    datasets_idx = np.concatenate(df["idx"].map(lambda x: np.array(json.loads(x)).reshape(-1, 1)).values, 1)

    scores_array, labels_array, datasets_idx = add_ba(scores_array, labels_array, datasets_idx, device)

    return scores_array, labels_array, datasets_idx


def get_model_scores_labels_idx_w_baseline(scores, model_name, device):
    df = scores.query("model == @model_name")
    # this has one row per sample (all in-d and all ood) and one column per method
    labels_array = np.concatenate(df["labels"].map(lambda x: np.array(json.loads(x)).reshape(-1, 1)).values, 1)
    # this has one row per sample (all in-d and all ood) and one column per method
    scores_array = np.concatenate(df["scores"].map(lambda x: np.array(json.loads(x)).reshape(-1, 1)).values, 1)
    # this has one row per sample (all in-d and all ood) and one column per method
    datasets_idx = np.concatenate(df["idx"].map(lambda x: np.array(json.loads(x)).reshape(-1, 1)).values, 1)

    scores_array, labels_array, datasets_idx = add_baseline(scores_array, labels_array, datasets_idx, device)

    return scores_array, labels_array, datasets_idx


def c_in(lbd, pi_in_star):
    return lbd * pi_in_star


def c_out(lbd, pi_in_star, c_fn):
    return c_fn - lbd * (1 - pi_in_star)


def theta(scores_ood, eps):
    # where scores_ood is 0, set it to eps
    scores_ood = np.where(scores_ood == 0, eps, scores_ood)
    return -1 / scores_ood


def t_bb(c_in, c_out):
    return 1 - 2 * c_in - c_out


def compute_r_bb(scores_sc, scores_ood, eps, c_in_fix, c_out_fix):
    return (1 - c_in_fix - c_out_fix) * scores_sc + c_out_fix * theta(scores_ood, eps)


def r_bb(scores_sc, scores_ood, lbd, eps, pi_in_star, c_fn):
    c_in_fix = c_in(lbd=lbd, pi_in_star=pi_in_star)
    c_out_fix = c_out(lbd=lbd, pi_in_star=pi_in_star, c_fn=c_fn)
    return compute_r_bb(scores_sc=scores_sc, scores_ood=scores_ood, eps=eps, c_in_fix=c_in_fix, c_out_fix=c_out_fix) < t_bb(c_in_fix, c_out_fix)


def get_lbd_range(pi_in_star, c_fn, num_points: int = 1000):
    min_lbd = 0
    max_c_in = 1
    max_lbd = max_c_in / pi_in_star
    lbds = np.linspace(min_lbd, max_lbd, num_points).tolist()
    c_ins = [c_in(lbd, pi_in_star) for lbd in lbds]
    c_outs = [c_out(lbd, pi_in_star, c_fn) for lbd in lbds]

    # find where both c_in and c_out are within [0,1]
    idx = np.where((np.array(c_ins) >= 0) & (np.array(c_ins) <= 1) & (np.array(c_outs) >= 0) & (np.array(c_outs) <= 1))[0]

    return np.array(lbds)[idx]


def get_scod_stats(scores_sc, scores_ood, lbd_range, y_scod, eps, pi_in_star, c_fn):
    tprs = []
    fprs = []
    risks = [0]
    risks_num = [0]
    risks_den = [0]
    coverages = [0]
    lbds = []

    for lbd in lbd_range:
        r_bb_res = r_bb(scores_sc=scores_sc, scores_ood=scores_ood, lbd=lbd, eps=eps, pi_in_star=pi_in_star, c_fn=c_fn)

        r_bb_res = r_bb_res.astype(int)
        labels_array = y_scod

        # compute tpr and fpr given predicted labels and true labels
        tpr = (r_bb_res == 1) & (labels_array == 1)
        tpr = tpr.astype(int).sum() / float(labels_array.sum())

        fpr = (r_bb_res == 1) & (labels_array == 0)
        fpr = fpr.astype(int).sum() / float((labels_array == 0).sum())

        tprs.append(tpr)
        fprs.append(fpr)

        if (r_bb_res == 0).sum() > 0:
            risk = (r_bb_res == 0) & (labels_array == 1)
            risks_num.append(risk.astype(int).sum())
            risk = risk.astype(int).sum() / float((r_bb_res == 0).sum())
            risks_den.append((r_bb_res == 0).sum())

            coverage = (r_bb_res == 0).sum() / float(len(labels_array))
            lbds.append(lbd)

            risks.append(risk)
            coverages.append(coverage)

    # sort fprs in ascending order and tprs accordingly
    tprs = [x for _, x in sorted(zip(fprs, tprs))]
    fprs = sorted(fprs)

    return tprs, fprs, risks, risks_num, risks_den, coverages, lbds
