from abc import ABC, abstractmethod

import numpy as np
import torch
from joblib import Parallel, delayed
from tqdm import tqdm

from graphsmodel.training import train_subset
from graphsmodel.utils import get_margin_incorrect_vectorized


class NodeInfluence(ABC):
    def __init__(self, cfg, data, top_k, train_signal_value, n_samples, results_dir):

        self.cfg = cfg
        self.data = data
        self.top_k = top_k
        self.train_signal_value = train_signal_value
        self.results_dir = results_dir
        self.n_samples = n_samples

    @abstractmethod
    def compute_subset_res(self, argsrt, k):
        pass

    def compute_perf_on_subset(self, k, metric_name, argsrt, mask):
        if self.train_signal_value:
            _, _, _, logits = self.compute_subset_res(argsrt, k)
            y = self.data.y
        else:
            _, logits, y = self.compute_subset_res(argsrt, k)

        margins = get_margin_incorrect_vectorized(logits.unsqueeze(0), y).mean(1)
        margins_ma = np.ma.masked_invalid(margins)
        perf_margins = margins_ma[:, mask]
        perf = (perf_margins > 0).mean(-1)
        return metric_name, perf

    def run_experiment(self, results_dict):
        results = Parallel(n_jobs=self.cfg.n_jobs)(
            delayed(self.compute_perf_on_subset)(
                k, metric_name, metric_dict["argsrt"], metric_dict["mask"]
            )
            for k in tqdm(range(1, self.top_k + 1))
            for metric_name, metric_dict in results_dict.items()
        )
        for metric_name, perf in results:
            results_dict[metric_name]["perf"].append(perf)

        torch.save(
            results_dict,
            self.results_dir / f"{self.filename}.pt",
        )
        return results_dict


class MostInfluentialRemoval(NodeInfluence):

    def __init__(
        self,
        cfg,
        data,
        top_k,
        train_signal_value,
        n_samples,
        results_dir,
        filename_prefix="",
    ):
        super().__init__(cfg, data, top_k, train_signal_value, n_samples, results_dir)
        self.filename = "most_influential_removal_results"
        if self.train_signal_value:
            self.filename = "train_signal_" + self.filename
        if filename_prefix != "":
            self.filename = filename_prefix + "_" + self.filename

    def subset_topk_removal(self, subset, argsrt, k):
        subset[argsrt[:k].tolist()] = False
        return train_subset(
            subset_idx=k,
            subset=subset,
            cfg=self.cfg,
            data=self.data,
            logits_on_data=self.train_signal_value,
        )

    def compute_subset_res(self, argsrt, k):
        return self.subset_topk_removal(np.ones(self.n_samples, dtype=bool), argsrt, k)


class MostInfluentialAddition(NodeInfluence):
    def __init__(self, cfg, data, top_k, train_signal_value, n_samples, results_dir):
        super().__init__(cfg, data, top_k, train_signal_value, n_samples, results_dir)
        self.filename = "most_influential_addition_results"
        if self.train_signal_value:
            self.filename = "train_signal_" + self.filename

    def subset_topk_addition(self, subset, argsrt, k):
        subset[argsrt[:k].tolist()] = True
        return train_subset(
            subset_idx=k,
            subset=subset,
            cfg=self.cfg,
            data=self.data,
            logits_on_data=self.train_signal_value,
        )

    def compute_subset_res(self, argsrt, k):
        return self.subset_topk_addition(
            np.zeros(self.n_samples, dtype=bool), argsrt, k
        )


class LeastInfluentialAddition(NodeInfluence):
    def __init__(self, cfg, data, top_k, train_signal_value, n_samples, results_dir):
        super().__init__(cfg, data, top_k, train_signal_value, n_samples, results_dir)
        self.filename = "least_influential_addition_results"
        if self.train_signal_value:
            self.filename = "train_signal_" + self.filename

    def subset_topk_addition(self, subset, argsrt, k):
        subset[argsrt[:k].tolist()] = True
        return train_subset(
            subset_idx=k,
            subset=subset,
            cfg=self.cfg,
            data=self.data,
            logits_on_data=self.train_signal_value,
        )

    def compute_subset_res(self, argsrt, k):
        return self.subset_topk_addition(
            np.zeros(self.n_samples, dtype=bool), argsrt, k
        )


class LeastInfluentialRemoval(NodeInfluence):
    def __init__(self, cfg, data, top_k, train_signal_value, n_samples, results_dir):
        super().__init__(cfg, data, top_k, train_signal_value, n_samples, results_dir)
        self.filename = "least_influential_removal_results"
        if self.train_signal_value:
            self.filename = "train_signal_" + self.filename

    def subset_topk_removal(self, subset, argsrt, k):
        subset[argsrt[:k].tolist()] = False
        return train_subset(
            subset_idx=k,
            subset=subset,
            cfg=self.cfg,
            data=self.data,
            logits_on_data=self.train_signal_value,
        )

    def compute_subset_res(self, argsrt, k):
        return self.subset_topk_removal(np.ones(self.n_samples, dtype=bool), argsrt, k)
