"""
This code is based on the implementation of Herding in the
Avalanche library https://github.com/ContinualAI/avalanche
"""

from abc import ABC
from typing import (
    List,
    TYPE_CHECKING,
)
import torch.nn.functional as F

import numpy as np
import torch
from numpy import inf
from torch import nn
from torch.utils.data import DataLoader

from avalanche.benchmarks.utils import (
    AvalancheDataset,
)

from MERS.sampling_strategies.SelecetionStrategy import pretrained_representations, calculate_weight
from MERS.mers_utils.storage_policy import ExemplarsSelectionStrategy

if TYPE_CHECKING:
    from avalanche.training.templates import SupervisedTemplate


class HerdingSelectionStrategy(ExemplarsSelectionStrategy, ABC):
    """Feature extraction like Typiclust selection strategy, and the examplars'
    selection is by herding."""

    def __init__(self, args):
        self.features = None
        self.seed = args.seed
        self.ss_method = args.features_type
        self.dataset_name = args.dataset
        self.seen_groups = set()
        self.weight_method = args.weight_method
        self.args = args

    def init_features(self, data, model, device):
        inc_classifier = model.linear
        model.linear = nn.Identity()
        model.eval()
        dataloader = DataLoader(data, batch_size=len(data), shuffle=False)
        for batch in dataloader:
            images = batch[0].to(device)  # to get all the images in the dataset
        self.features = model(images).detach().cpu()
        print(f'Features shape: {self.features.shape}')
        model.linear = inc_classifier
        norms = np.linalg.norm(self.features, axis=1, keepdims=True)
        self.features = self.features / np.maximum(norms, 1)
        features = pretrained_representations(data, self.dataset_name, self.ss_method, seed=self.seed,
                                              order=self.args.order, nvidia=self.args.nvidia)
        features = np.vstack(features)
        features = torch.tensor(features)
        self.features_ss = features.cpu().numpy()

    @torch.no_grad()
    def make_sorted_indices(
            self, strategy: "SupervisedTemplate", data: AvalancheDataset
    ) -> List[int]:
        cur_class = list(data.targets.uniques)[0]
        if cur_class in self.seen_groups:  # Buffer was already updated
            return list(range(len(data)))
        self.seen_groups.add(cur_class)

        # Build features (self.features, self.features_ss)
        self.init_features(data, strategy.model, strategy.device)

        device = strategy.device
        dtype = torch.float32

        # Ensure both feature sets are tensors on the correct device/dtype
        f1 = self.features if isinstance(self.features, torch.Tensor) else torch.as_tensor(self.features)
        f2 = self.features_ss if isinstance(self.features_ss, torch.Tensor) else torch.as_tensor(self.features_ss)
        f1 = f1.to(device=device, dtype=dtype)
        f2 = f2.to(device=device, dtype=dtype)
        assert f1.shape[0] == f2.shape[0], "features and features_ss must have the same number of samples"

        # Unit mean directions (centers) for each space
        centers = [
            f1.mean(dim=0),
            f2.mean(dim=0),
        ]

        # Running arithmetic means (unnormalized vectors), one per space
        current_center = [torch.zeros_like(centers[0]), torch.zeros_like(centers[1])]

        features_list = [f1, f2]
        candidate_centers = [None, None]
        selected_indices: List[int] = []

        # ---- sklearn needs CPU NumPy for weighting ----
        weights = [0.5, 0.5]
        max_size = strategy.plugins[1].storage_policy.max_size // len(strategy.plugins[1].storage_policy.seen_groups)
        remaining_features_np = [t.detach().cpu().numpy() for t in features_list]
        calculate_weight(
            weight_method=self.weight_method,
            weights=weights,  # some impls mutate in place
            remaining_features=remaining_features_np,  # <-- CPU NumPy arrays
            max_size=max_size,
        )


        print(f"Feature space 1 shape: {f1.shape}, Feature space 2 shape: {f2.shape}")
        print(f"Weights: {weights}")

        N = f1.shape[0]
        for i in range(N):
            # fresh distances
            distances = torch.zeros(N, dtype=dtype, device=device)

            for j, feats in enumerate(features_list):
                # Candidate center if we were to add each candidate k (row-wise):
                # μ_{i+1} = (i/(i+1)) * μ_i + (1/(i+1)) * x_k
                candidate_centers[j] = current_center[j] * (i / (i + 1.0)) + feats / (i + 1.0)  # (N, Dj)
                l2_distance = pow(candidate_centers[j] - centers[j], 2).sum(dim=1)

                distances += float(weights[j]) * l2_distance

            # Exclude already chosen indices
            if selected_indices:
                distances[torch.as_tensor(selected_indices, device=device, dtype=torch.long)] = float('inf')

            # Select best candidate
            new_index = int(distances.argmin().item())
            selected_indices.append(new_index)
            # Update running (unnormalized) mean with the chosen index in each space
            for j in range(len(features_list)):
                current_center[j] = candidate_centers[j][new_index]

        return selected_indices
