import os.path
import argparse
from and_or_harsanyi import *
from and_or_harsanyi_utils import *
from interaction_utils import *


def plot_Vs_increasing(Vs, save_path, save_name, fluctuation, standard=None):
    os.makedirs(save_path, exist_ok=True)
    length = len(Vs)
    Vs = Vs[np.argsort(Vs)]

    Vs_upper = Vs + fluctuation
    Vs_lower = Vs - fluctuation

    plt.figure(figsize=(8, 4))
    plt.plot(np.arange(length), Vs)
    plt.fill_between(np.arange(length), Vs_lower, Vs_upper, color='lightblue', alpha=0.5)

    plt.yticks(fontproperties='Times New Roman', size=30)
    plt.xticks(fontproperties='Times New Roman', size=30)
    plt.legend(prop={'family': 'Times New Roman', 'size': 28}, loc="center right", bbox_to_anchor=(1, 0.5))
    ax = plt.gca()  # 获得坐标轴的句柄
    ax.spines['bottom'].set_linewidth(2)  ###设置底部坐标轴的粗细
    ax.spines['left'].set_linewidth(2)  ####设置左边坐标轴的粗细
    ax.spines['right'].set_linewidth(2)  ###设置右边坐标轴的粗细
    ax.spines['top'].set_linewidth(2)  ####设置上部坐标轴的粗细
    plt.tight_layout()
    plt.savefig(os.path.join(save_path, f"{save_name}.png"), bbox_inches='tight', transparent=True)
    plt.close("all")


parser = argparse.ArgumentParser(description="sparsify and-or harsanyi")
parser.add_argument('--device', default=0, type=int,
                    help="set the device.")
parser.add_argument('--reward_way', default="gt-log-odds-minus-mean", type=str,
                    help="the way for calculating the rewards."
                    "choose from: gt-log-odds, gt-log-odds-minus-mean, gt-log-odds-minus-mean-minus-empty, none")

parser.add_argument("--sparsify-trick", default="pqa", type=str,
                    help="the trick to sparsify and-or interactions: p | pq | pqa")
parser.add_argument("--sparsify-loss", default="l1_for_6_10", type=str,
                    help="use which type of loss to sparsify and or interactions"
                    "choose from: l1, l1_for_5_10, l1_for_6_10")
parser.add_argument("--sparsify-weight", default=5, type=int, help="the weight for high-order loss")
parser.add_argument("--sparsify-qthres", default=0.4, type=float,
                    help="the threshold to bound the magnitude of q: q in [-thres*std, thres*std]")
parser.add_argument("--sparsify-qstd", default="vN_vEmpty_mean", type=str,
                    help="the standard to bound the magnitude of q: q in [-thres*std, thres*std] "
                         "choose from: vS, vS-v0, vN-v0, maxIs_mean, vN_vEmpty_mean, none")
parser.add_argument("--sparsify-lr", default=1e-6, type=float,
                    help="the learning rate to learn (p and q, a)")
parser.add_argument("--sparsify-lr-way", default="a_1", type=str,
                    help="the learning rate to learn a adjust or not. chosen from: a, ori")
parser.add_argument("--sparsify-niter", default=50000, type=int,
                    help="how many iteractions to optimize (p and q)")

parser.add_argument("--save_dir", default="eval_andor", type=str)
args = parser.parse_args()


max_strengths = []
for idx, filename in enumerate(os.listdir(args.save_dir)):
    if os.path.isdir(os.path.join(args.save_dir, filename)):
        load_folder = os.path.join(args.save_dir, filename, args.reward_way, "before_sparsify")

        Iand = np.load(os.path.join(load_folder, "Iand.npy"))
        strength_Iand = np.abs(Iand)
        strength_Iand = strength_Iand[np.argsort(-strength_Iand)]

        Ior = np.load(os.path.join(load_folder, "Ior.npy"))
        strength_Ior = np.abs(Ior)
        strength_Ior = strength_Ior[np.argsort(-strength_Ior)]

        max_strength = np.concatenate((strength_Iand[:int(Iand.shape[0] * 0.01)], strength_Ior[:int(Iand.shape[0] * 0.01)]))
        max_strengths.append(max_strength)
max_strength_mean = np.array(np.mean(np.array(max_strengths)))
print("max_strength_mean: ", max_strength_mean)

vN_vEmpty = []
for idx, filename in enumerate(os.listdir(args.save_dir)):
    if os.path.isdir(os.path.join(args.save_dir, filename)):
        load_folder = os.path.join(args.save_dir, filename, args.reward_way, "before_sparsify")

        vN = np.load(os.path.join(load_folder, "v_N.npy"))
        vEmpty = np.load(os.path.join(load_folder, "v_empty.npy"))
        vN_vEmpty.append(np.abs(vN - vEmpty))
vN_vEmpty_mean = np.array(np.mean(np.array(vN_vEmpty)))
print("vN_vEmpty: ", vN_vEmpty_mean)


os.makedirs(args.save_dir, exist_ok=True)
rewards_mean = np.load(os.path.join(args.save_dir, "rewards_mean.npy"))
rewards_mean = rewards_mean.astype(np.float32)


