import torch
from torch import optim
from tqdm import tqdm
import numpy as np
import os
import os.path as osp
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.patches import Rectangle
from typing import Callable, List, Tuple, Union
import random
import sys
import math

from interaction_utils import generate_all_masks, generate_subset_masks, generate_reverse_subset_masks, \
    generate_set_with_intersection_masks


def get_reward2Iand_mat(dim):
    '''
    The transformation matrix (containing 0, 1, -1's) from reward to and-interaction (Harsanyi)
    :param dim: the input dimension n
    :return: a matrix, with shape 2^n * 2^n
    '''
    all_masks = torch.BoolTensor(generate_all_masks(dim))
    n_masks, _ = all_masks.shape
    mat = []
    for i in tqdm(range(n_masks), ncols=100, desc="Generating mask"):
        mask_S = all_masks[i]
        row = torch.zeros(n_masks)
        # ===============================================================================================
        # Note: I(S) = \sum_{L\subseteq S} (-1)^{s-l} v(L)
        mask_Ls, L_indices = generate_subset_masks(mask_S, all_masks)
        L_indices = (L_indices == True).nonzero(as_tuple=False)
        assert mask_Ls.shape[0] == L_indices.shape[0]
        row[L_indices] = torch.pow(-1., mask_S.sum() - mask_Ls.sum(dim=1)).unsqueeze(1)
        # ===============================================================================================
        mat.append(row.clone())
    mat = torch.stack(mat).float()
    return mat


def get_reward2Ior_mat(dim):
    '''
    The transformation matrix (containing 0, 1, -1's) from reward to or-interaction
    :param dim: the input dimension n
    :return: a matrix, with shape 2^n * 2^n
    '''
    all_masks = torch.BoolTensor(generate_all_masks(dim))
    n_masks, _ = all_masks.shape
    mat = []
    for i in range(n_masks):
        mask_S = all_masks[i]
        row = torch.zeros(n_masks)
        # ===============================================================================================
        # Note: I(S) = -\sum_{L\subseteq S} (-1)^{s+(n-l)-n} v(N\L) if S is not empty
        if mask_S.sum() == 0:
            row[i] = 1.
        else:
            mask_NLs, NL_indices = generate_reverse_subset_masks(mask_S, all_masks)
            NL_indices = (NL_indices == True).nonzero(as_tuple=False)
            assert mask_NLs.shape[0] == NL_indices.shape[0]
            row[NL_indices] = - torch.pow(-1., mask_S.sum() + mask_NLs.sum(dim=1) + dim).unsqueeze(1)
        # ================================================================================================
        mat.append(row.clone())
    mat = torch.stack(mat).float()
    return mat


def get_Iand2reward_mat(dim):
    all_masks = torch.BoolTensor(generate_all_masks(dim))
    n_masks, _ = all_masks.shape
    mat = []
    for i in range(n_masks):
        mask_S = all_masks[i]
        row = torch.zeros(n_masks)
        # ================================================================================================
        # Note: v(S) = \sum_{L\subseteq S} I(S)
        mask_Ls, L_indices = generate_subset_masks(mask_S, all_masks)
        row[L_indices] = 1.
        # ================================================================================================
        mat.append(row.clone())
    mat = torch.stack(mat).float()
    return mat


def get_Ior2reward_mat(dim):
    all_masks = torch.BoolTensor(generate_all_masks(dim))
    n_masks, _ = all_masks.shape
    mat = []
    mask_empty = torch.zeros(dim).bool()
    _, empty_indice = generate_subset_masks(mask_empty, all_masks)
    for i in range(n_masks):
        mask_S = all_masks[i]
        row = torch.zeros(n_masks)
        # ================================================================================================
        # Note: v(S) = I(\emptyset) + \sum_{L: L\union S\neq \emptyset} I(S)
        row[empty_indice] = 1.
        mask_Ls, L_indices = generate_set_with_intersection_masks(mask_S, all_masks)
        row[L_indices] = 1.
        # ================================================================================================
        mat.append(row.clone())
    mat = torch.stack(mat).float()
    return mat


def l1_on_given_dim(vector: torch.Tensor, indices: List):
    assert len(vector.shape) == 1
    strength = torch.abs(vector)
    return torch.sum(strength[indices])


