import os
import os.path as osp
import csv

import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter


def makedirs(path):
    if not os.path.exists(path):
        os.makedirs(path)


def compute_mean_strength(masks, interactions):
    n_dim = masks.shape[1]
    i_orders = np.sum(masks, axis=1).astype(int)
    mean_strengths = []
    for i_order in range(n_dim + 1):
        indices = i_orders == i_order
        mean_strengths.append(np.abs(interactions[indices]).mean())
    mean_strengths = np.array(mean_strengths)

    return mean_strengths


def plot_mean_strength_andor(Iand_after_ori, Ior_after_ori, Iand_after_revise, Ior_after_revise,
                             masks, save_path, width=0.35):
    n_dim = masks.shape[1]

    mean_strength_and_ori = compute_mean_strength(masks, Iand_after_ori)
    mean_strength_or_ori = compute_mean_strength(masks, Ior_after_ori)
    mean_strength_and_revise = compute_mean_strength(masks, Iand_after_revise)
    mean_strength_or_revise = compute_mean_strength(masks, Ior_after_revise)

    plt.figure(figsize=(5, 4))
    X = np.arange(n_dim + 1)
    plt.plot(X[1:-1], mean_strength_and_ori[1:-1], linestyle='--', marker='o', markersize=3, color='C0', label="AND-ori")
    plt.plot(X[1:-1], mean_strength_and_revise[1:-1], marker='o', markersize=3, color='C0', label="AND-revise")

    plt.plot(X[1:-1], mean_strength_or_ori[1:-1], linestyle='--', marker='o', markersize=3, color='C1', label="OR-ori")
    plt.plot(X[1:-1], mean_strength_or_revise[1:-1], marker='o', markersize=3, color='C1', label="OR-revise")

    max_strength = np.max(np.maximum.reduce([mean_strength_and_ori, mean_strength_and_revise,
                                             mean_strength_or_ori, mean_strength_or_revise]))

    plt.yticks(np.round(np.linspace(0, max_strength, 5), 1), fontproperties='Times New Roman', size=FONT)
    plt.xticks([1, 3, 5, 7, 9], fontproperties='Times New Roman', size=FONT)
    plt.tight_layout()
    plt.legend(loc='upper right', prop={'size': FONT-2})
    ax = plt.gca()  # 获得坐标轴的句柄
    ax.spines['bottom'].set_linewidth(2)  ###设置底部坐标轴的粗细
    ax.spines['left'].set_linewidth(2)  ####设置左边坐标轴的粗细
    ax.spines['right'].set_linewidth(2)  ###设置右边坐标轴的粗细
    ax.spines['top'].set_linewidth(2)  ####设置上部坐标轴的粗细
    plt.savefig(save_path, bbox_inches='tight', transparent=True)
    plt.close("all")


def get_strength_mean_all(folders):
    for i, folder in enumerate(folders):
        sample_id = folder.split("/")[-1]
        print(i, sample_id)

        save_folder = osp.join(save_dir, param_revise)
        makedirs(save_folder)

        masks = np.load(osp.join(folder, reward_way_ori, "before_sparsify", "masks.npy"))
        Iand_after_ori = np.load(osp.join(folder, reward_way_ori, param_ori, "Iand.npy"))
        Ior_after_ori = np.load(osp.join(folder, reward_way_ori, param_ori, "Ior.npy"))
        Iand_after_revise = np.load(osp.join(folder, reward_way_revise, param_revise, "Iand.npy"))
        Ior_after_revise = np.load(osp.join(folder, reward_way_revise, param_revise, "Ior.npy"))

        plot_mean_strength_andor(Iand_after_ori, Ior_after_ori, Iand_after_revise, Ior_after_revise, masks,
                                 save_path=osp.join(save_folder, f"{sample_id}.png"))


if __name__ == '__main__':
    load_dir = "eval_andor"
    save_dir = "analysis_andor_paper_strength"
    makedirs(save_dir)

    rewards_mean = np.load(os.path.join(load_dir, "rewards_mean.npy"))
    rewards_mean = rewards_mean.astype(np.float32)

    FONT = 20

    loss = "l1"
    qthres = 0.4
    lr = 1e-6
    trick = "pq"
    qstd = "vN_vEmpty_mean"
    param_ori = f"after_sparsifying-trick-{trick}-loss-{loss}-lr-{lr}-qthres-{qthres}-qstd-{qstd}"
    reward_way_ori = "gt-log-odds"

    loss = "l1_for_6_10"
    qthres = 0.4
    lr = 1e-6
    weight = 5
    trick = "pqa"
    lr_way = "a_1"
    qstd = "vN_vEmpty_mean"
    param_revise = f"after_sparsifying-trick-{trick}-loss-{loss}-lr-{lr}-lr-way-{lr_way}-qthres-{qthres}-qstd-{qstd}-weight-{weight}"
    reward_way_revise = "gt-log-odds-minus-mean"

    sample_ids = sorted([sample_id for sample_id in os.listdir(load_dir) if sample_id.startswith("id")])

    get_strength_mean_all([osp.join(load_dir, sample_id) for sample_id in sample_ids])




