import os
import os.path as osp
import csv

import numpy as np
import torch
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
from analysis_importance import generate_subset_masks


def makedirs(path):
    if not os.path.exists(path):
        os.makedirs(path)


if __name__ == '__main__':
    concept_load_dir = "analysis_coalitions"
    phiS_load_dir = "analysis_importance"
    save_dir = "analysis_coalitions"

    # trick = "pqa"
    # loss = "l1"
    # qthres = 0.01
    # param = f"-trick-{trick}-loss-{loss}-qthres-{qthres}"
    # reward_way = "gt-log-odds-minus-mean-minus-empty"

    loss = "l1_for_6_10"
    reward_way = "gt-log-odds-minus-mean"
    qthres = 0.4
    lr = 1e-6
    weight = 5
    trick = "pqa"
    lr_way = "a_1"
    qstd = "vN_vEmpty_mean"
    param = f"after_sparsifying-trick-{trick}-loss-{loss}-lr-{lr}-lr-way-{lr_way}-qthres-{qthres}-qstd-{qstd}-weight-{weight}"

    filenames = [filename for filename in os.listdir(concept_load_dir) if filename.startswith("id")]
    for idx, filename in enumerate(filenames):
        sample_id = filename.split(".")[0]

        phis_S = []
        masks_S = []
        orders_S = []
        with open(osp.join(phiS_load_dir, sample_id, reward_way, param, "info_phi_abs_after.csv"), newline='') as csvfile:
            reader = csv.reader(csvfile)
            for i, row in enumerate(reader):
                if i == 0:
                    continue
                phis_S.append([float(item) for i, item in enumerate(row) if i == 0])
                orders_S.append([int(float(item)) for i, item in enumerate(row) if i == 1])
                masks_S.append([float(item) > 0.5 for i, item in enumerate(row) if i > 1])
        masks_S = np.array(masks_S, dtype=bool)
        phis_S = np.array(phis_S).flatten()
        orders_S = np.array(orders_S).flatten()

        chosen_masks = []
        with open(osp.join(concept_load_dir, filename), newline='') as csvfile:
            reader = csv.reader(csvfile)

            found_row = False
            for i, row in enumerate(reader):
                if found_row:
                    # 对满足条件的下一行进行操作
                    print(row)
                    chosen_mask = [float(item) > 0.5 for i, item in enumerate(row) if i > 1]
                    chosen_mask = np.array(chosen_mask, dtype=bool)

                    idx = np.where((masks_S == np.array(chosen_mask)).all(axis=1))
                    phi_chosen_mask = phis_S[idx]
                    order_chosen_mask = orders_S[idx]

                    chosen_masks.append(np.concatenate((phi_chosen_mask, order_chosen_mask, chosen_mask)))
                elif row == ['coalition', 'order', 'player_1', 'player_2', 'player_3', 'player_4', 'player_5',
                             'player_6', 'player_7', 'player_8', 'player_9', 'player_10']:
                    found_row = True  # 设置标志，表示找到了满足条件的行

        sorted_chosen_masks = sorted(chosen_masks, key=lambda x: -np.abs(x[0]))
        with open(osp.join(save_dir, f"coalitions_{sample_id}.csv"), 'w', newline='') as file:
            writer = csv.writer(file)
            writer.writerow(['phiS', 'order', 'player_1', 'player_2', 'player_3', 'player_4', 'player_5',
                             'player_6', 'player_7', 'player_8', 'player_9', 'player_10'])

            for interaction in sorted_chosen_masks:
                writer.writerow(interaction)