def train_p(rewards, loss_type, lr, niter):
    device = rewards.device
    n_dim = int(np.log2(rewards.numel()))
    reward2Iand = get_reward2Iand_mat(n_dim).to(device)
    reward2Ior = get_reward2Ior_mat(n_dim).to(device)

    # Trick: explicitly revise the reward (TODO: encapsulate)
    p = torch.zeros_like(rewards).requires_grad_(True)
    optimizer = optim.SGD([p], lr=0.0, momentum=0.9)

    log_lr = np.log10(lr)
    eta_list = np.logspace(log_lr, log_lr - 1, niter)

    if loss_type == "l1":
        losses = {"loss": []}
    elif loss_type.startswith("l1_on"):
        ratio = float(loss_type.split("_")[-1])
        Iand_p = torch.matmul(reward2Iand, 0.5 * rewards + p)
        Ior_p = torch.matmul(reward2Ior, 0.5 * rewards - p)
        num_noisy_pattern = int(ratio * (Iand_p.shape[0] + Ior_p.shape[0]))
        print("# noisy patterns", num_noisy_pattern)
        noisy_pattern_indices = torch.argsort(torch.abs(torch.cat([Iand_p, Ior_p]))).tolist()[:num_noisy_pattern]
        losses = {"loss": [], "noise_ratio": []}
    else:
        raise NotImplementedError(f"Loss type {loss_type} unrecognized.")

    progresses = {"I_and": [], "I_or": []}

    # add by zhouhuilin
    progresses_pq = {"p": []}

    pbar = tqdm(range(niter), desc="Optimizing p", ncols=100)
    for it in pbar:
        Iand_p = torch.matmul(reward2Iand, 0.5 * rewards + p)
        Ior_p = torch.matmul(reward2Ior, 0.5 * rewards - p)

        if loss_type == "l1":
            loss = torch.sum(torch.abs(Iand_p)) + torch.sum(torch.abs(Ior_p))  # 02-27: L1 penalty.
            losses["loss"].append(loss.item())
        elif loss_type.startswith("l1_on"):
            loss = l1_on_given_dim(torch.cat([Iand_p, Ior_p]), indices=noisy_pattern_indices)
            losses["loss"].append(loss.item())
            losses["noise_ratio"].append(loss.item() / torch.sum(torch.abs(torch.cat([Iand_p, Ior_p]))).item())
        else:
            raise NotImplementedError(f"Loss type {loss_type} unrecognized.")

        if it + 1 < niter:
            optimizer.zero_grad()
            optimizer.param_groups[0]["lr"] = eta_list[it]
            loss.backward()
            optimizer.step()

        if (it + 1) % 1000 == 0 or it == 0:
            progresses["I_and"].append(Iand_p.detach().cpu().numpy())
            progresses["I_or"].append(Ior_p.detach().cpu().numpy())
            pbar.set_postfix_str(f"loss={loss.item():.4f}")

            # add by zhouhuilin
            progresses_pq["p"].append(p.detach().cpu().numpy())

    return p.detach(), losses, progresses, progresses_pq


def train_p_q(rewards, masks, weight, loss_type, lr, niter, qbound, reward2Iand=None, reward2Ior=None):
    device = rewards.device
    n_dim = int(np.log2(rewards.numel()))
    if reward2Iand is None:
        reward2Iand = get_reward2Iand_mat(n_dim).to(device)
        reward2Ior = get_reward2Ior_mat(n_dim).to(device)

    log_lr = np.log10(lr)
    eta_list = np.logspace(log_lr, log_lr - 1, niter)

    # Trick: explicitly revise the reward (TODO: encapsulate)
    p = torch.zeros_like(rewards).requires_grad_(True)
    q = torch.zeros_like(rewards).requires_grad_(True)
    optimizer = optim.SGD([p, q], lr=0.0, momentum=0.9)
    # optimizer = optim.SGD([
    #     {'params': [p], 'lr': 0.0, 'momentum': 0.9},
    #     {'params': [q], 'lr': 0.0, 'momentum': 0.9}])

    if loss_type == "l1":
        losses = {"loss": []}
    elif loss_type.startswith("l1_on"):
        ratio = float(loss_type.split("_")[-1])
        Iand_p = torch.matmul(reward2Iand, 0.5 * rewards + p)
        Ior_p = torch.matmul(reward2Ior, 0.5 * rewards - p)
        num_noisy_pattern = int(ratio * (Iand_p.shape[0] + Ior_p.shape[0]))
        print("# noisy patterns", num_noisy_pattern)
        noisy_pattern_indices = torch.argsort(torch.abs(torch.cat([Iand_p, Ior_p]))).tolist()[:num_noisy_pattern]
        losses = {"loss": [], "noise_ratio": []}
    elif loss_type.startswith("l1_for"):
        lower, upper = int(loss_type.split("_")[-2]), int(loss_type.split("_")[-1])
        i_orders = np.sum(masks, axis=1).astype(int)
        noisy_pattern_indices = []
        for i_order in range(lower, upper + 1):
            indices_order = i_orders == i_order
            indices_order = np.where(indices_order)[0]
            noisy_pattern_indices.extend(list(indices_order))
        num_noisy_pattern = len(noisy_pattern_indices)
        print("# noisy patterns", num_noisy_pattern)
        losses = {"loss": [], "noise_ratio": []}
    else:
        raise NotImplementedError(f"Loss type {loss_type} unrecognized.")
    progresses = {"I_and": [], "I_or": []}

    # add by zhouhuilin
    progresses_pq = {"p": [], "q": []}

    pbar = tqdm(range(niter), desc="Optimizing pq", ncols=100)
    for it in pbar:

        # q.data = torch.clamp(q.data, -qbound, qbound)
        q.data = torch.max(torch.min(q.data, qbound), -qbound)
        Iand_p = torch.matmul(reward2Iand, 0.5 * (rewards + q) + p)
        Ior_p = torch.matmul(reward2Ior, 0.5 * (rewards + q) - p)

        if loss_type == "l1":
            loss = torch.sum(torch.abs(Iand_p)) + torch.sum(torch.abs(Ior_p))  # 02-27: L1 penalty.
            losses["loss"].append(loss.item())
        elif loss_type.startswith("l1_on"):
            loss = l1_on_given_dim(torch.cat([Iand_p, Ior_p]), indices=noisy_pattern_indices)
            losses["loss"].append(loss.item())
            losses["noise_ratio"].append(loss.item() / torch.sum(torch.abs(torch.cat([Iand_p, Ior_p]))).item())
        elif loss_type.startswith("l1_for"):
            loss_high_orders = l1_on_given_dim(Iand_p, indices=noisy_pattern_indices) + \
                               l1_on_given_dim(Ior_p, indices=noisy_pattern_indices)
            loss_all_orders = torch.sum(torch.abs(Iand_p)) + torch.sum(torch.abs(Ior_p))
            loss = 0.1 * loss_all_orders + weight * loss_high_orders
            losses["loss"].append(loss.item())
            losses["noise_ratio"].append(
                loss_high_orders.item() / torch.sum(torch.abs(torch.cat([Iand_p, Ior_p]))).item())
        else:
            raise NotImplementedError(f"Loss type {loss_type} unrecognized.")

        if it + 1 < niter:
            optimizer.zero_grad()
            for param_group in optimizer.param_groups:
                param_group['lr'] = eta_list[it]
            loss.backward()
            optimizer.step()

        if (it + 1) % 1000 == 0 or it == 0:
            progresses["I_and"].append(Iand_p.detach().cpu().numpy())
            progresses["I_or"].append(Ior_p.detach().cpu().numpy())
            pbar.set_postfix_str(f"loss={loss.item():.4f}")

            # add by zhouhuilin
            progresses_pq["p"].append(p.detach().cpu().numpy())
            progresses_pq["q"].append(q.detach().cpu().numpy())


    return p.detach(), q.detach(), losses, progresses, progresses_pq


