import os

import numpy as np
import torch

from and_or_harsanyi_utils import *
from interaction_utils import *


class AndHarsanyi(object):
    def __init__(
            self,
            model,
            selected_dim,
            x,
            baseline,
            y,
            grid_width=None,
            calc_bs=None,
            players=None,
            background=None
    ):
        self.model = model
        self.selected_dim = selected_dim
        self.input = x
        self.target = y
        self.baseline = baseline
        self.device = x.device

        self.grid_width = grid_width  # for image data
        self.calc_bs = calc_bs
        self.players = players  # customize players
        if background is None:
            background = []
        self.background = background  # players that always exists (default: emptyset [])

        if grid_width is not None:  # image data
            assert len(x.shape) == 4
            _, image_channel, image_height, image_width = x.shape
            grid_num_h = int(np.ceil(image_height / grid_width))
            grid_num_w = int(np.ceil(image_width / grid_width))
            grid_num = grid_num_h * grid_num_w
            if players is None:
                self.n_dim = grid_num
            else:
                self.n_dim = len(players)
        else:
            self.n_dim = self.input.shape[0]

        print("[AndHarsanyi] Generating v->I^and matrix:")
        self.reward2Iand = get_reward2Iand_mat(self.n_dim).to(self.device)
        print("[AndHarsanyi] Finish.")

        with torch.no_grad():
            self.output_empty = model(self.baseline)
            self.output_N = model(self.input)
        self.v_N = get_reward(self.output_N, self.selected_dim, gt=y)
        self.v_empty = get_reward(self.output_empty, self.selected_dim, gt=y)

    def attribute(self):
        with torch.no_grad():
            self.masks, outputs = calculate_all_subset_outputs(
                self.model, self.input, self.baseline,
                grid_width=self.grid_width, calc_bs=self.calc_bs,
                all_players=self.players, background=self.background
            )
        self.rewards = get_reward(outputs, self.selected_dim, gt=self.target)
        self.Iand = torch.matmul(self.reward2Iand, self.rewards)

    def save(self, save_folder):
        os.makedirs(save_folder, exist_ok=True)
        masks = self.masks.cpu().numpy()
        rewards = self.rewards.cpu().numpy()
        np.save(osp.join(save_folder, "rewards.npy"), rewards)
        np.save(osp.join(save_folder, "masks.npy"), masks)
        Iand = self.Iand.cpu().numpy()
        np.save(osp.join(save_folder, "Iand.npy"), Iand)

    def get_interaction(self):
        return self.Iand


class AndOrHarsanyi(object):
    def __init__(
            self,
            model,
            selected_dim,
            x,
            baseline,
            y,
            grid_width=None,
            calc_bs=None,
            players=None,
            background=None
    ):
        self.model = model
        self.selected_dim = selected_dim
        self.input = x
        self.target = y
        self.baseline = baseline
        self.device = x.device

        self.grid_width = grid_width  # for image data
        self.calc_bs = calc_bs
        self.players = players  # customize players
        self.background = background  # players that always exists (default: emptyset [])

        if grid_width is not None:  # image data
            assert len(x.shape) == 3
            image_channel, image_height, image_width = x.shape
            grid_num_h = int(np.ceil(image_height / grid_width))
            grid_num_w = int(np.ceil(image_width / grid_width))
            grid_num = grid_num_h * grid_num_w
            if players is None:
                self.n_dim = grid_num
            else:
                self.n_dim = len(players)
        else:
            self.n_dim = self.input.shape[0]

        print("[AndOrHarsanyi] Generating v->I^and matrix:")
        self.reward2Iand = get_reward2Iand_mat(self.n_dim).to(self.device)
        print("[AndOrHarsanyi] Generating v->I^or matrix:")
        self.reward2Ior = get_reward2Ior_mat(self.n_dim).to(self.device)
        print("[AndOrHarsanyi] Finish.")

        with torch.no_grad():
            self.output_empty = model(self.baseline[None, ...])
            self.output_N = model(self.input[None, ...])
        self.v_N = get_reward(self.output_N, self.selected_dim, gt=y)
        self.v_empty = get_reward(self.output_empty, self.selected_dim, gt=y)

    def attribute(self):
        with torch.no_grad():
            self.masks, outputs = calculate_all_subset_outputs(
                self.model, self.input, self.baseline,
                grid_width=self.grid_width, calc_bs=self.calc_bs,
                all_players=self.players, background=self.background
            )
        self.rewards = get_reward(outputs, self.selected_dim, gt=self.target)
        self.Iand = torch.matmul(self.reward2Iand, self.rewards)
        self.Ior = torch.matmul(self.reward2Ior, self.rewards)

    def save(self, save_folder):
        os.makedirs(save_folder, exist_ok=True)
        masks = self.masks.cpu().numpy()
        rewards = self.rewards.cpu().numpy()
        np.save(osp.join(save_folder, "rewards.npy"), rewards)
        np.save(osp.join(save_folder, "masks.npy"), masks)
        Iand = self.Iand.cpu().numpy()
        Ior = self.Ior.cpu().numpy()
        np.save(osp.join(save_folder, "Iand.npy"), Iand)
        np.save(osp.join(save_folder, "Ior.npy"), Ior)

    def get_interaction(self):
        return 0.5 * self.Iand, 0.5 * self.Ior

    def get_and_interaction(self):
        return self.Iand

    def get_or_interaction(self):
        return self.Ior


