from abc import ABC, abstractmethod
from typing import Callable, Any

import numpy
import torch

from distances import metric_distance_diversity


class GraphSavers(ABC):
    def __init__(self):
        self._snapshots: list[tuple[int, torch.Tensor]] = []

    def save_maybe(self, it: int, adj_batch: torch.Tensor):
        if self.should_save(it):
            self._save(it, adj_batch)

    @property
    def snapshots(self) -> list[tuple[int, torch.Tensor]]:
        return self._snapshots

    @property
    def graphs(self) -> torch.Tensor | None:
        if len(self.snapshots) == 0:
            return None
        else:
            return torch.cat([batch for _, batch in self.snapshots], dim=0)

    @abstractmethod
    def should_save(self, it: int) -> bool:
        pass

    @abstractmethod
    def _save(self, it: int, adj_batch: torch.Tensor):
        pass


class TrajectoryTracker(GraphSavers):
    def __init__(self, save_every: int = 100):
        super().__init__()
        self._save_every: int = save_every

    def should_save(self, it: int) -> bool:
        return it % self._save_every == 0

    def _save(self, it: int, adj_batch: torch.Tensor):
        self._snapshots.append((it, adj_batch.detach().cpu().clone()))


class PeriodicSaver(GraphSavers):
    def __init__(
        self, num_iters: int, batch_size: int, save_count: int, verbose: bool = False
    ):
        super().__init__()
        self._sample_iterations: numpy.ndarray = numpy.linspace(
            10, num_iters - 1, save_count // batch_size, dtype=int
        )
        if verbose:
            print(
                f"Periodic Sampling of graphs at {len(self._sample_iterations)} iterations specified as {self._sample_iterations}"
            )

    def should_save(self, it: int) -> bool:
        return it in self._sample_iterations

    def _save(self, it: int, adj_batch: torch.Tensor):
        self._snapshots.append((it, adj_batch.detach().cpu().clone()))


class TopSaver(GraphSavers):
    def __init__(
        self,
        num_iters: int,
        batch_size: int,
        save_count: int,
        orca_path: str,
        verbose: bool = False,
    ):
        super().__init__()
        self._sample_iterations: numpy.ndarray = numpy.linspace(
            0, num_iters - 1, save_count, dtype=int
        )
        self.orca_path = orca_path
        if verbose:
            print(
                f"Top Sampling of graphs at {len(self._sample_iterations)} iterations specified as {self._sample_iterations}"
            )

    def should_save(self, it: int) -> bool:
        return it in self._sample_iterations

    def _save(self, it: int, adj_batch: torch.Tensor):
        first_graph = adj_batch[0].detach().unsqueeze(0).cpu().clone()
        graphs = self.graphs
        if graphs is None:
            self._snapshots.append((it, first_graph))
        else:
            top_g = first_graph
            top_dist: float = -1
            for g in adj_batch:
                g = g.detach().unsqueeze_(0).cpu().clone()
                dist, _ = metric_distance_diversity(
                    torch.cat((graphs, g), dim=0), self.orca_path
                )
                if dist > top_dist:
                    top_dist = dist
                    top_g = g

            self._snapshots.append((it, top_g))


class IncSaver(GraphSavers):
    def __init__(
        self,
        num_iters: int,
        batch_size: int,
        save_count: int,
        orca_path: str,
        verbose: bool = False,
    ):
        super().__init__()
        self._sample_iterations: numpy.ndarray = numpy.linspace(
            10, num_iters - 1, save_count, dtype=int
        )
        self.orca_path = orca_path
        if verbose:
            print(
                f"Top Sampling of graphs at {len(self._sample_iterations)} iterations specified as {self._sample_iterations}"
            )

    def should_save(self, it: int) -> bool:
        return it in self._sample_iterations

    def _save(self, it: int, adj_batch: torch.Tensor):
        first_graph = adj_batch[0].detach().unsqueeze(0).cpu().clone()
        graphs = self.graphs
        to_add = []
        if graphs is None:
            graphs = first_graph
            to_add.append(first_graph)

        init_dist = (
            -1
            if graphs.shape[0] == 1
            else metric_distance_diversity(graphs, self.orca_path)[0]
        )
        for g in adj_batch:
            g = g.detach().unsqueeze_(0).cpu().clone()
            dist, _ = metric_distance_diversity(
                torch.cat((graphs, g), dim=0), self.orca_path
            )
            if dist > init_dist:
                init_dist = dist
                to_add.append(g)
                graphs = torch.cat((graphs, g), dim=0)

        if to_add:
            self._snapshots.append((it, torch.cat(to_add, dim=0)))


