from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from typing import Any, Callable, Dict, List, Optional, Union

    from dattri.task import AttributionTask


import torch
from torch.func import vmap
from tqdm import tqdm

from dattri.func.projection import random_project
from dattri.func.utils import _unflatten_params

from dattri.algorithm.base import BaseAttributor
from dattri.algorithm.utils import _check_shuffle

DEFAULT_PROJECTOR_KWARGS = {
    "proj_dim": 512,
    "proj_max_batch_size": 32,
    "proj_seed": 0,
    "device": "cpu",
    "use_half_precision": False,
}


class TRAKAttributor(BaseAttributor):
    def __init__(
        self,
        task: AttributionTask,
        correct_probability_func: Callable,
        projector_kwargs: Optional[Dict[str, Any]] = None,
        layer_name: Optional[Union[str, List[str]]] = None,
        device: str = "cpu",
        regularization: float = 0.0,
    ) -> None:
        self.task = task
        self.projector_kwargs = DEFAULT_PROJECTOR_KWARGS
        if projector_kwargs is not None:
            self.projector_kwargs.update(projector_kwargs)
        self.layer_name = layer_name
        self.device = device
        self.grad_target_func = self.task.get_grad_target_func(in_dims=(None, 0))
        self.grad_loss_func = self.task.get_grad_loss_func(in_dims=(None, 0))
        self.correct_probability_func = vmap(
            correct_probability_func,
            in_dims=(None, 0),
            randomness="different",
        )
        self.full_train_dataloader = None
        self.regularization = regularization

    def calc_grad_p(
        self, grad_t: torch.Tensor, batch_size: int, ckpt_idx: int
    ) -> torch.Tensor:
        grad_t = torch.nan_to_num(grad_t) / (self.num_train**0.5)
        if self.projector_kwargs["proj_dim"] is None:
            grad_p = grad_t
        else:
            grad_t /= self.projector_kwargs["proj_dim"] ** 0.5
            batch_size = grad_t.shape[0]
            grad_p = (
                random_project(
                    grad_t,
                    batch_size,
                    **self.projector_kwargs,
                )(grad_t, ensemble_id=ckpt_idx)
                .clone()
                .detach()
            )
        return grad_p

    def cache(
        self,
        full_train_dataloader: torch.utils.data.DataLoader,
    ) -> None:
        _check_shuffle(full_train_dataloader)

        self.num_train = 0
        for batch in full_train_dataloader:
            if isinstance(batch, (tuple, list)):
                bs = len(batch[0])
            else:
                bs = len(batch)
            self.num_train += bs

        self.kernels = []
        self.inv_kernels = []
        self.train_grads = []

        self.full_train_dataloader = full_train_dataloader
        inv_XTX_XT_list = []
        running_Q = 0
        running_count = 0
        for ckpt_idx in range(len(self.task.get_checkpoints())):
            parameters, _ = self.task.get_param(
                ckpt_idx=ckpt_idx,
                layer_name=self.layer_name,
            )
            full_parameters, _ = self.task.get_param(ckpt_idx=ckpt_idx)
            if self.layer_name is not None:
                self.grad_target_func = self.task.get_grad_target_func(
                    in_dims=(None, 0),
                    layer_name=self.layer_name,
                    ckpt_idx=ckpt_idx,
                )
                self.grad_loss_func = self.task.get_grad_loss_func(
                    in_dims=(None, 0),
                    layer_name=self.layer_name,
                    ckpt_idx=ckpt_idx,
                )

            full_train_projected_grad = []
            Q = []
            for train_data in tqdm(
                self.full_train_dataloader,
                desc="calculating gradient of training set...",
                leave=False,
            ):
                if isinstance(train_data, (tuple, list)):
                    train_batch_data = tuple(
                        data.to(self.device) for data in train_data
                    )
                else:
                    train_batch_data = train_data

                grad_t = self.grad_loss_func(parameters, train_batch_data)
                batch_size = grad_t.shape[0]
                grad_p = self.calc_grad_p(grad_t, batch_size, ckpt_idx)
                full_train_projected_grad.append(grad_p)
                Q.append(
                    (
                        torch.ones(batch_size).to(self.device)
                        - self.correct_probability_func(
                            _unflatten_params(full_parameters, self.task.get_model()),
                            train_batch_data,
                        ).flatten()
                    )
                    .clone()
                    .detach(),
                )
            full_train_projected_grad = torch.cat(full_train_projected_grad, dim=0)
            Q = torch.cat(Q, dim=0)

            self.train_grads.append(full_train_projected_grad.clone().nan_to_num())

            kernel_matrix = full_train_projected_grad.T @ full_train_projected_grad

            self.kernels.append(kernel_matrix.clone())

            kernel_matrix.diagonal().add_(self.regularization)

            inv_XTX_XT = torch.linalg.inv(kernel_matrix)
            self.inv_kernels.append(inv_XTX_XT)

            inv_XTX_XT = inv_XTX_XT @ full_train_projected_grad.T

            inv_XTX_XT_list.append(inv_XTX_XT)
            running_Q = running_Q * running_count + Q
            running_count += 1  # noqa: SIM113
            running_Q /= running_count
        self.inv_XTX_XT_list = inv_XTX_XT_list
        self.Q = running_Q

    def attribute(  # noqa: PLR0912, PLR0914, PLR0915
        self,
        test_dataloader: torch.utils.data.DataLoader,
        train_dataloader: Optional[torch.utils.data.DataLoader] = None,
    ) -> torch.Tensor:
        _check_shuffle(test_dataloader)
        if train_dataloader is not None:
            _check_shuffle(train_dataloader)

        self.test_grads = []

        running_xinv_XTX_XT = 0
        running_Q = 0
        running_count = 0
        if train_dataloader is not None and self.full_train_dataloader is not None:
            message = "You have cached a training loader by .cache()\
                       and you are trying to attribute a different training loader.\
                       If this new training loader is a subset of the cached training\
                       loader, please don't input the training dataloader in\
                       .attribute() and directly use index to select the corresponding\
                       scores."
            raise ValueError(message)
        if train_dataloader is None and self.full_train_dataloader is None:
            message = "You did not state a training loader in .attribute() and you\
                       did not cache a training loader by .cache(). Please provide a\
                       training loader or cache a training loader."
            raise ValueError(message)
        for ckpt_idx in range(len(self.task.get_checkpoints())):
            parameters, _ = self.task.get_param(
                ckpt_idx=ckpt_idx,
                layer_name=self.layer_name,
            )
            full_parameters, _ = self.task.get_param(ckpt_idx=ckpt_idx)
            if self.layer_name is not None:
                self.grad_target_func = self.task.get_grad_target_func(
                    in_dims=(None, 0),
                    layer_name=self.layer_name,
                    ckpt_idx=ckpt_idx,
                )
                self.grad_loss_func = self.task.get_grad_loss_func(
                    in_dims=(None, 0),
                    layer_name=self.layer_name,
                    ckpt_idx=ckpt_idx,
                )

            if train_dataloader is not None:
                train_projected_grad = []
                Q = []
                for train_data in tqdm(
                    train_dataloader,
                    desc="calculating gradient of training set...",
                    leave=False,
                ):
                    if isinstance(train_data, (tuple, list)):
                        train_batch_data = tuple(
                            data.to(self.device) for data in train_data
                        )
                    else:
                        train_batch_data = train_data

                    grad_t = self.grad_loss_func(
                        parameters,
                        train_batch_data,
                    )
                    batch_size = grad_t.shape[0]
                    grad_p = self.calc_grad_p(grad_t, batch_size, ckpt_idx)
                    train_projected_grad.append(grad_p)
                    Q.append(
                        (
                            torch.ones(batch_size).to(self.device)
                            - self.correct_probability_func(
                                _unflatten_params(
                                    full_parameters,
                                    self.task.get_model(),
                                ),
                                train_batch_data,
                            )
                        )
                        .clone()
                        .detach(),
                    )
                train_projected_grad = torch.cat(train_projected_grad, dim=0)
                Q = torch.cat(Q, dim=0)

            test_projected_grad = []
            for test_data in tqdm(
                test_dataloader,
                desc="calculating gradient of test set...",
                leave=False,
            ):
                if isinstance(test_data, (tuple, list)):
                    test_batch_data = tuple(data.to(self.device) for data in test_data)
                else:
                    test_batch_data = test_data
                grad_t = self.grad_target_func(parameters, test_batch_data)
                batch_size = grad_t.shape[0]
                grad_p = self.calc_grad_p(grad_t, batch_size, ckpt_idx)
                test_projected_grad.append(grad_p)
            test_projected_grad = torch.cat(test_projected_grad, dim=0)

            self.test_grads.append(test_projected_grad)

            if train_dataloader is not None:
                kernel_matrix = train_projected_grad.T @ train_projected_grad
                kernel_matrix.diagonal().add_(self.regularization)
                running_xinv_XTX_XT = (
                    running_xinv_XTX_XT * running_count
                    + test_projected_grad
                    @ (torch.linalg.inv(kernel_matrix) @ train_projected_grad.T)
                )
            else:
                running_xinv_XTX_XT = (
                    running_xinv_XTX_XT * running_count
                    + test_projected_grad @ self.inv_XTX_XT_list[ckpt_idx]
                )

            if train_dataloader is not None:
                running_Q = running_Q * running_count + Q
            running_count += 1  # noqa: SIM113
            if train_dataloader is not None:
                running_Q /= running_count
            running_xinv_XTX_XT /= running_count
        if train_dataloader is not None:
            return (running_xinv_XTX_XT * running_Q.to(self.device).unsqueeze(0)).T
        return (running_xinv_XTX_XT * self.Q.to(self.device).unsqueeze(0)).T
