import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.func import vmap


from trak import TRAKAttributor
from if_fim import IFFIMAttributor
from dattri.metric import lds
from dattri.benchmark.load import load_benchmark
from dattri.task import AttributionTask
from dattri.benchmark.utils import SubsetSampler

device = "cuda"
use_IFFIM = True  # NOTE: set to False to use TRAK
use_baselines = False  # NOTE: set to True to employ baselines

model_details = load_benchmark(model="resnet9", dataset="cifar2", metric="lds")[0]

model_details["train_sampler"] = SubsetSampler(range(1000))
model_details["test_sampler"] = SubsetSampler(range(100))

ensemble_models = [
    # NOTE: pick different settings
    f"./cifar2_resnet9_1.0_seed0_size1000_100/checkpoints/{i}/model_weights_0.pt"
    for i in range(50)
]

# NOTE: pick different settings
groundtruth = torch.load("./cifar2_resnet9_0.5_seed100_size1000_100/ground_truth.pt")


train_loader = DataLoader(
    model_details["train_dataset"],
    shuffle=False,
    batch_size=250,
    sampler=model_details["train_sampler"],
)
test_loader = DataLoader(
    model_details["test_dataset"],
    shuffle=False,
    batch_size=250,
    sampler=model_details["test_sampler"],
)


def f_trak(params, data_target_pair):
    image, label = data_target_pair
    image_t = image.unsqueeze(0)
    label_t = label.unsqueeze(0)
    loss = nn.CrossEntropyLoss()
    yhat = torch.func.functional_call(model_details["model"], params, image_t)
    logp = -loss(yhat, label_t)
    return logp - torch.log(1 - torch.exp(logp))


def loss_trak(params, data_target_pair):
    image, label = data_target_pair
    image_t = image.unsqueeze(0)
    label_t = label.unsqueeze(0)
    loss = nn.CrossEntropyLoss()
    yhat = torch.func.functional_call(model_details["model"], params, image_t)
    return loss(yhat, label_t)


task_f = AttributionTask(
    model=model_details["model"].to(device),
    loss_func=f_trak,
    checkpoints=ensemble_models[0],
)

task_l = AttributionTask(
    model=model_details["model"].to(device),
    loss_func=loss_trak,
    checkpoints=ensemble_models[0],
)


def m_trak(params, image_label_pair):
    image, label = image_label_pair
    image_t = image.unsqueeze(0)
    label_t = label.unsqueeze(0)
    loss = nn.CrossEntropyLoss()
    yhat = torch.func.functional_call(model_details["model"], params, image_t)
    p = torch.exp(-loss(yhat, label_t.long()))
    return p


projector_kwargs = {
    "proj_dim": 4096,  # NOTE: set projection dim; None means no projection
    "device": "cuda",
}


def calc_eigs():
    if use_IFFIM:
        attributor = IFFIMAttributor(
            task_f=task_f,
            task_l=task_l,
            correct_probability_func=m_trak,
            device=device,
            projector_kwargs=projector_kwargs,
            regularization=1e-2,
        )
    else:
        attributor = TRAKAttributor(
            task=task_f,
            correct_probability_func=m_trak,
            device=device,
            projector_kwargs=projector_kwargs,
            regularization=1e-2,
        )
    with torch.no_grad():
        attributor.cache(train_loader)
    return torch.linalg.eigvalsh(attributor.kernels[0])


if use_baselines:
    W = calc_eigs()
    lambs = [torch.quantile(W, q).item() for q in [0.5, 0.7, 0.9]]
else:
    lambs = [0, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4]


def surrogate_indicator(j, g, ginv):
    # j: (p,); g: (p, p); ginv: (p, p)
    gj = g @ j
    jginv = j @ ginv
    prod1 = jginv @ gj
    jginv = jginv @ ginv
    prod2 = jginv @ gj
    jginv = jginv @ ginv
    prod3 = jginv @ gj
    return prod2 / (prod1 * prod3).sqrt()


for d in lambs:
    if use_IFFIM:
        attributor = IFFIMAttributor(
            task_f=task_f,
            task_l=task_l,
            correct_probability_func=m_trak,
            device=device,
            projector_kwargs=projector_kwargs,
            regularization=d,
        )
    else:
        attributor = TRAKAttributor(
            task=task_f,
            correct_probability_func=m_trak,
            device=device,
            projector_kwargs=projector_kwargs,
            regularization=d,
        )

    with torch.no_grad():
        attributor.cache(train_loader)
        score = attributor.attribute(test_loader)

    res = vmap(surrogate_indicator, in_dims=(0, None, None))(
        attributor.test_grads[0].nan_to_num(),
        attributor.kernels[0].nan_to_num(),
        attributor.inv_kernels[0].nan_to_num(),
    )
    res[res.isinf() & (res > 0)] = 1
    print(f"Surrogate indicator (lambda = {d}): {res.nanmean()}")

    metric_score = lds(score, groundtruth)[0]
    metric_score = torch.mean(metric_score[~torch.isnan(metric_score)])

    print(f"LDS (lambda = {d}): {metric_score}")
