import torch
import numpy as np
import logging
from copy import deepcopy
from typing import List
from src.server.strategies.valuations import ClientValuation

logger = logging.getLogger(__name__)


class ShapleyMC(ClientValuation):
    def __init__(self, *args, similarity_metric="cosine", **kwargs):
        super().__init__(*args, **kwargs)
        self.device = kwargs.get("device", torch.device("cuda" if torch.cuda.is_available() else "cpu"))
        self.similarity_metric = similarity_metric.lower()
        self.model.to(self.device)

    def _shared_indices(self, weights_list: List[List[np.ndarray]]) -> List[int]:
        base_shapes = [w.shape for w in weights_list[0]]
        shared = []
        for i, shape in enumerate(base_shapes):
            if all(len(w) > i and w[i].shape == shape for w in weights_list):
                shared.append(i)
        return shared

    def _flatten_selected_weights(self, weights: List[np.ndarray], indices: List[int]) -> torch.Tensor:
        return torch.cat([torch.tensor(weights[i].flatten(), device=self.device) for i in indices])

    def _cosine_similarity(self, a: torch.Tensor, b: torch.Tensor) -> float:
        if a.norm() == 0 or b.norm() == 0:
            return 0.0
        return float(torch.nn.functional.cosine_similarity(a, b, dim=0).item())

    def _euclidean_similarity(self, a: torch.Tensor, b: torch.Tensor) -> float:
        return 1.0 / (1.0 + torch.norm(a - b).item())

    def _dot_alignment(self, a: torch.Tensor, b: torch.Tensor) -> float:
        return float(torch.dot(a, b) / (a.norm() * b.norm() + 1e-8))

    def _entropy_score(self, a: torch.Tensor) -> float:
        abs_weights = torch.abs(a)
        probs = abs_weights / abs_weights.sum()
        return -torch.sum(probs * torch.log(probs + 1e-10)).item()

    def _pairwise_avg_similarity(self, client_vecs: List[torch.Tensor], i: int) -> float:
        total = 0.0
        for j, v in enumerate(client_vecs):
            if j == i:
                continue
            total += self._cosine_similarity(client_vecs[i], v)
        return total / (len(client_vecs) - 1)

    def evaluate(
        self,
        current_weights: List[np.ndarray],
        weights_1: List[List[np.ndarray]],
        client_samples: List[int],
        **kwargs
    ) -> dict:
        logger.info(f"[ShapleyMC] Using similarity metric: {self.similarity_metric}")

        total_samples = sum(client_samples)
        num_clients = len(weights_1)

        shared_idxs = self._shared_indices([current_weights] + weights_1)
        if not shared_idxs:
            logger.warning("No shared weights among clients. Returning zero similarity.")
            return {f"Shapley_{self.similarity_metric}": [0.0 for _ in range(num_clients)]}

        client_vecs = [self._flatten_selected_weights(w, shared_idxs) for w in weights_1]
        global_vec = sum([v * (n / total_samples) for v, n in zip(client_vecs, client_samples)])

        scores = []
        for i, vec in enumerate(client_vecs):
            if self.similarity_metric == "cosine":
                score = self._cosine_similarity(vec, global_vec)
            elif self.similarity_metric == "euclidean":
                score = self._euclidean_similarity(vec, global_vec)
            elif self.similarity_metric == "alignment":
                score = self._dot_alignment(vec, global_vec)
            elif self.similarity_metric == "entropy":
                score = self._entropy_score(vec)
            elif self.similarity_metric == "pairwise":
                score = self._pairwise_avg_similarity(client_vecs, i)
            else:
                raise ValueError(f"Unknown similarity metric: {self.similarity_metric}")
            scores.append(score)

        return {f"Shapley_{self.similarity_metric}": scores}
