import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, Subset, DataLoader, default_collate

from typing import Callable, List, Optional
from attributors.abstract_classes import DataAttributor, Property
from attributors.utils import flatten_params
from transformers import default_data_collator
from data.utils import process_batch
from tqdm import tqdm

class InfluenceGNAttributor(DataAttributor):
    """
    Implement influence functions via a Fisher approximation of the Hessian
    Discussed in Barshan et al., "RelatIF: Identifying Explanatory Training Examples via Relative Influence"
    Paper: https://arxiv.org/abs/2003.11630

    This is also used to ultimately implement `TrakAttributor`
    """
    
    def __init__(self,
                 dataset: Dataset,
                 group_ids: List[int],
                 model: Optional[nn.Module] = None,
                 train_model: Optional[Callable[[Dataset, List[bool]], nn.Module]] = None):
        """
        Initialize Influence Functions.
        """
        super().__init__(dataset, group_ids, train_model=train_model, model=model)

    def compute_group_attributions(self,
                                   property: Property,
                                   train_loss_fn: Callable[[torch.Tensor, torch.Tensor], float],
                                   device: Optional[str] = 'cpu',
                                   use_model_cache: Optional[bool] = False,
                                   verbose: Optional[bool] = True,
                                   projection_dim: Optional[int] = 32,
                                   temperature: Optional[int] = 1) -> List[float]:
        """
        Compute scalar attributions for Gauss-Newton influence functions and return the ranking of groups.
        """
        # Train model if not provided
        if self.model is None:
            self.model = self.train_model(self.dataset, device=device, use_model_cache=use_model_cache, verbose=verbose)
        self.model.to(device)
        flattened_params = flatten_params(self.model)
        self.num_params = len(flattened_params)
        if verbose: print(f"\nNumber of model parameters: {self.num_params}")
        
        # Compute test property gradients
        if verbose: print("Computing test property gradients...")
        test_property_grads = property.backward(self.model, device=device)

        #  Compute projection matrix
        if projection_dim is not None:
            if verbose:
                print(f"Computing projection matrix of dimension {projection_dim}...")
                expected_mem = self.num_params * projection_dim * 4 / 1e9
                print(f"Expected memory usage for projection matrix: {expected_mem:.2f} GB")
            self.P = torch.randn(projection_dim, self.num_params).to(device)
            # u,s,v = torch.linalg.svd(P, full_matrices=False)
            # self.P = u @ torch.diag(torch.ones_like(s)) @ v
            # self.P = self.P.to(device)
            # Project test property gradients
            test_property_grads = self.P @ test_property_grads

        # Compute phi (gradient of train loss w.r.t. model parameters in projection space per group)
        if verbose: print("Computing projected train gradients...")
        phi = self._compute_phi(self.dataset, projection_dim, train_loss_fn, temperature, device)

        # Fisher estimate of Hessian
        if verbose: print("Computing Fisher estimate of Hessian...")
        hessian_est = phi.t() @ phi

        # Compute scores
        if verbose: print("Computing scores per group...")
        scores = -1. * test_property_grads.t() @ torch.pinverse(hessian_est) @ phi.t()
        scores_per_group = {unique_group_id: scores[i].item() for i, unique_group_id in enumerate(self.unique_group_ids)}
        return scores_per_group

    def _compute_phi(self,
                     data_subset: Dataset,
                     projection_dim: int,
                     loss_fn: Callable[[torch.Tensor, torch.Tensor], float],
                     temperature: int = 1,
                     device: str = 'cpu'):
        # Compute phi (gradient of train loss w.r.t. model parameters in projection space per group)
        phi = []
        for unique_group_id in tqdm(self.unique_group_ids):
            group_indices_to_select =  [i for i, group_id in enumerate(self.group_ids) if (group_id == unique_group_id)]
            data_subset = Subset(self.dataset, group_indices_to_select)
            gradients = self._compute_train_grads(data_subset, projection_dim, loss_fn=loss_fn,
                                                  temperature=temperature, device=device)
            phi.append(gradients)
        return torch.vstack(phi)

    def _compute_train_grads(self,
                             data_subset: Dataset,
                             projection_dim: int,
                             temperature: int,
                             loss_fn: Callable,
                             batch_size: int = None,
                             device: str = 'cpu'):
        
        """
        Function to compute loss gradients and compute using the fisher approximation
        """
        max_batch = 512 if self.dataset.name in ['qnli', 'qnli_noisy'] else 1024
        if batch_size is None:
            batch_size = min(len(data_subset), max_batch)

        collate_fn = default_data_collator if self.dataset.name in ['qnli', 'qnli_noisy'] else default_collate
        dataloader = DataLoader(data_subset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
        out_grads = None
        for batch in dataloader:
            self.model.zero_grad()
            inputs, labels = process_batch(batch, device=device)
            objective = loss_fn(self.model(**inputs) / temperature, labels)
            objective.backward()
            gradients = flatten_params(self.model, gradients=True) * labels.size(0)

            if projection_dim is not None: 
                gradients = self.P @ gradients

            if out_grads is None:
                out_grads = gradients
            else:
                out_grads += gradients 
            torch.cuda.empty_cache()

        return out_grads / len(data_subset)