import torch
from torch.utils.data import DataLoader
from torch.func import vmap


from trak import TRAKAttributor
from if_fim import IFFIMAttributor
from dattri.metric import lds
from dattri.benchmark.load import load_benchmark
from dattri.task import AttributionTask

from dattri.benchmark.models.MusicTransformer.utilities.constants import TOKEN_PAD

device = "cuda"
use_IFFIM = True  # NOTE: set to False to use TRAK
use_baselines = False  # NOTE: set to True to employ baselines

model_details, groundtruth = load_benchmark(
    model="musictransformer", dataset="maestro", metric="lds"
)


train_loader = DataLoader(
    model_details["train_dataset"],
    shuffle=False,
    batch_size=64,
    sampler=model_details["train_sampler"],
)
test_loader = DataLoader(
    model_details["test_dataset"],
    shuffle=False,
    batch_size=64,
)


def f_trak(params, data_target_pair):
    x, y = data_target_pair
    x_t = x.unsqueeze(0)
    y_t = y.unsqueeze(0)
    loss_fn = torch.nn.CrossEntropyLoss(ignore_index=TOKEN_PAD, reduction="none")

    output = torch.func.functional_call(model_details["model"], params, x_t)
    output_last = output[:, -1, :]
    y_last = y_t[:, -1]

    logp = -loss_fn(output_last, y_last)
    logit_func = logp - torch.log(1 - torch.exp(logp))
    return logit_func.squeeze(0)


def loss_trak(params, data_target_pair):
    x, y = data_target_pair
    x_t = x.unsqueeze(0)
    y_t = y.unsqueeze(0)
    loss_fn = torch.nn.CrossEntropyLoss(ignore_index=TOKEN_PAD, reduction="none")

    output = torch.func.functional_call(model_details["model"], params, x_t)
    output_last = output[:, -1, :]
    y_last = y_t[:, -1]

    return loss_fn(output_last, y_last).squeeze(0)


def correctness_p(params, data_target_pair):
    x, y = data_target_pair
    x_t = x.unsqueeze(0)
    y_t = y.unsqueeze(0)
    loss_fn = torch.nn.CrossEntropyLoss(ignore_index=TOKEN_PAD, reduction="none")

    output = torch.func.functional_call(model_details["model"], params, x_t)
    output_last = output[:, -1, :]
    y_last = y_t[:, -1]
    logp = -loss_fn(output_last, y_last)

    return torch.exp(logp)


task_f = AttributionTask(
    loss_func=f_trak,
    model=model_details["model"].to(device),
    checkpoints=model_details["models_full"][0],
)

task_l = AttributionTask(
    loss_func=loss_trak,
    model=model_details["model"].to(device),
    checkpoints=model_details["models_full"][0],
)

projector_kwargs = {
    "proj_dim": 4096,  # NOTE: set projection dim; None means no projection
    "device": device,
    "use_half_precision": False,
}


def calc_eigs():
    if use_IFFIM:
        attributor = IFFIMAttributor(
            task_f=task_f,
            task_l=task_l,
            correct_probability_func=correctness_p,
            device=device,
            projector_kwargs=projector_kwargs,
            regularization=1e-2,
        )
    else:
        attributor = TRAKAttributor(
            task=task_f,
            correct_probability_func=correctness_p,
            device=device,
            projector_kwargs=projector_kwargs,
            regularization=1e-2,
        )

    with torch.no_grad():
        attributor.cache(train_loader)

    return torch.linalg.eigvalsh(attributor.kernels[0])


if use_baselines:
    W = calc_eigs()
    lambs = [torch.quantile(W, q).item() for q in [0.1, 0.3, 0.5, 0.7, 0.9]]
else:
    lambs = [0, 1e-2, 5e-2, 1e-1, 5e-1, 1e0, 5e0, 1e1, 5e1, 1e2]


def surrogate_indicator(j, g, ginv):
    # j: (p,); g: (p, p); ginv: (p, p)
    gj = g @ j
    jginv = j @ ginv
    prod1 = jginv @ gj
    jginv = jginv @ ginv
    prod2 = jginv @ gj
    jginv = jginv @ ginv
    prod3 = jginv @ gj
    return prod2 / (prod1 * prod3).sqrt()


for d in lambs:
    if use_IFFIM:
        attributor = IFFIMAttributor(
            task_f=task_f,
            task_l=task_l,
            correct_probability_func=correctness_p,
            device=device,
            projector_kwargs=projector_kwargs,
            regularization=d,
        )
    else:
        attributor = TRAKAttributor(
            task=task_f,
            correct_probability_func=correctness_p,
            device=device,
            projector_kwargs=projector_kwargs,
            regularization=d,
        )

    with torch.no_grad():
        attributor.cache(train_loader)
        score = attributor.attribute(test_loader)

    res = vmap(surrogate_indicator, in_dims=(0, None, None))(
        attributor.test_grads[0].nan_to_num(),
        attributor.kernels[0].nan_to_num(),
        attributor.inv_kernels[0].nan_to_num(),
    )
    res[res.isinf() & (res > 0)] = 1
    print(f"Surrogate indicator (lambda = {d}): {res.nanmean()}")

    metric_score = lds(score, groundtruth)[0]
    metric_score = torch.mean(metric_score[~torch.isnan(metric_score)])

    print(f"LDS (lambda = {d}): {metric_score}")
