from abc import ABC, abstractmethod
from typing import Generator, NamedTuple

import itertools
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import shapiq
import torch
from matplotlib import pyplot as plt
from torch.nn import functional as F
from tqdm import tqdm
from utils.bits import pack_more_bits, unpack_more_bits
from utils.segmentations import minimal_bounding_box
from typing import Literal

import wandb

import logging
log = logging.getLogger(__name__)

MASK_STRATEGY = Literal["original", "attention"]
BG_STRATEGY = Literal["ignore", "combine", "concat", "bgonly"]


class BreakdownItem(NamedTuple):
    region_id: int | None
    value: float
    delta: float


@torch.no_grad()
def batched(iterator: Generator, batch_size: int = 16, drop_last: bool = False) -> Generator[tuple[torch.Tensor, ...], None, None]:
    """
    Batches an iterator over tuples into tuple of stacked tensors
    """
    batch = []
    for iter in iterator:

        if len(batch) == 0:
            batch = [[] for _ in range(len(iter))]

        for i, item in enumerate(iter):
            batch[i].append(item)

        if len(batch[0]) == batch_size:
            yield tuple([torch.stack(x) for x in batch])

            batch = []

    if drop_last:
        return

    # Yield the last batch if it is not empty
    if len(batch) > 0 and len(batch[0]) > 0:
        yield tuple([torch.stack(x) for x in batch])


def get_player_names(n_players: int, bg_strategy: BG_STRATEGY, bg_start_idx: int = float("inf")) -> list[str]:
    player_names = [f"Region {i}" for i in range(n_players)]
    if bg_strategy == "combine":
        player_names[-1] = "Background"
    elif bg_strategy == "concat":
        player_names[bg_start_idx:] = [
            f"{name} (BG)" for name in player_names[bg_start_idx:]
        ]
    return player_names