# def train_p_q_a(rewards, a, rewards_mean_ids, rewards_mean_k, v_empty_, masks, loss_type, lr, niter, qbound,
#                 reward2Iand=None, reward2Ior=None):
#     device = rewards.device
#     n_dim = int(np.log2(rewards.numel()))
#     if reward2Iand is None:
#         reward2Iand = get_reward2Iand_mat(n_dim).to(device)
#         reward2Ior = get_reward2Ior_mat(n_dim).to(device)
#
#     log_lr = np.log10(lr)
#     eta_list = np.logspace(log_lr, log_lr - 1, niter)
#
#     # Trick: explicitly revise the reward (TODO: encapsulate)
#     p = torch.zeros_like(rewards).requires_grad_(True)
#     q = torch.zeros_like(rewards).requires_grad_(True)
#     a = a.requires_grad_(True)
#     optimizer = optim.SGD([p, q, a], lr=0.0, momentum=0.9)
#     # optimizer = optim.SGD([
#     #     {'params': [p], 'lr': 0.0, 'momentum': 0.9},
#     #     {'params': [q], 'lr': 0.0, 'momentum': 0.9}])
#
#     if loss_type == "l1":
#         losses = {"loss": []}
#     elif loss_type.startswith("l1_for"):
#         lower, upper = int(loss_type.split("_")[-2]), int(loss_type.split("_")[-1])
#         i_orders = np.sum(masks, axis=1).astype(int)
#         noisy_pattern_indices = []
#         for i_order in range(lower, upper + 1):
#             indices_order = i_orders == i_order
#             indices_order = np.where(indices_order)[0]
#             noisy_pattern_indices.extend(list(indices_order))
#         num_noisy_pattern = len(noisy_pattern_indices)
#         print("# noisy patterns", num_noisy_pattern)
#         losses = {"loss": [], "noise_ratio": []}
#     else:
#         raise NotImplementedError(f"Loss type {loss_type} unrecognized.")
#     progresses = {"I_and": [], "I_or": []}
#
#     # add by zhouhuilin
#     progresses_pqa = {"p": [], "q": [], "a": []}
#
#     pbar = tqdm(range(niter), desc="Optimizing pqa", ncols=100)
#     for it in pbar:
#
#         # q.data = torch.clamp(q.data, -qbound, qbound)
#         q.data = torch.max(torch.min(q.data, qbound), -qbound)
#         Iand_p = torch.matmul(reward2Iand, 0.5 * ((rewards - a - v_empty_) + q) + p)
#         Ior_p = torch.matmul(reward2Ior, 0.5 * ((rewards - a - v_empty_) + q) - p)
#
#         if loss_type == "l1":
#             loss = torch.sum(torch.abs(Iand_p)) + torch.sum(torch.abs(Ior_p))  # 02-27: L1 penalty.
#             losses["loss"].append(loss.item())
#         elif loss_type.startswith("l1_for"):
#             loss = l1_on_given_dim(Iand_p, indices=noisy_pattern_indices) + \
#                    l1_on_given_dim(Ior_p, indices=noisy_pattern_indices)
#             losses["loss"].append(loss.item())
#             losses["noise_ratio"].append(loss.item() / torch.sum(torch.abs(torch.cat([Iand_p, Ior_p]))).item())
#         else:
#             raise NotImplementedError(f"Loss type {loss_type} unrecognized.")
#
#         if it + 1 < niter:
#             optimizer.zero_grad()
#             for i, param in enumerate(optimizer.param_groups):
#                 param["lr"] = eta_list[it]
#             loss.backward()
#             optimizer.step()
#
#         if (it + 1) % 1000 == 0 or it == 0:
#             progresses["I_and"].append(Iand_p.detach().cpu().numpy())
#             progresses["I_or"].append(Ior_p.detach().cpu().numpy())
#             pbar.set_postfix_str(f"loss={loss.item():.4f}")
#
#             # add by zhouhuilin
#             progresses_pqa["p"].append(p.detach().cpu().numpy())
#             progresses_pqa["q"].append(q.detach().cpu().numpy())
#             progresses_pqa["a"].append(a.detach().cpu().numpy())
#
#     return p.detach(), q.detach(), a.detach(), losses, progresses, progresses_pqa


