from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from typing import Any, Callable, Dict, List, Optional, Union

    from dattri.task import AttributionTask


import torch
from torch.func import vmap
from tqdm import tqdm

from dattri.func.projection import random_project
from dattri.func.utils import _unflatten_params

from dattri.algorithm.base import BaseAttributor
from dattri.algorithm.utils import _check_shuffle

DEFAULT_PROJECTOR_KWARGS = {
    "proj_dim": 512,
    "proj_max_batch_size": 32,
    "proj_seed": 0,
    "device": "cpu",
    "use_half_precision": False,
}


class IFFIMAttributor(BaseAttributor):
    def __init__(
        self,
        task_f: AttributionTask,
        task_l: AttributionTask,
        correct_probability_func: Callable,
        projector_kwargs: Optional[Dict[str, Any]] = None,
        layer_name: Optional[Union[str, List[str]]] = None,
        device: str = "cpu",
        regularization: float = 0.0,
    ) -> None:
        self.task_f = task_f
        self.task_l = task_l
        self.projector_kwargs = DEFAULT_PROJECTOR_KWARGS
        if projector_kwargs is not None:
            self.projector_kwargs.update(projector_kwargs)
        self.layer_name = layer_name
        self.device = device
        self.grad_target_func = self.task_f.get_grad_target_func(in_dims=(None, 0))
        self.grad_loss_func = self.task_l.get_grad_loss_func(in_dims=(None, 0))
        self.correct_probability_func = vmap(
            correct_probability_func,
            in_dims=(None, 0),
            randomness="different",
        )
        self.full_train_dataloader = None
        self.regularization = regularization

    def calc_grad_p(
        self, grad_t: torch.Tensor, batch_size: int, ckpt_idx: int
    ) -> torch.Tensor:
        grad_t = torch.nan_to_num(grad_t) / (self.num_train**0.5)
        if self.projector_kwargs["proj_dim"] is None:
            grad_p = grad_t
        else:
            grad_t /= self.projector_kwargs["proj_dim"] ** 0.5
            batch_size = grad_t.shape[0]
            grad_p = (
                random_project(
                    grad_t,
                    batch_size,
                    **self.projector_kwargs,
                )(grad_t, ensemble_id=ckpt_idx)
                .clone()
                .detach()
            )
        return grad_p

    def cache(
        self,
        full_train_dataloader: torch.utils.data.DataLoader,
    ) -> None:
        _check_shuffle(full_train_dataloader)

        self.num_train = 0
        for batch in full_train_dataloader:
            if isinstance(batch, (tuple, list)):
                bs = len(batch[0])
            else:
                bs = len(batch)
            self.num_train += bs

        # print('Num train: ', self.num_train)
        # self.norm_scaler *= (self.num_train ** 0.5)

        self.kernels = []
        self.inv_kernels = []
        self.train_grads = []

        self.full_train_dataloader = full_train_dataloader
        inv_XTX_XT_list = []
        for ckpt_idx in range(len(self.task_f.get_checkpoints())):
            parameters, _ = self.task_f.get_param(
                ckpt_idx=ckpt_idx,
                layer_name=self.layer_name,
            )
            full_parameters, _ = self.task_f.get_param(ckpt_idx=ckpt_idx)
            if self.layer_name is not None:
                self.grad_target_func = self.task_f.get_grad_target_func(
                    in_dims=(None, 0),
                    layer_name=self.layer_name,
                    ckpt_idx=ckpt_idx,
                )
                self.grad_loss_func = self.task_l.get_grad_loss_func(
                    in_dims=(None, 0),
                    layer_name=self.layer_name,
                    ckpt_idx=ckpt_idx,
                )

            full_train_projected_grad_l = []
            for train_data in tqdm(
                self.full_train_dataloader,
                desc="calculating gradient of training set...",
                leave=False,
            ):
                if isinstance(train_data, (tuple, list)):
                    train_batch_data = tuple(
                        data.to(self.device) for data in train_data
                    )
                else:
                    train_batch_data = train_data

                grad_t = self.grad_loss_func(parameters, train_batch_data)
                batch_size = grad_t.shape[0]
                grad_p = self.calc_grad_p(grad_t, batch_size, ckpt_idx)
                full_train_projected_grad_l.append(grad_p)
            full_train_projected_grad_l = torch.cat(full_train_projected_grad_l, dim=0)

            self.train_grads.append(full_train_projected_grad_l.clone().nan_to_num())

            kernel_matrix = full_train_projected_grad_l.T @ full_train_projected_grad_l

            self.kernels.append(kernel_matrix.clone())

            kernel_matrix.diagonal().add_(self.regularization)

            inv_XTX_XT = torch.linalg.inv(kernel_matrix)
            self.inv_kernels.append(inv_XTX_XT)

            inv_XTX_XT = inv_XTX_XT @ full_train_projected_grad_l.T

            inv_XTX_XT_list.append(inv_XTX_XT)
        self.inv_XTX_XT_list = inv_XTX_XT_list

    def attribute(  # noqa: PLR0912, PLR0914, PLR0915
        self,
        test_dataloader: torch.utils.data.DataLoader,
    ) -> torch.Tensor:
        _check_shuffle(test_dataloader)

        self.test_grads = []

        running_xinv_XTX_XT = 0
        running_count = 0
        for ckpt_idx in range(len(self.task_f.get_checkpoints())):
            parameters, _ = self.task_f.get_param(
                ckpt_idx=ckpt_idx,
                layer_name=self.layer_name,
            )
            # full_parameters, _ = self.task_f.get_param(ckpt_idx=ckpt_idx)
            if self.layer_name is not None:
                self.grad_target_func = self.task_f.get_grad_target_func(
                    in_dims=(None, 0),
                    layer_name=self.layer_name,
                    ckpt_idx=ckpt_idx,
                )
                self.grad_loss_func = self.task_l.get_grad_loss_func(
                    in_dims=(None, 0),
                    layer_name=self.layer_name,
                    ckpt_idx=ckpt_idx,
                )

            test_projected_grad = []
            for test_data in tqdm(
                test_dataloader,
                desc="calculating gradient of test set...",
                leave=False,
            ):
                if isinstance(test_data, (tuple, list)):
                    test_batch_data = tuple(data.to(self.device) for data in test_data)
                else:
                    test_batch_data = test_data
                grad_t = self.grad_target_func(parameters, test_batch_data)
                batch_size = grad_t.shape[0]
                grad_p = self.calc_grad_p(grad_t, batch_size, ckpt_idx)
                test_projected_grad.append(grad_p)
            test_projected_grad = torch.cat(test_projected_grad, dim=0)

            self.test_grads.append(test_projected_grad)

            running_xinv_XTX_XT = (
                running_xinv_XTX_XT * running_count
                + test_projected_grad @ self.inv_XTX_XT_list[ckpt_idx]
            )

            running_count += 1  # noqa: SIM113
            running_xinv_XTX_XT /= running_count
        return -running_xinv_XTX_XT.T
