import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, Subset, DataLoader, default_collate

from typing import Callable, List, Optional
from attributors.abstract_classes import DataAttributor, Property
from attributors.utils import flatten_params
from transformers import default_data_collator
from data.utils import process_batch
from tqdm import tqdm

class TracInAttributor(DataAttributor):
    """
    Implement influence functions by assuming Hessian = Identity

    Derived from the technique used in Pruthi et al, "Estimating Training Data Influence
    by Tracing Gradient Descent"
    Paper: https://arxiv.org/pdf/2002.08484
    """
    
    def __init__(self,
                 dataset: Dataset,
                 group_ids: List[int],
                 model: Optional[nn.Module] = None,
                 train_model: Optional[Callable[[Dataset, List[bool]], nn.Module]] = None):
        """
        Initialize Influence Functions.
        """
        super().__init__(dataset, group_ids, train_model=train_model, model=model)

    def compute_group_attributions(self,
                                   property: Property,
                                   train_loss_fn: Callable[[torch.Tensor, torch.Tensor], float],
                                   device: Optional[str] = 'cpu',
                                   use_model_cache: Optional[bool] = False,
                                   verbose: Optional[bool] = True,
                                   temperature: Optional[int] = 1) -> List[float]:
        """
        Compute scalar attributions and return the ranking of groups.
        """
        # Train model if not provided
        if self.model is None:
            self.model = self.train_model(self.dataset, device=device, use_model_cache=use_model_cache, verbose=verbose)
        self.model.to(device)
        flattened_params = flatten_params(self.model)
        self.num_params = len(flattened_params)
        print(f"\nNumber of model parameters: {self.num_params}")
        
        # Compute test property gradients
        if verbose: print("Computing test property gradients...")
        test_property_grads = property.backward(self.model, device=device)

        # Compute phi (gradient of train loss w.r.t. model parameters in projection space per group)
        if verbose: print("Computing scores per group...")
        scores_per_group = self._compute_phi_product(test_property_grads, self.dataset, train_loss_fn, temperature, device)

        return scores_per_group
    
    def _compute_phi_product(self,
                             test_property_grads: torch.Tensor,
                             data_subset: Dataset,
                             loss_fn: Callable[[torch.Tensor, torch.Tensor], float],
                             temperature: int = 1,
                             device: str = 'cpu'):
        
        # Compute phi (gradient of train loss w.r.t. model parameters per group)
        scores_per_group = {}
        for unique_group_id in tqdm(self.unique_group_ids, desc='Computing scores per group'):
            group_indices_to_select = [i for i, group_id in enumerate(self.group_ids) if (group_id == unique_group_id)]
            data_subset = Subset(self.dataset, group_indices_to_select)
            gradients = self._compute_train_grads(data_subset, loss_fn=loss_fn, temperature=temperature, device=device)
            score = -1. * test_property_grads @ gradients
            scores_per_group[unique_group_id] = score.detach().cpu().item()
        return scores_per_group

    def _compute_train_grads(self,
                             data_subset: Dataset,
                             temperature: int,
                             loss_fn: Callable,
                             batch_size: int = None,
                             device: str = 'cpu'):
        
        """
        Function to compute loss gradients and compute using the fisher approximation
        """
        max_batch = 512 if self.dataset.name in ['qnli', 'qnli_noisy'] else 1024
        if batch_size is None:
            batch_size = min(len(data_subset), max_batch)

        collate_fn = default_data_collator if self.dataset.name in ['qnli', 'qnli_noisy'] else default_collate
        dataloader = DataLoader(data_subset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
        out_grads = None
        for batch in dataloader:
            self.model.zero_grad()
            inputs, labels = process_batch(batch, device=device)
            objective = loss_fn(self.model(**inputs) / temperature, labels)
            objective.backward()
            gradients = flatten_params(self.model, gradients=True) * labels.size(0)

            if out_grads is None:
                out_grads = gradients
            else:
                out_grads += gradients 
            torch.cuda.empty_cache()

        return out_grads / len(data_subset)
    