import argparse
import itertools
import csv

import torch
from torch import nn, optim
from torch.utils.data import DataLoader

from dattri.algorithm.tracin import TracInAttributor
from dattri.metric import lds
from dattri.benchmark.load import load_benchmark
from dattri.benchmark.datasets.mnist import create_mlp_model
from dattri.task import AttributionTask

def train_mnist_mlp_with_ckpts(
    dataloader: DataLoader,
    seed: int,
    device: str,
    total_epochs: int,
    save_last_k: int
):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    import numpy as np, random as pyrandom
    np.random.seed(seed)
    pyrandom.seed(seed)

    model = create_mlp_model("mnist").to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

    ckpts = []
    for epoch in range(1, total_epochs + 1):
        model.train()
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        if epoch > total_epochs - save_last_k:
            sd = {k: v.cpu().clone() for k, v in model.state_dict().items()}
            ckpts.append(sd)
            print(f"  checkpoint saved for epoch {epoch}")

    return model, ckpts

def loss_tracin(params, data_pair):
    x, y = data_pair
    x_t = x.unsqueeze(0)
    y_t = y.unsqueeze(0).long()
    loss = nn.CrossEntropyLoss()
    logits = torch.func.functional_call(model, params, x_t)
    return loss(logits, y_t)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Train MNIST MLP for 20 epochs, save last 10 ckpts, grid‐search TracIn"
    )
    parser.add_argument("--device",      type=str, default="cpu",
                        help="cpu or cuda")
    parser.add_argument("--total_epochs",type=int, default=20,
                        help="number of epochs to train (fixed=20)")
    parser.add_argument("--save_last_k", type=int, default=10,
                        help="how many of the final epochs to save (fixed=10)")
    args = parser.parse_args()

    device = args.device

    model_details, groundtruth = load_benchmark(
        model="mlp", dataset="mnist", metric="lds"
    )

    full_loader = DataLoader(
        model_details["train_dataset"],
        batch_size=64,
        sampler=model_details["train_sampler"],
    )
    print("Training MLP for", args.total_epochs, "epochs...")

    model, ckpts = train_mnist_mlp_with_ckpts(
        full_loader,
        seed=0,
        device=device,
        total_epochs=args.total_epochs,
        save_last_k=args.save_last_k,
    )

    task = AttributionTask(
        model=model,
        loss_func=loss_tracin,
        checkpoints=ckpts,
    )

    train_loader = DataLoader(
        model_details["train_dataset"],
        batch_size=64,
        shuffle=False,
        sampler=model_details["train_sampler"],
    )
    test_loader = DataLoader(
        model_details["test_dataset"],
        batch_size=64,
        shuffle=False,
        sampler=model_details["test_sampler"],
    )

    proj_dims        = [None, 4096, 2048, 1024, 512]
    normalized_grads = [False, True]
    weight_list      = torch.ones(len(ckpts), device=device) * 1e-3

    results = []
    print("Running grid search...")
    for pd, ng in itertools.product(proj_dims, normalized_grads):
        kwargs = {
            "task": task,
            "weight_list": weight_list,
            "normalized_grad": ng,
            "device": device,
        }
        if pd is not None:
            kwargs["projector_kwargs"] = {
                "proj_dim": pd,
                "proj_max_batch_size": 32,
                "proj_seed": 0,
                "device": device,
                "use_half_precision": False,
            }

        attributor = TracInAttributor(**kwargs)
        with torch.no_grad():
            score = attributor.attribute(train_loader, test_loader)

        lds_vals = lds(score, groundtruth)[0]
        mean_lds = torch.mean(lds_vals[~torch.isnan(lds_vals)]).item()

        tag = "NoProj" if pd is None else str(pd)
        print(f"[proj_dim={tag}, normalized_grad={ng}] → LDS={mean_lds:.4f}")
        results.append([tag, ng, mean_lds])

    out_csv = "tracin_mnist_ckpts_grid.csv"
    with open(out_csv, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["proj_dim", "normalized_grad", "LDS"])
        writer.writerows(results)

    print(f"Grid search complete. Results saved to {out_csv}")