def train_p_q_a(rewards, alr, weight, rewards_mean_ids, rewards_mean_k, masks, loss_type, lr, niter, qbound,
                reward2Iand=None, reward2Ior=None):
    device = rewards.device
    n_dim = int(np.log2(rewards.numel()))
    if reward2Iand is None:
        reward2Iand = get_reward2Iand_mat(n_dim).to(device)
        reward2Ior = get_reward2Ior_mat(n_dim).to(device)

    log_lr = np.log10(lr)
    eta_list = np.logspace(log_lr, log_lr - 1, niter)

    lr_list = []
    for k in range(len(rewards_mean_k)):
        if k <= len(rewards_mean_k) // 2:
            combination_count = math.comb(len(rewards_mean_k) // 2, k)
        else:
            combination_count = math.comb(len(rewards_mean_k) // 2, k - len(rewards_mean_k) // 2)
        eta_list_for_ak = eta_list / (combination_count * alr)
        lr_list.append(eta_list_for_ak)

    # # Trick: explicitly revise the reward (TODO: encapsulate)
    # p = torch.zeros_like(rewards).requires_grad_(True)
    # q = torch.zeros_like(rewards).requires_grad_(True)
    # rewards_mean_k = rewards_mean_k.requires_grad_(True)
    # optimizer = optim.SGD([p, q, rewards_mean_k], lr=0.0, momentum=0.9)

    p = torch.zeros_like(rewards).requires_grad_(True)
    q = torch.zeros_like(rewards).requires_grad_(True)
    a_list = []
    for a_k in rewards_mean_k:
        a_list.append(torch.tensor([a_k], requires_grad=True))

    parameters = [{'params': [p], 'lr': 0.0, 'momentum': 0.9},
                  {'params': [q], 'lr': 0.0, 'momentum': 0.9}]
    for i, _ in enumerate(a_list):
        parameters.append({'params': [a_list[i]], 'lr': 0.0, 'momentum': 0.9})
    optimizer = optim.SGD(parameters)

    if loss_type == "l1":
        losses = {"loss": []}
    elif loss_type.startswith("l1_for"):
        lower, upper = int(loss_type.split("_")[-2]), int(loss_type.split("_")[-1])
        i_orders = np.sum(masks, axis=1).astype(int)
        noisy_pattern_indices = []
        for i_order in range(lower, upper + 1):
            indices_order = i_orders == i_order
            indices_order = np.where(indices_order)[0]
            noisy_pattern_indices.extend(list(indices_order))
        num_noisy_pattern = len(noisy_pattern_indices)
        print("# noisy patterns", num_noisy_pattern)
        losses = {"loss": [], "noise_ratio": []}
    else:
        raise NotImplementedError(f"Loss type {loss_type} unrecognized.")
    progresses = {"I_and": [], "I_or": []}

    # add by zhouhuilin
    progresses_pqa = {"p": [], "q": [], "rewards_mean_k": []}

    pbar = tqdm(range(niter), desc="Optimizing pqa", ncols=100)
    for it in pbar:

        # q.data = torch.clamp(q.data, -qbound, qbound)
        q.data = torch.max(torch.min(q.data, qbound), -qbound)
        a_vector = torch.index_select(torch.stack(a_list).to(device), 0, rewards_mean_ids).squeeze()
        Iand_p = torch.matmul(reward2Iand,
                              0.5 * ((rewards - a_vector) + q) + p)
        Ior_p = torch.matmul(reward2Ior,
                             0.5 * ((rewards - a_vector) + q) - p)

        if loss_type == "l1":
            loss = torch.sum(torch.abs(Iand_p)) + torch.sum(torch.abs(Ior_p))  # 02-27: L1 penalty.
            losses["loss"].append(loss.item())
        elif loss_type.startswith("l1_for"):
            loss_high_orders = l1_on_given_dim(Iand_p, indices=noisy_pattern_indices) + \
                   l1_on_given_dim(Ior_p, indices=noisy_pattern_indices)
            loss_all_orders = torch.sum(torch.abs(Iand_p)) + torch.sum(torch.abs(Ior_p))
            loss = 0.1 * loss_all_orders + weight * loss_high_orders
            losses["loss"].append(loss.item())
            losses["noise_ratio"].append(loss_high_orders.item() / torch.sum(torch.abs(torch.cat([Iand_p, Ior_p]))).item())
        else:
            raise NotImplementedError(f"Loss type {loss_type} unrecognized.")

        if it + 1 < niter:
            optimizer.zero_grad()
            for i, param in enumerate(optimizer.param_groups):
                if i == 0 or i == 1:
                    param["lr"] = eta_list[it]
                else:
                    param["lr"] = lr_list[i - 2][it]
            # optimizer.param_groups[0]["lr"] = eta_list[it]
            loss.backward()
            optimizer.step()

        if (it + 1) % 1000 == 0 or it == 0:
            progresses["I_and"].append(Iand_p.detach().cpu().numpy())
            progresses["I_or"].append(Ior_p.detach().cpu().numpy())
            pbar.set_postfix_str(f"loss={loss.item():.4f}")

            # add by zhouhuilin
            progresses_pqa["p"].append(p.detach().cpu().numpy())
            progresses_pqa["q"].append(q.detach().cpu().numpy())
            for k, a_k in enumerate(a_list):
                rewards_mean_k[k] = a_k
            progresses_pqa["rewards_mean_k"].append(rewards_mean_k.detach().cpu().numpy())

    return p.detach(), q.detach(), rewards_mean_k.detach(), losses, progresses, progresses_pqa


def train_p_q_a_ori(rewards, alr, weight, rewards_mean_ids, rewards_mean_k, masks, loss_type, lr, niter, qbound,
                reward2Iand=None, reward2Ior=None):
    device = rewards.device
    n_dim = int(np.log2(rewards.numel()))
    if reward2Iand is None:
        reward2Iand = get_reward2Iand_mat(n_dim).to(device)
        reward2Ior = get_reward2Ior_mat(n_dim).to(device)

    log_lr = np.log10(lr)
    eta_list = np.logspace(log_lr, log_lr - 1, niter)

    # Trick: explicitly revise the reward (TODO: encapsulate)
    p = torch.zeros_like(rewards).requires_grad_(True)
    q = torch.zeros_like(rewards).requires_grad_(True)
    rewards_mean_k = rewards_mean_k.requires_grad_(True)
    optimizer = optim.SGD([p, q, rewards_mean_k], lr=0.0, momentum=0.9)

    if loss_type == "l1":
        losses = {"loss": []}
    elif loss_type.startswith("l1_for"):
        lower, upper = int(loss_type.split("_")[-2]), int(loss_type.split("_")[-1])
        i_orders = np.sum(masks, axis=1).astype(int)
        noisy_pattern_indices = []
        for i_order in range(lower, upper + 1):
            indices_order = i_orders == i_order
            indices_order = np.where(indices_order)[0]
            noisy_pattern_indices.extend(list(indices_order))
        num_noisy_pattern = len(noisy_pattern_indices)
        print("# noisy patterns", num_noisy_pattern)
        losses = {"loss": [], "noise_ratio": []}
    else:
        raise NotImplementedError(f"Loss type {loss_type} unrecognized.")
    progresses = {"I_and": [], "I_or": []}

    # add by zhouhuilin
    progresses_pqa = {"p": [], "q": [], "rewards_mean_k": []}

    pbar = tqdm(range(niter), desc="Optimizing pqa", ncols=100)
    for it in pbar:

        # q.data = torch.clamp(q.data, -qbound, qbound)
        q.data = torch.max(torch.min(q.data, qbound), -qbound)
        a_vector = rewards_mean_k[rewards_mean_ids]
        Iand_p = torch.matmul(reward2Iand,
                              0.5 * ((rewards - a_vector) + q) + p)
        Ior_p = torch.matmul(reward2Ior,
                             0.5 * ((rewards - a_vector) + q) - p)

        if loss_type == "l1":
            loss = torch.sum(torch.abs(Iand_p)) + torch.sum(torch.abs(Ior_p))  # 02-27: L1 penalty.
            losses["loss"].append(loss.item())
        elif loss_type.startswith("l1_for"):
            loss_high_orders = l1_on_given_dim(Iand_p, indices=noisy_pattern_indices) + \
                   l1_on_given_dim(Ior_p, indices=noisy_pattern_indices)
            loss_all_orders = torch.sum(torch.abs(Iand_p)) + torch.sum(torch.abs(Ior_p))
            loss = 0.1 * loss_all_orders + weight * loss_high_orders
            losses["loss"].append(loss.item())
            losses["noise_ratio"].append(loss_high_orders.item() / torch.sum(torch.abs(torch.cat([Iand_p, Ior_p]))).item())
        else:
            raise NotImplementedError(f"Loss type {loss_type} unrecognized.")

        if it + 1 < niter:
            optimizer.zero_grad()
            for i, param in enumerate(optimizer.param_groups):
                param["lr"] = eta_list[it]
            loss.backward()
            optimizer.step()

        if (it + 1) % 1000 == 0 or it == 0:
            progresses["I_and"].append(Iand_p.detach().cpu().numpy())
            progresses["I_or"].append(Ior_p.detach().cpu().numpy())
            pbar.set_postfix_str(f"loss={loss.item():.4f}")

            # add by zhouhuilin
            progresses_pqa["p"].append(p.detach().cpu().numpy())
            progresses_pqa["q"].append(q.detach().cpu().numpy())
            progresses_pqa["rewards_mean_k"].append(rewards_mean_k.detach().cpu().numpy())

    return p.detach(), q.detach(), rewards_mean_k.detach(), losses, progresses, progresses_pqa


def train_p_q_baseline(
        f_baseline2rewards: Callable,
        f_masksbaseline2rewards: Callable,
        baseline_init: torch.Tensor,
        loss_type: str,
        partition_lr: float,
        baseline_lr: float,
        niter: int,
        interval: int,  # TODO: optimize pq for [interval] iterations, and then optimize the baseline value
        qbound: Union[float, torch.Tensor],
        baseline_min: Union[float, torch.Tensor],
        baseline_max: Union[float, torch.Tensor],
        reward2Iand: torch.Tensor,
        reward2Ior: torch.Tensor,
        calc_bs: int
):
    device = baseline_init.device
    masks, rewards = f_baseline2rewards(baseline_init, with_grad=False)

    log_partition_lr = np.log10(partition_lr)
    log_baseline_lr = np.log10(baseline_lr)
    eta_pq_list = np.logspace(log_partition_lr, log_partition_lr - 1, niter)
    eta_baseline_list = np.logspace(log_baseline_lr, log_baseline_lr - 1, niter)

    # Trick: explicitly revise the reward (TODO: encapsulate)
    p = torch.zeros_like(rewards).requires_grad_(True)
    q = torch.zeros_like(rewards).requires_grad_(True)
    # optimizer_pq = optim.SGD([p, q], lr=0.0, momentum=0.9)
    optimizer_pq = optim.SGD([p, q], lr=0.0, momentum=0.0)

    baseline = baseline_init.clone().requires_grad_(True)
    # optimizer_baseline = optim.SGD([baseline], lr=0.0, momentum=0.9)

    if loss_type == "l1":
        losses = {"loss": []}
    elif loss_type.startswith("l1_on"):
        ratio = float(loss_type.split("_")[-1])
        Iand_p = torch.matmul(reward2Iand, 0.5 * rewards + p)
        Ior_p = torch.matmul(reward2Ior, 0.5 * rewards - p)
        num_noisy_pattern = int(ratio * (Iand_p.shape[0] + Ior_p.shape[0]))
        print("# noisy patterns", num_noisy_pattern)
        noisy_pattern_indices = torch.argsort(torch.abs(torch.cat([Iand_p, Ior_p]))).tolist()[:num_noisy_pattern]
        losses = {"loss": [], "noise_ratio": []}
    else:
        raise NotImplementedError(f"Loss type {loss_type} unrecognized.")
    progresses = {"I_and": [], "I_or": [], "I_all": [], "baseline": []}

    pbar = tqdm(range(niter), desc="Optimizing p|q|b", ncols=100)
    for it in pbar:

        Iand_p = torch.matmul(reward2Iand, 0.5 * (rewards + q) + p)
        Ior_p = torch.matmul(reward2Ior, 0.5 * (rewards + q) - p)

        if loss_type == "l1":
            loss = torch.sum(torch.abs(Iand_p)) + torch.sum(torch.abs(Ior_p))  # 02-27: L1 penalty.
            losses["loss"].append(loss.item())
        elif loss_type.startswith("l1_on"):
            loss = l1_on_given_dim(torch.cat([Iand_p, Ior_p]), indices=noisy_pattern_indices)
            losses["loss"].append(loss.item())
            losses["noise_ratio"].append(loss.item() / torch.sum(torch.abs(torch.cat([Iand_p, Ior_p]))).item())
        else:
            raise NotImplementedError(f"Loss type {loss_type} unrecognized.")

        if it + 1 < niter:
            optimizer_pq.zero_grad()
            optimizer_pq.param_groups[0]["lr"] = eta_pq_list[it]
            loss.backward()
            optimizer_pq.step()

        # q.data = torch.clamp(q.data, -qbound, qbound)
        q.data = torch.max(torch.min(q.data, qbound), -qbound)

        if (it + 1) % interval == 0 and it != niter - 1:
            # [reward_coefs] is defined s.t. [loss = rewards * reward_coefs]
            if loss_type == "l1":
                with torch.no_grad():
                    reward_coefs = 0.5 * torch.matmul(torch.sign(Iand_p), reward2Iand) + \
                                   0.5 * torch.matmul(torch.sign(Ior_p), reward2Ior)
            else:
                raise NotImplementedError(f"Loss type {loss_type} unrecognized.")

            indices = list(range(masks.shape[0]))

            baseline_grad = 0.

            for batch_id in range(int(np.ceil(len(indices) / calc_bs))):
                batch_indices = indices[batch_id*calc_bs : batch_id*calc_bs+calc_bs]
                _, rewards = f_masksbaseline2rewards(masks[batch_indices], baseline, with_grad=True)

                loss_baseline = torch.matmul(reward_coefs[batch_indices], rewards)

                grad = torch.autograd.grad(loss_baseline, baseline)[0]
                baseline_grad += grad.data.clone()

            baseline.data = baseline.data - eta_baseline_list[it] * baseline_grad
            # baseline.data = torch.clamp(baseline.data, baseline_min, baseline_max)
            baseline.data = torch.max(torch.min(baseline.data, baseline_max), baseline_min)\
                                 .clone().detach().requires_grad_(True).float()

            # update the rewards, note that rewards only change after the optimization of baseline values
            masks, rewards = f_baseline2rewards(baseline, with_grad=False)
            baseline_numpy = baseline.data.clone().cpu().numpy()
            progresses["baseline"].append(baseline_numpy)

        if (it + 1) % (max(niter // 200, 1)) == 0 or it == 0 or it == niter - 1:
            I_and_numpy = Iand_p.detach().cpu().numpy()
            I_or_numpy = Ior_p.detach().cpu().numpy()
            I_all_numpy = np.concatenate([I_and_numpy, I_or_numpy])
            progresses["I_and"].append(I_and_numpy)
            progresses["I_or"].append(I_or_numpy)
            progresses["I_all"].append(I_all_numpy)
            pbar.set_postfix_str(f"loss={loss.item():.4f}")

    return p.detach(), q.detach(), baseline.detach(), losses, progresses


def train_baseline(
        f_baseline2rewards: Callable,
        f_masksbaseline2rewards: Callable,
        baseline_init: torch.Tensor,
        loss_type: str,
        baseline_lr: float,
        niter: int,
        baseline_min: Union[float, torch.Tensor],
        baseline_max: Union[float, torch.Tensor],
        reward2Iand: torch.Tensor,
        calc_bs: int
):
    device = baseline_init.device
    masks, rewards = f_baseline2rewards(baseline_init, with_grad=False)
    Iand = reward2Iand @ rewards

    log_baseline_lr = np.log10(baseline_lr)
    eta_baseline_list = np.logspace(log_baseline_lr, log_baseline_lr - 1, niter)

    baseline = baseline_init.clone().requires_grad_(True)
    # optimizer_baseline = optim.SGD([baseline], lr=0.0, momentum=0.9)

    if loss_type == "l1":
        losses = {"loss": []}
    else:
        raise NotImplementedError(f"Loss type {loss_type} unrecognized.")
    progresses = {
        "I_and": [Iand.data.clone().cpu().numpy()],
        "baseline": [baseline.data.clone().cpu().numpy()]
    }

    for it in range(niter):

        # [reward_coefs] is defined s.t. [loss = rewards * reward_coefs]
        if loss_type == "l1":
            with torch.no_grad():
                reward_coefs = torch.matmul(torch.sign(Iand), reward2Iand)
        else:
            raise NotImplementedError(f"Loss type {loss_type} unrecognized.")

        indices = list(range(masks.shape[0]))

        baseline_grad = 0.
        total_loss_baseline = 0.

        for batch_id in tqdm(range(int(np.ceil(len(indices) / calc_bs))), desc="Optimizing b", ncols=100):
            batch_indices = indices[batch_id*calc_bs : batch_id*calc_bs+calc_bs]
            _, rewards = f_masksbaseline2rewards(masks[batch_indices], baseline, with_grad=True)

            loss_baseline = torch.matmul(reward_coefs[batch_indices], rewards)

            grad = torch.autograd.grad(loss_baseline, baseline, only_inputs=True)[0]
            baseline_grad += grad.data.clone()
            total_loss_baseline += loss_baseline.item()

        baseline.data = baseline.data - eta_baseline_list[it] * baseline_grad
        # baseline.data = torch.clamp(baseline.data, baseline_min, baseline_max)
        baseline.data = torch.max(torch.min(baseline.data, baseline_max), baseline_min)\
                             .clone().detach().requires_grad_(True).float()

        # update the rewards, note that rewards change after the optimization of baseline values
        masks, rewards = f_baseline2rewards(baseline, with_grad=False)
        Iand = reward2Iand @ rewards
        I_and_numpy = Iand.detach().cpu().numpy()
        progresses["I_and"].append(I_and_numpy)

        baseline_numpy = baseline.data.clone().cpu().numpy()
        progresses["baseline"].append(baseline_numpy)
        losses["loss"].append(total_loss_baseline)
        I_and_numpy = Iand.detach().cpu().numpy()
        progresses["I_and"].append(I_and_numpy)

        print(f"[{it}/{niter}] loss: {total_loss_baseline:.4f}")



    return baseline.detach(), losses, progresses


# ==================================================
#               FOR  VISUALIZATION
# ==================================================

def plot_simple_line_chart(data, xlabel, ylabel, title, save_folder, save_name, X=None):
    os.makedirs(save_folder, exist_ok=True)
    plt.figure(figsize=(8, 6))
    plt.title(title)
    if X is None: X = np.arange(len(data))
    plt.plot(X, data)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.tight_layout()
    plt.savefig(osp.join(save_folder, f"{save_name}.png"), dpi=200)
    plt.close("all")


def visualize_pattern_interaction(
        coalition_masks: np.ndarray,
        interactions: np.ndarray,
        attributes: List,
        title: str = None,
        save_path="test.png"
):
    plt.figure(figsize=(15, 6))

    plt.subplot(2, 1, 2)
    ax_attribute = plt.gca()
    x = np.arange(len(coalition_masks))
    y = np.arange(len(attributes))
    plt.xticks(x, [])
    plt.yticks(y, attributes)
    plt.xlim(x.min() - 0.5, x.max() + 0.5)
    plt.ylim(y.min() - 0.5, y.max() + 0.5)
    plt.xlabel("pattern")
    plt.ylabel("attribute")

    patch_colors = {
        True: {
            'pos': 'red',
            'neg': 'blue'
        },
        False: 'gray'
    }
    patch_width = 0.8
    patch_height = 0.9

    for coalition_id in range(len(coalition_masks)):
        coalition = coalition_masks[coalition_id]
        for attribute_id in range(len(attributes)):
            # is_selected = judge_is_selected(attributes[attribute_id], attribute_id, coalition)
            is_selected = coalition[attribute_id]
            if not is_selected:
                facecolor = patch_colors[is_selected]
            else:
                if interactions[coalition_id] > 0: facecolor = patch_colors[is_selected]['pos']
                else: facecolor = patch_colors[is_selected]['neg']
            rect = Rectangle(
                xy=(coalition_id - patch_width / 2,
                    attribute_id - patch_height / 2),
                width=patch_width, height=patch_height,
                edgecolor=None,
                facecolor=facecolor,
                alpha=0.5
            )
            ax_attribute.add_patch(rect)

    plt.subplot(2, 1, 1, sharex=ax_attribute)
    if title is not None:
        plt.title(title)
    plt.ylabel("interaction strength")
    # plt.yscale("log")
    ax_eval = plt.gca()
    plt.setp(ax_eval.get_xticklabels(), visible=False)
    ax_eval.spines['right'].set_visible(False)
    ax_eval.spines['top'].set_visible(False)
    plt.plot(np.arange(len(coalition_masks)), np.abs(interactions))
    plt.hlines(y=0, xmin=0, xmax=len(coalition_masks), linestyles='dotted', colors='red')

    plt.tight_layout()
    plt.savefig(save_path)
    plt.close("all")


def generate_colorbar(ax, cmap_name, x_range, loc, title=""):
    '''
    generate a (fake) colorbar in a matplotlib plot
    :param ax:
    :param cmap_name:
    :param x_range:
    :param loc:
    :param title:
    :return:
    '''
    length = x_range[1] - x_range[0] + 1
    bar_ax = ax.inset_axes(loc)
    bar_ax.set_title(title)
    dummy = np.vstack([np.linspace(0, 1, length)] * 2)
    bar_ax.imshow(dummy, aspect='auto', cmap=plt.get_cmap(cmap_name))
    bar_ax.set_yticks([])
    bar_ax.set_xticks(x_range)


def plot_interaction_progress(interaction, save_path, order_cfg="descending", title=""):
    if not isinstance(interaction, list):
        interaction = [interaction]

    order_first = np.argsort(-interaction[0])

    plt.figure(figsize=(8, 6))
    plt.title(title)

    cmap_name = 'viridis'
    colors = cm.get_cmap(name=cmap_name, lut=len(interaction))
    colors = colors(np.arange(len(interaction)))

    label = None
    for i, item in enumerate(interaction):
        X = np.arange(1, item.shape[0] + 1)
        plt.hlines(0, 0, X.shape[0], linestyles="dotted", colors="red")
        label = f"iter {i+1}" if len(interaction) > 1 else None
        if order_cfg == "descending":
            plt.plot(X, item[np.argsort(-item)], label=label, color=colors[i])
        elif order_cfg == "first":
            plt.plot(X, item[order_first], label=label, color=colors[i])
        else:
            raise NotImplementedError(f"Unrecognized order configuration {order_cfg}.")
        plt.xlabel("patterns (with I(S) descending)")
        plt.ylabel("I(S)")
    # if label is not None: plt.legend()
    plt.tight_layout()
    ax = plt.gca()
    generate_colorbar(
        ax, cmap_name,
        x_range=(0, len(interaction) - 1),
        loc=[0.58, 0.9, 0.4, 0.03],
        title="iteration"
    )
    plt.savefig(save_path, dpi=200)
    plt.close("all")


def plot_pq_progress(pq, chosen_concept_idx, save_path, pq_type):
    pq = np.array(pq)

    plt.figure(figsize=(8, 6))
    for i, idx in enumerate(chosen_concept_idx):
        pq_chosen = pq[:, idx]
        X = np.arange(0, pq_chosen.shape[0] * 1000, 1000)
        plt.plot(X, pq_chosen, label=f"concept {idx}")
    plt.xticks(fontsize=22)
    plt.yticks(fontsize=22)
    plt.legend()
    plt.tight_layout()
    plt.savefig(osp.join(save_path, f"{pq_type}_progress.png"), dpi=200)
    plt.close("all")


def plot_a_progress(a, save_path, pq_type):
    a = np.array(a)

    plt.figure(figsize=(8, 6))
    for idx, i in enumerate(np.arange(a.shape[1])):
        a_k = a[:, i:i+1]
        X = np.arange(0, a_k.shape[0] * 1000, 1000)
        plt.plot(X, a_k, label=f"a_{i}")
    plt.xticks(fontsize=22)
    plt.yticks(fontsize=22)
    plt.legend()
    plt.tight_layout()
    plt.savefig(osp.join(save_path, f"{pq_type}_progress.png"), dpi=200)
    plt.close("all")


def plot_interaction_strength_ratio(interaction, save_path, title="Relationship between # patterns & explain-ratio"):
    strength = np.abs(interaction)
    strength = strength[np.argsort(-strength)]
    total_strength = strength.sum()
    strength = strength / total_strength
    plt.figure()
    cum_strength = np.cumsum(strength)
    plt.plot(np.arange(len(interaction)), cum_strength)

    for thres in [0.7, 0.8, 0.9, 0.95]:
        plt.hlines(y=thres, xmin=0, xmax=len(interaction)-1, linestyles="dashed", colors="red")
        idx = np.where(cum_strength >= thres)[0][0]
        plt.scatter(idx, cum_strength[idx], c="red")
        plt.annotate(f"{idx}", (idx, cum_strength[idx]), zorder=5)

    plt.title(title)
    plt.xlabel(r"# of patterns $S$")
    plt.ylabel(r"ratio")
    plt.tight_layout()
    plt.savefig(save_path, dpi=200, transparent=True)
    plt.close("all")


def plot_interaction_strength_descending(interaction, save_path, title="interaction strength (descending)", standard=None):
    strength = np.abs(interaction)
    strength = strength[np.argsort(-strength)]

    plt.figure()
    plt.plot(np.arange(len(interaction)), strength)
    if standard is not None:
        for r in [1.0, 0.1, 0.05, 0.01]:
            plt.hlines(y=r*standard, xmin=0, xmax=len(interaction)-1, linestyles="dashed", colors="red")
            idx = np.where(strength <= r*standard)[0][0]
            plt.scatter(idx, strength[idx], c="red")
            plt.annotate(f"{idx}", (idx, strength[idx]), zorder=5)
    plt.title(title)
    plt.xlabel(r"# of patterns $S$")
    plt.ylabel(r"$|I(S)|$")
    plt.tight_layout()
    plt.savefig(save_path, dpi=200, transparent=True)
    plt.close("all")


def denormalize_image(image, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
    mean = np.array(mean).reshape((-1, 1, 1))
    std = np.array(std).reshape((-1, 1, 1))
    image = image * std + mean
    return image


def plot_image(image, save_folder, save_name):
    os.makedirs(save_folder, exist_ok=True)
    plt.figure(figsize=(5, 5))
    plt.imshow(image.transpose(1, 2, 0).clip(0, 1))
    plt.axis("off")
    plt.savefig(osp.join(save_folder, f"{save_name}.png"), transparent=True)
    plt.close("all")