class shnapCoalitionEvaluator():
    def __init__(
            self,
            classifier: torch.nn.Module,
            original: torch.Tensor,
            inpaint: torch.Tensor,
            region_seg: torch.Tensor,
            batch_size: int = 1,
            bg_start_idx: int = float("inf"),
            bg_strategy: BG_STRATEGY = "ignore"
        ):
        self.classifier = classifier
        self.original = original.detach() # (1, 1, W, H, D)
        self.inpaint = inpaint.detach() # (1, 1, W, H, D)
        self.region_seg = region_seg.detach() # (W, H, D) Multi-class segmentation mask of regions
        self.batch_size = batch_size
        self.bg_start_idx = bg_start_idx
        self.bg_strategy = bg_strategy

        self.n_players = self.region_seg.max().int().item() # Number of regions
        self.n_values = self.classifier.num_years

        self.cache_mask = torch.zeros(2**self.n_players)
        self.cache = torch.zeros(2**self.n_players, self.n_values)

        self.bounding_boxes = [
            minimal_bounding_box(self.region_seg == i) for i in range(1, self.n_players + 1)
        ]

    @torch.no_grad()
    def create_mask(self, coalition: np.ndarray) -> torch.Tensor:
        # Create a mask of what to keep from the original
        # and what to use from the inpaint
        # 1 means use original, 0 means use inpaint
        mask = torch.zeros_like(self.region_seg, dtype=torch.bool)
        for i, bit in enumerate(coalition):
            if bit == 1:
                mask[self.bounding_boxes[i]] = (self.region_seg[self.bounding_boxes[i]].clone()) > 0
        return mask

    @torch.no_grad()
    def prepare_input(self, coalition: np.ndarray):
        assert (
            coalition.shape[0] == self.n_players
        ), "Coalition must have the same number of players as the evaluator"
        mask = self.create_mask(coalition)
        mask = mask.unsqueeze(0).unsqueeze(0)  # (1, 1, W, H, D)
        input_tensor = self.inpaint.clone()
        input_tensor[mask] = self.original[mask]  # (1, 1, W, H, D)
        input_tensor = input_tensor.repeat(1, 3, 1, 1, 1) # Repeat to match classifier input shape
        return input_tensor

    @torch.no_grad()
    def image_iterator(self, coalitions: np.ndarray):
        assert coalitions.ndim == 2, "Coalitions must be a 2D array"
        coalitions_idxs = pack_more_bits(coalitions).astype(np.int64)
        coalitions_idxs = torch.from_numpy(coalitions_idxs)
        for idx, coalition in zip(coalitions_idxs, coalitions):
            input_img = self.prepare_input(coalition)
            yield idx, input_img.squeeze(0) # it will be batched later


    @torch.inference_mode()
    def compute_values(self, coalitions: np.ndarray):
        coalition_idxs = pack_more_bits(coalitions).astype(np.int64)
        cache_mask_np = (self.cache_mask[coalition_idxs] == 0).cpu().numpy()
        active_coalitions = coalitions[cache_mask_np]
        if len(active_coalitions) == 0:
            return None
        assert active_coalitions.shape[1] == self.n_players, "Coalitions must have the same number of players as the evaluator"

        image_iterator = self.image_iterator(active_coalitions)
        image_iterator = batched(image_iterator, batch_size=self.batch_size)
        for batch_idx, batch_imgs in tqdm(image_iterator, total=len(active_coalitions)//self.batch_size, desc="V-values computation"):
            results = self.classifier.forward_all_years(batch_imgs)
            self.cache[batch_idx] = results
            self.cache_mask[batch_idx] = 1


    @torch.inference_mode()
    def get_values(self, coalitions: np.ndarray, value_idx: int) -> np.ndarray:
        """
        Get the values for the given coalitions.
        If the values are not cached, compute them and cache them.
        """
        coalitions = coalitions.astype(int)
        assert coalitions.shape[1] == self.n_players, "Coalitions must have the same number of players as the evaluator"
        assert value_idx < self.n_values, "Value index must be less than the number of values"
        self.compute_values(coalitions)

        indices = pack_more_bits(coalitions).astype(np.int64)
        indices = torch.from_numpy(indices)
        values = self.cache[indices, value_idx].cpu().numpy()
        return values


    @torch.inference_mode()
    def get_predictions(self) -> np.ndarray:
        """
        Get the predictions for the marginal coalitions.
        """
        marginal_indices = [0] + [
            2**(self.n_players - 1 - i) for i in range(self.n_players)
        ]
        marginal_indices = np.array(marginal_indices, dtype=np.uint32)
        coalitions = unpack_more_bits(marginal_indices, self.n_players)
        self.compute_values(coalitions)
        marginal_indices = marginal_indices.astype(np.int64) # For torch
        return self.cache[marginal_indices, :].cpu().numpy()


    @torch.inference_mode()
    def get_all_values_as_df(self):
        """
        Returns a DataFrame with all cached values.
        """
        indices = self.cache_mask.nonzero(as_tuple=False).squeeze().tolist()
        indices = sorted(indices)
        indices = np.array(indices, dtype=np.uint32)
        coalitions = unpack_more_bits(indices, self.n_players).astype(np.int64)

        indices_torch = torch.tensor(indices.astype(np.int64))
        values = self.cache[indices_torch].cpu().numpy()

        coalition_columns = [f"Region_{i}" for i in range(self.n_players)]
        values_columns = [f"Year_{i}" for i in range(self.n_values)]
        values_columns[0] = "Base risk"
        columns = ["Index"] + coalition_columns + values_columns
        idx_array = indices.reshape(-1, 1)
        coalitions = np.where(coalitions == 1, "Original", "Inpaint")

        data = np.hstack([idx_array, coalitions, values])
        df = pd.DataFrame(data, columns=columns)
        return df


    @torch.inference_mode()
    def sv_quality(self, sv_values: np.ndarray, baseline: np.ndarray) -> np.ndarray:
        """
        Compute the quality of the Shapley values approximation
        on all subsets 

        sv_values: (n_players, n_values) array of Shapley values
        baseline: (n_values,) array of baseline values

        """
        if self.n_players <= 1:
            return np.array([1.0] * self.n_values)

        sv_values = torch.from_numpy(sv_values).to(self.cache.device)  # (n_players, n_values)
        sv_predicted_values = torch.zeros_like(self.cache)

        # Iterate over all coalitions
        for coalition_size in range(1, self.n_players): # [1, n-1]
            for coalition in itertools.combinations(range(self.n_players), coalition_size):
                coalition_array = np.zeros(self.n_players, dtype=np.int8)
                coalition_array[list(coalition)] = 1
                coalition_idx = pack_more_bits(coalition_array.reshape(1, -1)).astype(np.int64)

                # Compute the predicted value for this coalition
                predicted_values = sv_values[coalition_array == 1].sum(axis=0)
                sv_predicted_values[coalition_idx] = predicted_values.to(dtype=sv_predicted_values.dtype)

        sv_predicted_values += torch.from_numpy(baseline).to(self.cache.device).reshape(1, -1)
        y_mean = self.cache.nanmean(axis=0)
        sr = (self.cache - sv_predicted_values) ** 2
        st = (self.cache - y_mean) ** 2
        r2 = 1 - sr.nansum(axis=0) / st.nansum(axis=0)
        return r2.cpu().numpy()


    @torch.inference_mode()
    def breakdown(self, value_idx: int) -> list[BreakdownItem]:
        """
        Returns a breakdown of the prediction coalitions 
        if a form of array of pairs (region_id, new_value)
        """
        current_coalition = np.zeros(self.n_players, dtype=int)
        initial_value = self.get_values(current_coalition.reshape(1, -1), value_idx=value_idx)[0]
        breakdown = [BreakdownItem(None, initial_value, 0)]

        for _ in range(self.n_players):
            modified_coalitions = []
            for i in range(self.n_players):
                if current_coalition[i] == 0:
                    new_coalition = current_coalition.copy()
                    new_coalition[i] = 1
                    modified_coalitions.append((i, new_coalition))
            
            if len(modified_coalitions) == 0:
                break
                
            modified_idxs, modified_coalitions = zip(*modified_coalitions)
            modified_coalitions = np.array(modified_coalitions, dtype=int)
            modified_values = self.get_values(modified_coalitions, value_idx=value_idx)
            reference_value = breakdown[-1].value
            abs_deltas = np.abs(modified_values - reference_value)

            best_idx = np.argmax(abs_deltas)
            best_region = modified_idxs[best_idx]
            best_value = modified_values[best_idx]

            breakdown.append(BreakdownItem(best_region, best_value, best_value - reference_value))
            current_coalition[best_region] = 1

        return breakdown



class shnapGame(shapiq.Game):
    def __init__(self, evaluator: shnapCoalitionEvaluator, value_idx: int = 0):
        self.evaluator = evaluator
        self.value_idx = value_idx

        n_players = evaluator.n_players
        self.player_names = get_player_names(
            n_players,
            evaluator.bg_strategy,
            evaluator.bg_start_idx
        )

        super().__init__(
            n_players=n_players,
            player_names=self.player_names,
            normalization_value=0.0,
        )

    def value_function(self, coalitions):
        """
        Computes the value function for the given coalitions.
        coalitions: (N, M) array of coalitions where N is the number of coalitions and M is the number of players
        """
        values = self.evaluator.get_values(coalitions, value_idx=self.value_idx)
        return values


class ShapIQPlotsFactory():
    def __init__(
        self,
        force_plot: bool = False,
        stacked_bar_plot: bool = False,
        upset_plot: bool = False,
        waterfall_plot: bool = False,
        network_plot: bool = False,
        si_graph_plot: bool = False,
    ):
        self.force_plot = force_plot
        self.stacked_bar_plot = stacked_bar_plot
        self.upset_plot = upset_plot
        self.waterfall_plot = waterfall_plot
        self.network_plot = network_plot
        self.si_graph_plot = si_graph_plot

    def create_plots(self, interaction_values: shapiq.InteractionValues, player_names: list[str], prefix: str =""):
        plots_dict = {}

        if self.force_plot:
            fig = interaction_values.plot_force(feature_names=player_names, show=False)
            plots_dict[f"{prefix}force_plot"] = wandb.Image(fig)
            plt.close(fig)

        if self.stacked_bar_plot:
            fig, _ = interaction_values.plot_stacked_bar(feature_names=player_names, show=False)
            plots_dict[f"{prefix}stacked_bar_plot"] = wandb.Image(fig)
            plt.close(fig)

        if self.upset_plot:
            fig = interaction_values.plot_upset(feature_names=player_names, show=False)
            plots_dict[f"{prefix}upset_plot"] = wandb.Image(fig)
            plt.close(fig)

        if self.waterfall_plot:
            fig, ax = plt.subplots(1)
            ax = interaction_values.plot_waterfall(feature_names=player_names, show=False)
            plots_dict[f"{prefix}waterfall_plot"] = wandb.Image(fig)
            plt.close(fig)

        if (self.network_plot or self.si_graph_plot) and (len(player_names) == 1):
            log.warning("Only one player in the game, skipping interaction plots")
            return plots_dict

        if self.network_plot:
            fig, _ = shapiq.network_plot(
                first_order_values=interaction_values.get_n_order_values(1),
                second_order_values=interaction_values.get_n_order_values(2),
                feature_names=player_names,
                show=False
            )
            plots_dict[f"{prefix}network_plot"] = wandb.Image(fig)
            plt.close(fig)

        if self.si_graph_plot:
            fig, _ = interaction_values.plot_si_graph(feature_names=player_names, show=False)
            plots_dict[f"{prefix}si_plot"] = wandb.Image(fig)
            plt.close(fig)

        return plots_dict


def grouped_data_bar_plot(
    data: torch.Tensor,
    y_label: str,
    title: str,
    regions_start_idx: int = 0,
    bg_start_idx: float = float("inf"),
    bg_strategy: BG_STRATEGY = "ignore"
):
    """
    A generic function to create a grouped bar plot for regions and different years.

    Args:
        data (torch.Tensor): A tensor of shape (N, M) where N is the number of regions and M is the number of years.
        If data is a 1D tensor, it will be treated as a base (without year).
        y_label (str): The label for the y-axis.
        title (str): The title of the plot.
        regions_start_idx (int): The index of the first region in the data.
        bg_start_idx (int): The index of the first background region in the data.
    """
    assert data.ndim in (1, 2), "Data must be a 1D or 2D tensor"
    assert regions_start_idx >= 0, "regions_start_idx must be non-negative"
    assert bg_start_idx >= regions_start_idx, "bg_start_idx must be greater than or equal to regions_start_idx"

    player_names = get_player_names(
        n_players=data.shape[0],
        bg_strategy=bg_strategy,
        bg_start_idx=bg_start_idx
    )
    player_names = (["None"] * regions_start_idx) + player_names

    columns = ("Region", "Year", "Value")
    df_items = []

    if data.ndim == 1:
        for i, value in enumerate(data):
            df_items.append({
                "Region": player_names[i],
                "Year": "Base risk",
                "Value": value.item()
            })
    else:
        for i in range(data.shape[0]):
            for j in range(data.shape[1]):
                df_items.append({
                    "Region": player_names[i],
                    "Year": f"Year {j + 1}",
                    "Value": data[i, j].item()
                })

    df = pd.DataFrame(df_items, columns=columns)
    fig = px.bar(
        df,
        x="Region",
        y="Value",
        color="Year",
        barmode="group",
        title=title,
        labels={"Value": y_label, "Region": "Inpainted region"}
    )
    if "shnap" in y_label.lower():
        max_value = np.abs(data).max() + 0.1
        fig.update_layout(yaxis_range=[-max_value, max_value])
    elif "delta" in y_label.lower():
        fig.update_layout(yaxis_range=[-1, 1])
    else:
        fig.update_layout(yaxis_range=[0, 1])

    return fig


def plot_breakdown(breakdown: list[BreakdownItem], title: str = "Prediction breakdown", scale_by = 1.0):
    """
    Plots a breakdown of the prediction into contributions from each region.
    """
    columns = ("Step", "Region", "Value", "Delta")
    df_items = []
    for i, item in enumerate(breakdown):
        region_name = "Base risk" if item.region_id is None else f"Region {item.region_id}"
        df_items.append({
            "Step": i,
            "Region": region_name,
            "Value": item.value,
            "Delta": item.delta
        })
    df = pd.DataFrame(df_items, columns=columns)
    df["Value"] = df["Value"] * scale_by
    df["Delta"] = df["Delta"] * scale_by

    fig = go.Figure(go.Waterfall(
        name = "20", orientation = "h",
        measure = ["absolute"] + ["relative"] * (len(df) - 1),
        x = df["Delta"].loc[1:],
        y = df["Region"].loc[1:],
        base = df.loc[0, "Value"],
        textposition = "outside",
        text = df["Delta"].loc[1:].round(3).astype(str),
        connector = {"line":{"color":"rgb(63, 63, 63)"}}
    ))
    fig.add_vline(x=df.loc[0, "Value"], line_dash="dot", annotation_text="Baseline value", annotation_position="top right")
    fig.update_layout(
        title = title,
        showlegend = False
    )

    return fig