from abc import ABC, abstractmethod
from typing import Callable, Iterable, Optional, List, Dict
from torch import nn, Tensor
from torch.utils.data import Dataset
from attributors.utils import convert_to_list

class Property(ABC):
    """
    Abstract class to implement properties such as:
    (1) Accuracy
    (2) Loss
    (3) Fairness
    (4) Robustness
    (5) Privacy
    """
    
    def __init__(self, dataset: Dataset, name: str, batch_size: Optional[int] = None):
        """
        Initialize Property
        """
        self.dataset = dataset
        self.name = name
        self.batch_size = len(self.dataset) if batch_size is None else batch_size
    
    @abstractmethod
    def forward(self, model: nn.Module) -> float:
        """
        Compute the property value.
        """
        pass
    
    @abstractmethod
    def backward(self, model: nn.Module) -> Tensor:
        """
        Compute the property gradient.
        """
        pass

class DataAttributor(ABC):
    """
    Abstract class to implement data attribution methods such as:
    (1) Datamodels
    (2) Influence functions
    (3) TRAK
    (4) Leave One Out (LOO)
    (5) Greedy Algorithm
    """
    
    def __init__(self,
                 dataset: Dataset,
                 group_ids: List[int], 
                 train_model: Optional[Callable[[Iterable[Dataset]], nn.Module]] = None, 
                 model: Optional[nn.Module] = None):
        """
        Initialize Data Attributor

        :dataset: torch.utils.data.Dataset object of the training data
        :group_ids: group IDs for each training datapoint
        :train_model: TODO: Refactor this aspect to account for weights and potential caching
        :model: Optional model input. Possible for influence functions, not for data models
        """
        if model is None and train_model is None:
            raise ValueError("Either model or train_model must be provided.")
        self.dataset = dataset
        self.group_ids = convert_to_list(group_ids)
        self.unique_group_ids = list(set(self.group_ids))
        self.train_model = train_model
        self.model = model
    
    @abstractmethod
    def compute_group_attributions(self, property_fn: Callable[[nn.Module], float]) -> Dict[int, float]:
        """
        Computes scalar data attributions per group for a given property function.
        :property_fn: Takes in a model (nn.Module) and returns a scalar (float)comp

        :return: Dict of {group_id: scalar attribution}
        """
        pass

    def compute_group_rankings(self,
                               property_fn: Callable[[nn.Module], float],
                               verbose: bool = False,
                               **kwargs) -> List[int]:
        """
        Compute group rankings based on the property function.
        """
        scores_per_group = self.compute_group_attributions(property_fn, **kwargs)
        if verbose:
            print("\nGroup Scores:")
            for group_id, score in scores_per_group.items():
                print(f"Group {group_id}: {score}")
        return sorted(scores_per_group, key=lambda x: scores_per_group[x], reverse=True)