for idx, filename in enumerate(os.listdir(args.save_dir)):

    if os.path.isdir(os.path.join(args.save_dir, filename)):

        # 如果存在这个文件夹的话，就直接下一条数据
        print(idx, filename)

        save_folder = os.path.join(args.save_dir, filename, args.reward_way,
                                   f"after_sparsifying-trick-{args.sparsify_trick}-loss-{args.sparsify_loss}"
                                   f"-lr-{args.sparsify_lr}-lr-way-{args.sparsify_lr_way}-qthres-{args.sparsify_qthres}"
                                   f"-qstd-{args.sparsify_qstd}-weight-{args.sparsify_weight}")

        load_folder = os.path.join(args.save_dir, filename, args.reward_way, "before_sparsify")
        masks = np.load(os.path.join(load_folder, "masks.npy"))
        v_N = np.load(os.path.join(load_folder, "v_N.npy"))
        v_Empty = np.load(os.path.join(load_folder, "v_empty.npy"))
        rewards_mean_ids = np.load(os.path.join(load_folder, "rewards_mean_ids.npy"))

        reward2Iand = np.load(os.path.join(load_folder, "reward2Iand.npy"))
        reward2Ior = np.load(os.path.join(load_folder, "reward2Ior.npy"))

        Iand = np.load(os.path.join(load_folder, "Iand.npy"))
        Ior = np.load(os.path.join(load_folder, "Ior.npy"))

        if args.sparsify_trick == "pqa":
            load_rewards_folder = os.path.join(args.save_dir, filename, "without_minus_mean")
            rewards = np.load(os.path.join(load_rewards_folder, "rewards.npy"))
        elif args.sparsify_trick == "pq":
            rewards = np.load(os.path.join(load_folder, "rewards.npy"))
            print(args.sparsify_trick)

        calculator = {
            "rewards": torch.from_numpy(rewards).to(args.device),
            "rewards_mean_ids": torch.from_numpy(rewards_mean_ids).to(torch.int64).to(args.device),
            "rewards_mean_k": torch.from_numpy(rewards_mean).to(args.device),
            "masks": masks,
            "reward2Iand": torch.from_numpy(reward2Iand).to(args.device),
            "reward2Ior": torch.from_numpy(reward2Ior).to(args.device),
            "v_N": torch.from_numpy(v_N).to(args.device),
            "v_empty": torch.from_numpy(v_Empty).to(args.device),
            "idx": filename[-4:],
            "max_strength_mean": torch.from_numpy(max_strength_mean).to(args.device),
            "vN_vEmpty_mean": torch.from_numpy(vN_vEmpty_mean).to(args.device),
        }

        sparsifier = AndOrHarsanyiSparsifier(
            calculator=calculator, trick=args.sparsify_trick,
            loss=args.sparsify_loss, qthres=args.sparsify_qthres,
            qstd=args.sparsify_qstd, lr=args.sparsify_lr,
            niter=args.sparsify_niter, weight=args.sparsify_weight,
            alr=int(args.sparsify_lr_way.split("_")[1]),
        )

        sparsifier.sparsify(verbose_folder=osp.join(save_folder, "sparsify_verbose"))
        Iand_s, Ior_s = sparsifier.get_interaction()
        sparsifier.save(save_folder=osp.join(save_folder))

        if args.sparsify_trick == "pqa":
            plot_Vs_increasing((calculator["rewards"] - sparsifier.a[calculator["rewards_mean_ids"].to(torch.int64)]).cpu().numpy(),
                               save_path=osp.join(save_folder, "sparsify_verbose"),
                               save_name="plot_vS_increasing_q",
                               fluctuation=sparsifier.q.cpu().numpy(),
                               standard=None)

            plot_Vs_increasing((calculator["rewards"] - sparsifier.a[calculator["rewards_mean_ids"].to(torch.int64)]).cpu().numpy(),
                               save_path=osp.join(save_folder, "sparsify_verbose"),
                               save_name="plot_vS_increasing_q_bound",
                               fluctuation=sparsifier.q_bound.cpu().numpy(),
                               standard=None)
        elif args.sparsify_trick == "pq":
            plot_Vs_increasing(calculator["rewards"].cpu().numpy(),
                               save_path=osp.join(save_folder, "sparsify_verbose"),
                               save_name="plot_vS_increasing",
                               fluctuation=sparsifier.q_bound.cpu().numpy(),
                               standard=None)

        with open(osp.join(save_folder, "info.txt"), 'w') as f:
            f.write("\n[Before Sparsifying]\n")
            f.write(f"\tSum of I^and and I^or: {np.sum(Iand) / 2 + np.sum(Ior) / 2}\n")
            f.write("\n[After Sparsifying]\n")
            f.write(f"\tSum of I^and and I^or: {torch.sum(Iand_s) + torch.sum(Ior_s)}\n")
            f.write(f"\n[v_N]: \t{v_N}\n")
            f.write(f"\n[v_empty]: \t{v_Empty}\n")
            f.write(f"\n[v_N - v_empty]: \t{v_N - v_Empty}\n")




