import os
import os.path as osp
import csv
from tqdm import tqdm
import torch

import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt


def makedirs(path):
    if not os.path.exists(path):
        os.makedirs(path)


def generate_subset_masks(set_mask, all_masks):
    '''
    For a given S, generate its bigger sets L's, as well as the indices of L's in [all_masks]
    :param set_mask:
    :param all_masks:
    :return: the subset masks, the bool indice
    '''
    set_mask_ = set_mask.expand_as(all_masks)
    # is_subset_L_S = torch.logical_or(set_mask_, torch.logical_not(all_masks))
    is_subset_S_T = torch.logical_or(all_masks, torch.logical_not(set_mask_))
    is_subset = torch.all(is_subset_S_T, dim=1)
    return all_masks[is_subset], is_subset


if __name__ == '__main__':
    load_dir = "eval_andor"
    save_dir = "analysis_importance"
    makedirs(save_dir)

    # trick = "pqa"
    # loss = "l1"
    # qthres = 0.01
    # param = f"-trick-{trick}-loss-{loss}-qthres-{qthres}"
    # reward_way = "gt-log-odds-minus-mean-minus-empty"

    # loss = "l1"
    # qthres = 0.4
    # lr = 1e-6
    # trick = "pq"
    # qstd = "vN_vEmpty_mean"
    # param = f"after_sparsifying-trick-{trick}-loss-{loss}-lr-{lr}-qthres-{qthres}-qstd-{qstd}"
    # reward_way = "gt-log-odds"

    loss = "l1_for_6_10"
    reward_way = "gt-log-odds-minus-mean"
    qthres = 0.4
    lr = 1e-6
    weight = 5
    trick = "pqa"
    lr_way = "a_1"
    qstd = "vN_vEmpty_mean"
    param = f"after_sparsifying-trick-{trick}-loss-{loss}-lr-{lr}-lr-way-{lr_way}-qthres-{qthres}-qstd-{qstd}-weight-{weight}"

    sample_ids = sorted([sample_id for sample_id in os.listdir(load_dir) if sample_id.startswith("id")])
    folders = [osp.join(load_dir, sample_id, reward_way) for sample_id in sample_ids]

    for folder in folders:
        sample_id = folder.split("/")[-2]
        save_folder = osp.join(save_dir, sample_id, reward_way, param)
        makedirs(save_folder)

        masks = torch.tensor(np.load(osp.join(folder, "before_sparsify", "masks.npy")))
        n_masks, _ = masks.shape

        Iand_before = torch.tensor(np.load(osp.join(folder, "before_sparsify", "Iand.npy")))
        Ior_before = torch.tensor(np.load(osp.join(folder, "before_sparsify", "Ior.npy")))
        Iand_after = torch.tensor(np.load(osp.join(folder, param, "Iand.npy")))
        Ior_after = torch.tensor(np.load(osp.join(folder, param, "Ior.npy")))

        # mat是n*n(1024*1024)大小的矩阵，其中不同行代表不同的遮挡状态S，不同列表示不同T，其中S《T.  |S| / |T|
        phi_and_mat = []
        for i in tqdm(range(n_masks), ncols=100, desc="Generating mask"):
            mask_S = masks[i]

            phi_and_row = torch.zeros(n_masks)
            # ===============================================================================================
            # Note: phi(S) = \sum_{S\subseteq T} |S| / |T| I(T)
            mask_Ls, L_indices = generate_subset_masks(mask_S, masks)

            L_indices = (L_indices == True).nonzero(as_tuple=False)
            assert mask_Ls.shape[0] == L_indices.shape[0]
            epsilon = 1e-7
            phi_and_row[L_indices] = (mask_S.sum() / (mask_Ls.sum(dim=1) + epsilon)).unsqueeze(1)
            # ===============================================================================================
            phi_and_mat.append(phi_and_row.clone())
        phi_and_mat = torch.stack(phi_and_mat).float()

        phi_and = np.array(torch.matmul(phi_and_mat, Iand_after))
        phi_or = np.array(torch.matmul(phi_and_mat, Ior_after))
        phi_after_sparsify = phi_and + phi_or

        phi_before_sparsify = np.array(torch.matmul(phi_and_mat, (Iand_before + Ior_before) / 2))

        np.save(os.path.join(save_folder, "phi_after_sparsify.npy"), phi_after_sparsify)
        np.save(os.path.join(save_folder, "phi_before_sparsify.npy"), phi_before_sparsify)

        masks = np.array(masks)

        with open(osp.join(save_folder, "info_phi_abs_before.csv"), 'w', newline='') as file:
            writer = csv.writer(file)
            writer.writerow(["phi_abs_before", "order", "player_1", "player_2", "player_3", "player_4",
                             "player_5", "player_6", "player_7", "player_8", "player_9", "player_10"])

            top_k_idx = np.abs(phi_before_sparsify).argsort()[::-1]
            for idx in top_k_idx:
                phi = np.array([phi_before_sparsify[idx]])
                order = np.where(masks[idx] == True)[0].shape[0]
                order = np.array([order])
                if order >= 2:
                    mask = masks[idx].astype(int)
                    writer.writerow(np.concatenate((phi, order, mask)))

        with open(osp.join(save_folder, "info_phi_abs_after.csv"), 'w', newline='') as file:
            writer = csv.writer(file)
            writer.writerow(["phi_abs_after", "order", "player_1", "player_2", "player_3", "player_4",
                             "player_5", "player_6", "player_7", "player_8", "player_9", "player_10"])

            top_k_idx = np.abs(phi_after_sparsify).argsort()[::-1]
            for idx in top_k_idx:
                phi = np.array([phi_after_sparsify[idx]])
                order = np.where(masks[idx] == True)[0].shape[0]
                order = np.array([order])
                if order >= 2:
                    mask = masks[idx].astype(int)
                    writer.writerow(np.concatenate((phi, order, mask)))




