from typing import List, Callable, Optional, Iterable, Tuple
import torch
import torch.nn as nn
from torch.nn.utils import parameters_to_vector
from torch.utils.data import DataLoader, Dataset, RandomSampler, Subset
from torch.utils.data.dataloader import default_collate

from attributors.abstract_classes import DataAttributor, Property
from attributors.utils import set_attr, del_attr, flatten_params

from transformers import default_data_collator
from data.utils import process_batch

class InfluenceLissaAttributor(DataAttributor):
    """
    Implement Influence Functions for data attribution.
    """
    
    def __init__(self,
                 dataset: Dataset,
                 group_ids: List[int],
                 model: Optional[nn.Module] = None,
                 train_model: Optional[Callable[[Dataset, List[bool]], nn.Module]] = None):
        """
        Initialize Influence Functions.
        """
        super().__init__(dataset, group_ids, train_model=train_model, model=model)

    def compute_group_attributions(self,
                                   property: Property,
                                   train_loss_fn: Callable[[torch.Tensor, torch.Tensor], float],
                                   device: Optional[str] = 'cpu',
                                   use_model_cache: Optional[bool] = False,
                                   verbose: Optional[bool] = True,
                                   damp: Optional[float] = 0.001,
                                   repeat: Optional[int] = 20,
                                   depth: Optional[int] = 200,
                                   scale: Optional[float] = 50,
                                   batch_size: Optional[float] = 1) -> List[int]:
        """
        Compute scalar attributions for LiSSA influence functions and return the ranking of groups.
        """
        # Train model if not provided
        if self.model is None:
            self.model = self.train_model(self.dataset, device=device, use_model_cache=use_model_cache, verbose=verbose)
        self.model.to(device)
        self.params_names = tuple(name for name, _ in self._model_params())
        self.params_shape = tuple(p.shape for _, p in self._model_params())
        print(f"\nNumber of model parameters: {sum(p.numel() for _, p in self._model_params())}")

        # Compute test property gradients
        if verbose: print("Computing test property gradients...")
        test_property_grads = property.backward(self.model, device=device)

        # Compute inverse HVP (H^-1 @ v)
        # TODO: Check/fix device implementation
        if verbose: print("Computing inverse HVP...")
        stest = self._compute_inverse_hvp(train_loss_fn, test_property_grads, damp, repeat, depth, scale, batch_size, device)
        
        # Compute scores per group
        if verbose: print("Computing group attributions...")
        scores_per_group = {}
        for unique_group_id in self.unique_group_ids:
            # Get group as batch
            group_indices_to_select = [i for i, group_id in enumerate(self.group_ids) if group_id == unique_group_id]
            data_subset = Subset(self.dataset, group_indices_to_select)
            gradients = self._compute_train_grads(data_subset, loss_fn=train_loss_fn, device=device)

            # Compute scores for group
            score = -1. * gradients @ stest  # * len(group_indices_to_select) / len(self.dataset)
            scores_per_group[unique_group_id] = score.detach().cpu().item()

        return scores_per_group

    def _compute_train_grads(self,
                             data_subset: Dataset,
                             loss_fn: Callable,
                             batch_size: int = None,
                             device: str = 'cpu'):
        
        """
        Function to compute loss gradients and compute using the fisher approximation
        """
        max_batch = 512 if self.dataset.name in ['qnli', 'qnli_noisy'] else 1024
        if batch_size is None:
            batch_size = min(len(data_subset), max_batch)

        collate_fn = default_data_collator if self.dataset.name in ['qnli', 'qnli_noisy'] else default_collate
        dataloader = DataLoader(data_subset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
        out_grads = None
        for batch in dataloader:
            self.model.zero_grad()
            inputs, labels = process_batch(batch, device=device)
            objective = loss_fn(self.model(**inputs), labels)
            objective.backward()
            gradients = flatten_params(self.model, gradients=True) * labels.size(0)

            if out_grads is None:
                out_grads = gradients
            else:
                out_grads += gradients 
            torch.cuda.empty_cache()

        return out_grads / len(data_subset)
    
    def _compute_inverse_hvp(self, loss_fn, vec, damp, repeat, depth, scale, batch_size, device):
        params = self._model_make_functional()
        flat_params = parameters_to_vector(params)
        ihvp = 0.0
        sampler = RandomSampler(data_source=self.dataset, replacement=True,
                                num_samples=depth)
        collate_fn = default_data_collator if self.dataset.name in ['qnli', 'qnli_noisy'] else lambda x: tuple(x_.to(device) for x_ in default_collate(x))
        train_loader = DataLoader(dataset=self.dataset, batch_size=batch_size,
                                  shuffle=False, sampler=sampler,
                                  collate_fn=collate_fn)

        # Repeat for stability
        for _ in range(repeat):

            # Sum across trials
            h_est = vec.clone()
            for batch in train_loader:
                hvp_batch = self._compute_batch_hvp(loss_fn, flat_params, batch, vec=h_est, device=device)
                with torch.no_grad():
                    hvp_batch += damp * h_est
                    h_est += vec - hvp_batch / scale

            ihvp += h_est / scale

        with torch.no_grad():
            self._model_reinsert_params(self._reshape_like_params(flat_params), register=True)

        return ihvp / repeat
    
    def _compute_batch_hvp(self, loss_fn, flat_params, batch, vec, device):
        def f(theta_):
            reshaped_theta = self._reshape_like_params(theta_)
            self._model_reinsert_params(reshaped_theta)
            inputs, labels = process_batch(batch, device=device)

            with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
                loss = loss_fn(self.model(**inputs), labels)
            return loss
        
        import time
        start_time = time.time()
        hvp = torch.autograd.functional.hvp(f, flat_params, v=vec)[1]
        end_time = time.time()

        hvp_time = end_time - start_time
        print(f"HVP computation time: {hvp_time:.4f} seconds")
        
        return hvp

    def _model_make_functional(self):
        params = tuple(p.detach().requires_grad_() for p in self._model_params(False))
        for name in self.params_names:
            del_attr(self.model, name.split("."))
        return params

    def _model_reinsert_params(self, params, register=False):
        for name, p in zip(self.params_names, params):
            set_attr(self.model, name.split("."), torch.nn.Parameter(p) if register else p)
    
    def _model_params(self, with_names=True):
        return tuple((name, p) if with_names else p for name, p in self.model.named_parameters() if p.requires_grad)

    def _reshape_like_params(self, vec):
        pointer = 0
        split_tensors = []
        for dim in self.params_shape:
            num_param = dim.numel()
            split_tensors.append(vec[pointer: pointer + num_param].view(dim))
            pointer += num_param
        return tuple(split_tensors)