class RankSaver(GraphSavers):
    def __init__(
        self,
        num_iters: int,
        save_count: int,
        feature_fn: Callable[[torch.Tensor], Any],
        pairwise_fn: Callable[[Any], torch.Tensor],
        loss_fn: Callable[[torch.Tensor, torch.Tensor, int], torch.Tensor],
        device: str,
        verbose: bool = False,
    ):
        super().__init__()
        self._sc = save_count
        self._sample_iterations: numpy.ndarray = numpy.linspace(
            10, num_iters - 1, save_count, dtype=int
        )
        if verbose:
            print(
                f"Top Sampling of graphs at {len(self._sample_iterations)} iterations specified as {self._sample_iterations}"
            )

        self._feature_fn = feature_fn
        self._pairwise_fn = pairwise_fn
        self._loss_fn = loss_fn

        self._saved: torch.Tensor | None = None
        self._features = None
        self._pairwise = None

        masks: list[torch.Tensor] = []
        for idx in range(save_count):
            mask = 1.0 - torch.eye(save_count, device=device)
            mask[idx, :] = 0
            mask[:, idx] = 0
            masks.append(mask)
        masks.append(1.0 - torch.eye(save_count, device=device))
        self._masks: torch.Tensor = torch.stack(masks)

    @property
    def graphs(self) -> torch.Tensor | None:
        if self._saved is not None:
            return self._saved.detach().clone().cpu()
        else:
            return None

    def should_save(self, it: int) -> bool:
        return it in self._sample_iterations

    def _find_weak_link(self) -> int:
        div = (self._sc - 1) * (self._sc - 2)
        weak_link = 0
        weak_loss = self._loss_fn(self._pairwise, self._masks[0], div)
        for idx in range(1, self._sc):
            new_loss = self._loss_fn(self._pairwise, self._masks[idx], div)
            if new_loss < weak_loss:
                weak_loss = new_loss
                weak_link = idx

        return weak_link

    def _replace_weak_link(self, adj_batch: torch.Tensor):
        div = self._sc * (self._sc - 1)
        weak_link = self._find_weak_link()
        pairwise: torch.Tensor = self._pairwise.clone()
        cur_loss = self._loss_fn(self._pairwise, self._masks[-1], div)

        for idx in range(adj_batch.shape[0]):
            g = adj_batch[idx].unsqueeze(0)

            fs: tuple[torch.Tensor, torch.Tensor] = self._feature_fn(g)
            spectral_dist = (self._features[0] - fs[0]).norm(dim=1)
            poly_dist = (self._features[1] - fs[1]).norm(dim=1)
            dist = spectral_dist + 5 * poly_dist
            dist[weak_link] = 0.0

            pairwise[weak_link, :] = dist
            pairwise[:, weak_link] = dist

            new_loss = self._loss_fn(pairwise, self._masks[-1], div)
            if new_loss < cur_loss:
                self._saved[weak_link] = adj_batch[idx]
                self._features[0][weak_link] = fs[0][0]
                self._features[1][weak_link] = fs[1][0]
                self._pairwise = pairwise

                cur_loss = new_loss
                weak_link = self._find_weak_link()
                pairwise: torch.Tensor = self._pairwise.clone()

    def _save(self, it: int, adj_batch: torch.Tensor):
        adj_batch = adj_batch.detach().clone()
        if self._saved is None:
            self._saved = adj_batch
            self._features = self._feature_fn(self._saved)
            self._pairwise = self._pairwise_fn(self._features)
        else:
            self._replace_weak_link(adj_batch)


class StandardDeviationSaver(GraphSavers):
    def __init__(
        self,
        num_iters: int,
        save_count: int,
        device: str,
        verbose: bool = False,
    ):
        super().__init__()
        self._sc = save_count
        self._sample_iterations: numpy.ndarray = numpy.linspace(
            10, num_iters - 1, save_count, dtype=int
        )
        if verbose:
            print(
                f"Top Sampling of graphs at {len(self._sample_iterations)} iterations specified as {self._sample_iterations}"
            )

        self._saved: torch.Tensor | None = None

        masks: list[torch.Tensor] = []
        for idx in range(save_count):
            mask = torch.ones((self._sc), device=device)
            mask[idx] = 0
            masks.append(mask == 1)
        masks.append(torch.ones((self._sc), device=device) == 1)
        self._masks: torch.Tensor = torch.stack(masks)

    @property
    def graphs(self) -> torch.Tensor | None:
        if self._saved is not None:
            return self._saved.detach().clone().cpu()
        else:
            return None

    def should_save(self, it: int) -> bool:
        return it in self._sample_iterations

    def _find_weak_link(self) -> int:
        weak_link = 0
        weak_loss = (self._edges[self._masks[0]]).std()
        for idx in range(1, self._sc):
            new_loss = (self._edges[self._masks[idx]]).std()
            if new_loss > weak_loss:
                weak_loss = new_loss
                weak_link = idx

        return weak_link

    def _replace_weak_link(self, adj_batch: torch.Tensor):
        weak_link = self._find_weak_link()
        pairwise: torch.Tensor = self._edges.clone()
        cur_loss = pairwise.std()

        for idx in range(adj_batch.shape[0]):
            pairwise[weak_link] = adj_batch[idx].sum() / 2

            new_loss = pairwise.std()
            if new_loss > cur_loss:
                self._saved[weak_link] = adj_batch[idx]
                self._edges = pairwise

                cur_loss = new_loss
                weak_link = self._find_weak_link()
                pairwise: torch.Tensor = self._edges.clone()

    def _save(self, it: int, adj_batch: torch.Tensor):
        adj_batch = adj_batch.detach().clone()
        if self._saved is None:
            self._saved = adj_batch
            self._edges = adj_batch.sum(dim=(1, 2)) / 2
        else:
            self._replace_weak_link(adj_batch)
