import copy
import io
from typing import Optional

import torch
from avalanche.evaluation import PluginMetric
from avalanche.evaluation.metric_results import MetricResult, MetricValue
from avalanche.evaluation.metric_utils import get_metric_name
from avalanche.training.templates import SupervisedTemplate
from torch import Tensor, nn


class ParameterSpaceDistance(PluginMetric[Tensor]):
    """
    The ParameterSpaceDistance Metric.

    Instances of this metric measure the distance in parameter space of the current parameters from the initial parameters at each experience.

    Each time `result` is called, this metric measures the distance since the last `reset`.

    The reset method will bring the metric to its initial state. By default
    this metric in its initial state will return None.
    """

    def __init__(self, experience_for_initial_params=0):
        super().__init__()
        self.experience_for_initial_params = experience_for_initial_params
        self.initial_weights: Optional[torch.Tensor] = None
        self.initial_weights_norm: Optional[float] = None
        self.distance: float = 0.0

    def update(self, weights: torch.Tensor):
        """
        Update the distance to the initial weights.

        :param weights: the weight tensor at current experience
        :return: None.
        """
        self.distance = (
            torch.linalg.vector_norm(self.initial_weights - weights, ord=2).item()
            / self.initial_weights_norm
        )  # Measure relative distance

    def result(self) -> float:
        """
        Retrieves the distance to the initial weights.

        :return: The distance to the initial weights.
        """
        return self.distance

    def reset(self) -> None:
        """
        Resets the metric.

        :return: None.
        """
        self.initial_weights = None
        self.initial_weights_norm = None

    def _package_result(self, strategy) -> "MetricResult":
        distance = self.result()
        if distance is None:
            return None

        metric_name = get_metric_name(
            self, strategy, add_experience=True, add_task=False
        )
        return [
            MetricValue(self, metric_name, distance, strategy.clock.train_iterations)
        ]

    def before_training(self, strategy: "SupervisedTemplate") -> "MetricResult":
        if self.experience_for_initial_params == 0:
            self.initial_weights = nn.utils.parameters_to_vector(
                copy.deepcopy(strategy.model).to("cpu").parameters()
            )
            self.initial_weights_norm = torch.linalg.vector_norm(
                self.initial_weights, ord=2
            ).item()

    def after_training_exp(self, strategy: "SupervisedTemplate") -> "MetricResult":
        weights = nn.utils.parameters_to_vector(
            copy.deepcopy(strategy.model).to("cpu").parameters()
        )
        # print("Initial weights:")
        # if self.initial_weights is not None:
        #     print(self.initial_weights)
        #     print(self.initial_weights.shape)
        # else:
        #     print("None")

        # print("Weights:")
        # print(weights)
        # print(weights.shape)

        if self.experience_for_initial_params == 1 and self.initial_weights is None:
            self.initial_weights = weights
            self.initial_weights_norm = torch.linalg.vector_norm(
                self.initial_weights, ord=2
            ).item()
        else:
            self.update(weights)

        return self._package_result(strategy)

    def __str__(self):
        return f"ParameterSpaceDistance(experience_for_initial_params={self.experience_for_initial_params})"


__all__ = ["ParameterSpaceDistance"]
