import torch
import torch.nn as nn
from torch.utils.data import Dataset, Subset

from typing import Optional, Callable, List
from attributors.abstract_classes import DataAttributor
from attributors.influence_gauss_newton import InfluenceGNAttributor
from attributors.utils import array_in_list

import numpy as np

class TrakAttributor(DataAttributor):
    """
    
    Implementation of Park et al., "TRAK: Attributing Model Behavior at Scale"
    Paper: https://arxiv.org/abs/2303.14186
    Code: https://github.com/MadryLab/trak

    """
    
    def __init__(self,
                 dataset: Dataset, 
                 group_ids: List[int], 
                 train_model: Callable[[Dataset, List[bool]], nn.Module]):
        """
        Initialize TRAK.
        """
        super().__init__(dataset, group_ids, train_model=train_model)

    def train_subset_models(self,
                            num_subsets: Optional[int] = 5,
                            subsampling_frac: Optional[float] = 0.75,
                            use_model_cache: Optional[bool] = False,
                            device: Optional[str] = 'cpu',
                            verbose: bool = True):
        """
        Train models on random subsets of the data
        """
        # Groups sampled
        groups_sampled = []
        datapoint_weights_list = []
        model_list = []

        for subset_id in range(num_subsets):
            group_weights = np.zeros_like(len(self.unique_group_ids))
            
            # Avoid repeating the group weights previously sampled or selecting no groups
            while (array_in_list(group_weights, groups_sampled)) or (np.sum(group_weights) == 0): 
                group_weights = np.random.binomial(n=1, p=subsampling_frac, size=len(self.unique_group_ids))

            # Convert group weights to per datapoint weights
            per_datapoint_weights = [1]*len(self.group_ids)
            for i, group_id in enumerate(self.group_ids):
                if group_weights[self.unique_group_ids.index(group_id)] == 0:
                    per_datapoint_weights[i] = 0
            datapoint_weights_list.append(per_datapoint_weights)

            # Train the model
            with torch.no_grad():
                with torch.enable_grad(): 
                    if verbose:
                        print(f"\nTraining model {subset_id} on subset {list(group_weights)}")       
                    model = self.train_model(self.dataset, per_datapoint_weights,
                                             use_model_cache=use_model_cache, device=device, verbose=verbose)
            model_list.append(model)
            
        return model_list, datapoint_weights_list
    
    def compute_group_attributions(self,
                             property_fn: Callable[[nn.Module], float],
                             train_loss_fn: Callable[[torch.Tensor, torch.Tensor], float],
                             use_model_cache: Optional[bool] = False,
                             device: Optional[str] = 'cpu',
                             num_subsets: Optional[int] = 5,
                             subsampling_frac: Optional[float] = 0.75,
                             projection_dim: Optional[int] = 32,
                             soft_thresh_param: Optional[float] = 0.001) -> List[float]:
        """
        Compute scalar attributions for TRAK and return the ranking of groups.
        TRAK Method:
        1. Train models on random subsets of the data
        2. Compute attributions for each group within each subset using InfluenceGNAttributor
        3. Aggregate attributions per group across all subsets
        """

        model_list, datapoint_weights_list = self.train_subset_models(num_subsets, subsampling_frac,
                                                                      use_model_cache=use_model_cache,
                                                                      device=device)
        scores_per_group = {unique_group_id: 0 for unique_group_id in self.unique_group_ids}
        for (m, weights) in zip(model_list, datapoint_weights_list):
            group_indices_to_select =  [index for index, value in enumerate(weights) if (value == 1)]
            data_subset = Subset(self.dataset, group_indices_to_select)

            # Get attributions for present groups
            group_id = [self.group_ids[i] for i in group_indices_to_select]
            inf = InfluenceGNAttributor(dataset = data_subset, group_ids=group_id, model=m)
            score_per_group = inf.compute_group_attributions(property_fn=property_fn,
                                                             train_loss_fn=train_loss_fn,
                                                             use_model_cache=use_model_cache,
                                                             device=device,
                                                             projection_dim=projection_dim)
            
            # Aggregate attributions
            for group_id, score in score_per_group.items():
                scores_per_group[group_id] += score

        # Soft threshold scores < soft_thresh_param to 0
        for group_id, score in scores_per_group.items():
            if abs(score) < soft_thresh_param:
                scores_per_group[group_id] = 0
            
        return scores_per_group