import os
import os.path as osp
import csv

import numpy as np
import torch
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
from analysis_importance import generate_subset_masks


def makedirs(path):
    if not os.path.exists(path):
        os.makedirs(path)


if __name__ == '__main__':
    andor_load_dir = "eval_andor"
    phiS_load_dir = "analysis_importance"
    save_dir = "analysis_andor_chosen_S"
    makedirs(save_dir)

    FONT = 20

    # trick = "pqa"
    # loss = "l1"
    # qthres = 0.01
    # param = f"-trick-{trick}-loss-{loss}-qthres-{qthres}"
    # reward_way = "gt-log-odds-minus-mean-minus-empty"

    # loss = "l1"
    # qthres = 0.4
    # lr = 1e-6
    # trick = "pq"
    # qstd = "vN_vEmpty_mean"
    # param = f"after_sparsifying-trick-{trick}-loss-{loss}-lr-{lr}-qthres-{qthres}-qstd-{qstd}"
    # reward_way = "gt-log-odds"

    loss = "l1_for_6_10"
    reward_way = "gt-log-odds-minus-mean"
    qthres = 0.4
    lr = 1e-6
    weight = 5
    trick = "pqa"
    lr_way = "a_1"
    qstd = "vN_vEmpty_mean"
    param = f"after_sparsifying-trick-{trick}-loss-{loss}-lr-{lr}-lr-way-{lr_way}-qthres-{qthres}-qstd-{qstd}-weight-{weight}"
    threshold_tau = 0.15

    filenames = [filename for filename in os.listdir(andor_load_dir) if filename.startswith("id")]
    for idx, filename in enumerate(filenames):
        sample_id = filename.split(".")[0]
        # print(idx, sample_id)

        save_folder = osp.join(save_dir, sample_id, reward_way, param, f"threshold_{threshold_tau}")
        makedirs(save_folder)

        andor_folder = osp.join(andor_load_dir, sample_id, reward_way)
        masks = np.load(osp.join(andor_folder, "before_sparsify", "masks.npy"))
        Iand_after = np.load(osp.join(andor_folder, param, "Iand.npy"))
        Ior_after = np.load(osp.join(andor_folder, param, "Ior.npy"))

        masks_andor = np.concatenate((masks, masks), axis=0)
        I_andor = np.concatenate((Iand_after, Ior_after), axis=0)

        I_strength = np.abs(I_andor)
        strength_max = I_strength[np.argsort(-I_strength)][0]
        threshold = strength_max * threshold_tau

        index = np.where(I_strength > threshold)[0]
        salient_chosen_masks = masks_andor[index, :]
        salient_chosen_andor = I_andor[index]

        phis_S = []
        masks_S = []
        orders_S = []
        with open(osp.join(phiS_load_dir, sample_id, reward_way, param, "info_phi_abs_after.csv"),
                  newline='') as csvfile:
            reader = csv.reader(csvfile)
            for i, row in enumerate(reader):
                if i == 0:
                    continue
                phis_S.append([float(item) for i, item in enumerate(row) if i == 0])
                orders_S.append([int(float(item)) for i, item in enumerate(row) if i == 1])
                masks_S.append([float(item) > 0.5 for i, item in enumerate(row) if i > 1])
        masks_S = np.array(masks_S, dtype=bool)
        phis_S = np.array(phis_S).flatten()
        orders_S = np.array(orders_S).flatten()

        salient_concepts = []
        # 整理一下，I，order，mask
        for i, chosen_mask in enumerate(salient_chosen_masks):
            I_chosen_mask = np.array([salient_chosen_andor[i]])
            order_chosen_mask = np.sum(chosen_mask)
            idx = np.where((masks_S == chosen_mask).all(axis=1))
            phi_chosen_mask = phis_S[idx]
            # salient_concepts.append(np.concatenate((I_chosen_mask, phi_chosen_mask, np.array([order_chosen_mask]), chosen_mask)))
            salient_concepts.append(np.concatenate((I_chosen_mask, np.array([order_chosen_mask]), chosen_mask)))
        print(filename, len(salient_concepts), 1 - len(salient_concepts) / 2048)

        sorted_salient_concepts = sorted(salient_concepts, key=lambda x: -np.abs(x[0]))
        with open(osp.join(save_folder, "salient_concepts.csv"), 'w', newline='') as file:
            writer = csv.writer(file)
            writer.writerow(["salient_concepts_num"])
            writer.writerow(salient_chosen_andor.shape)
            writer.writerow(["interaction", "order", "player_1", "player_2", "player_3", "player_4",
                             "player_5", "player_6", "player_7", "player_8", "player_9", "player_10"])
            for interaction in sorted_salient_concepts:
                writer.writerow(interaction)


        # phis_S = []
        # masks_S = []
        # orders_S = []
        # with open(osp.join(phiS_load_dir, sample_id, reward_way, param, "info_phi_abs_after.csv"),
        #           newline='') as csvfile:
        #     reader = csv.reader(csvfile)
        #     for i, row in enumerate(reader):
        #         if i == 0:
        #             continue
        #         phis_S.append([float(item) for i, item in enumerate(row) if i == 0])
        #         orders_S.append([int(float(item)) for i, item in enumerate(row) if i == 1])
        #         masks_S.append([float(item) > 0.5 for i, item in enumerate(row) if i > 1])
        # masks_S = np.array(masks_S, dtype=bool)
        # phis_S = np.array(phis_S).flatten()
        # orders_S = np.array(orders_S).flatten()
        #
        # coalitions = []
        # for i, chosen_mask in enumerate(salient_chosen_masks):
        #     idx = np.where((masks_S == chosen_mask).all(axis=1))
        #     phi_chosen_mask = phis_S[idx]
        #     order_chosen_mask = orders_S[idx]
        #     I_chosen_mask = np.array([salient_chosen_andor[i]])
        #     coalitions.append(np.concatenate((phi_chosen_mask, I_chosen_mask, order_chosen_mask, chosen_mask)))
        # print(filename, len(coalitions))