class AndOrHarsanyiSparsifier(object):
    def __init__(
            self,
            calculator: dict,
            trick: str,
            loss: str,
            qthres: float,
            qstd: str,
            lr: float,
            niter: int,
            weight: int,
            alr: int,
    ):
        self.calculator = calculator
        self.trick = trick
        self.loss = loss
        self.qthres = qthres
        self.qstd = qstd
        self.lr = lr
        self.niter = niter
        self.weight = weight
        self.alr = alr

        self.p = None
        self.q = None
        self.a = None
        self.q_bound = None

    def _init_q_bound(self):
        self.standard = None
        if self.qstd == "none":
            self.q_bound = self.qthres
            return

        if self.qstd == "vS":
            standard = self.calculator["rewards"].clone()
        elif self.qstd == "vS-v0":
            standard = self.calculator["rewards"] - self.calculator["v_empty"]
        elif self.qstd == "vN":
            standard = self.calculator["v_N"].clone()
        elif self.qstd == "vN-v0":
            standard = self.calculator["v_N"] - self.calculator["v_empty"]
        elif self.qstd == "maxvS":
            standard, _ = torch.max(torch.abs(self.calculator["rewards"]), dim=0)
        elif self.qstd == "maxvS-v0":
            standard = torch.max(torch.abs(self.calculator["rewards"] - self.calculator["v_empty"]), dim=0)[0]
        elif self.qstd == "maxIs_mean":
            standard = self.calculator["max_strength_mean"].clone()
        elif self.qstd == "vN_vEmpty_mean":
            standard = self.calculator["vN_vEmpty_mean"].clone()
        else:
            raise NotImplementedError(f"Invalid qstd: {self.qstd}")

        self.standard = torch.abs(standard)
        self.q_bound = self.qthres * self.standard
        # print("qthres: ", self.qthres)
        # print("standard: ", self.standard)
        # print("q_bound: ", self.q_bound)
        # print(torch.max(self.calculator["rewards"] -
        #       self.calculator["rewards_mean_k"][self.calculator["rewards_mean_ids"].to(torch.int64)] -
        #       self.calculator["v_empty_"]), torch.min(self.calculator["rewards"] -
        #       self.calculator["rewards_mean_k"][self.calculator["rewards_mean_ids"].to(torch.int64)] -
        #       self.calculator["v_empty_"])
        #       )
        # # print(self.q_bound / (torch.max(self.calculator["rewards"]) - torch.min(self.calculator["rewards"])))
        # exit()

        return

    def sparsify(self, verbose_folder=None):
        if self.trick == "p":
            p, losses, progresses, progresses_pq = train_p(
                rewards=self.calculator["rewards"],
                loss_type=self.loss,
                lr=self.lr,
                niter=self.niter
            )
            self.p = p.clone()
        elif self.trick == "pq":
            self._init_q_bound()
            p, q, losses, progresses, progresses_param = train_p_q(
                rewards=self.calculator["rewards"],
                masks=self.calculator["masks"],
                weight=self.weight,
                loss_type=self.loss,
                lr=self.lr,
                niter=self.niter,
                qbound=self.q_bound,
                reward2Iand=self.calculator["reward2Iand"],
                reward2Ior=self.calculator["reward2Ior"]
            )
            self.p = p.clone()
            self.q = q.clone()
        elif self.trick == "pqa":
            self._init_q_bound()
            p, q, a, losses, progresses, progresses_param = train_p_q_a(
                rewards=self.calculator["rewards"],
                alr=self.alr,
                weight=self.weight,
                rewards_mean_ids=self.calculator["rewards_mean_ids"],
                rewards_mean_k=self.calculator["rewards_mean_k"],
                masks=self.calculator["masks"],
                loss_type=self.loss,
                lr=self.lr,
                niter=self.niter,
                qbound=self.q_bound,
                reward2Iand=self.calculator["reward2Iand"],
                reward2Ior=self.calculator["reward2Ior"]
            )
            self.p = p.clone()
            self.q = q.clone()
            self.a = a.clone()
        else:
            raise NotImplementedError(f"Invalid trick: {self.trick}")

        self._calculate_interaction()

        if verbose_folder is None:
            return

        for k in losses.keys():
            plot_simple_line_chart(
                data=losses[k], xlabel="iteration", ylabel=f"{k}", title="",
                save_folder=verbose_folder, save_name=f"{k}_curve_optimize_p_q_a"
            )
        for k in progresses.keys():
            plot_interaction_progress(
                interaction=progresses[k], save_path=osp.join(verbose_folder, f"{k}_progress_optimize_p_q.png"),
                order_cfg="descending", title=f"{k} progress during optimization"
            )

        # add by zhouhuilin: 随机选取10个concept，绘制训练过程中pq的变化
        random.seed(int(self.calculator["idx"]))
        chosen_concept_idx = random.sample(range(self.calculator["rewards"].shape[0]), 10)
        for k in progresses_param.keys():
            if k == "p" or k == "q":
                plot_pq_progress(
                    pq=progresses_param[k], chosen_concept_idx=chosen_concept_idx, save_path=verbose_folder,
                    pq_type=f"{k}")
            else:
                plot_a_progress(a=progresses_param[k], save_path=verbose_folder, pq_type=f"{k}")

        with open(osp.join(verbose_folder, "log.txt"), "w") as f:
            f.write(f"trick: {self.trick} | loss: {self.loss} | lr: {self.lr} | niter: {self.niter}\n")
            f.write(f"for [q] -- threshold: {self.qthres} | standard: {self.qstd}\n")
            if self.q_bound is not None and self.q_bound.numel() < 20:
                f.write(f"\t[q] bound: {self.q_bound}")
            f.write(f"\tSum of I^and and I^or: {torch.sum(self.Iand) + torch.sum(self.Ior)}\n")
            f.write(f"\tSum of I^and: {torch.sum(self.Iand)}\n")
            f.write(f"\tSum of I^or: {torch.sum(self.Ior)}\n")
            f.write(f"\t|I^and|+|I^or|: {torch.sum(torch.abs(self.Iand)) + torch.sum(torch.abs(self.Ior)).item()}\n")
            f.write("\tDuring optimizing,\n")
            for k, v in losses.items():
                f.write(f"\t\t{k}: {v[0]} -> {v[-1]}\n")

    def _calculate_interaction(self):
        rewards = self.calculator["rewards"]
        if self.trick == "p":
            self.Iand = torch.matmul(self.calculator["reward2Iand"], 0.5 * rewards + self.p).detach()
            self.Ior = torch.matmul(self.calculator["reward2Ior"], 0.5 * rewards - self.p).detach()
        elif self.trick == "pq":
            self.Iand = torch.matmul(self.calculator["reward2Iand"], 0.5 * (rewards + self.q) + self.p).detach()
            self.Ior = torch.matmul(self.calculator["reward2Ior"], 0.5 * (rewards + self.q) - self.p).detach()
        elif self.trick == "pqa":
            self.Iand = torch.matmul(self.calculator["reward2Iand"],
                                     0.5 * ((rewards - self.a[self.calculator["rewards_mean_ids"].to(torch.int64)])
                                            + self.q) + self.p).detach()
            self.Ior = torch.matmul(self.calculator["reward2Ior"],
                                    0.5 * ((rewards - self.a[self.calculator["rewards_mean_ids"].to(torch.int64)])
                                           + self.q) - self.p).detach()
        else:
            raise NotImplementedError(f"Invalid trick: {self.trick}")

    def save(self, save_folder):
        Iand = self.Iand.cpu().numpy()
        Ior = self.Ior.cpu().numpy()
        np.save(osp.join(save_folder, "Iand.npy"), Iand)
        np.save(osp.join(save_folder, "Ior.npy"), Ior)
        p = self.p.cpu().numpy()
        np.save(osp.join(save_folder, "p.npy"), p)
        if self.q is not None:
            q = self.q.cpu().numpy()
            np.save(osp.join(save_folder, "q.npy"), q)
        if self.a is not None:
            a = self.a.cpu().numpy()
            np.save(osp.join(save_folder, "a_k.npy"), a)

    def get_interaction(self):
        return self.Iand, self.Ior
