import argparse
from pathlib import PosixPath

import torch
from torch import nn
from torch.utils.data import DataLoader
import random
import sys
import os
import csv

from dattri.algorithm.trak import TRAKAttributor
from dattri.benchmark.load import load_benchmark
from dattri.metric import lds
from dattri.task import AttributionTask
from dattri.benchmark.datasets.mnist import train_mnist_mlp, create_mnist_dataset
from dattri.benchmark.utils import SubsetSampler

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--device", default="cpu", type=str)
    parser.add_argument("--test_type", default="epoch", type=str)
    parser.add_argument("--regularization", type=float, default=0.0)
    parser.add_argument("--proj_dim", type=int, default=512)
    parser.add_argument("--epoch_num", type=int, default=20)
    args = parser.parse_args()

    # Load pre-trained benchmark
    model_details, groundtruth = load_benchmark(
        model="mlp", dataset="mnist", metric="lds"
    )
    
    train_loader = DataLoader(
        model_details["train_dataset"],
        batch_size=64,
        sampler=model_details["train_sampler"],
    )
    print("train dataset size:", len(model_details["train_sampler"]))

    test_loader = DataLoader(
        model_details["test_dataset"],
        batch_size=64,
        sampler=model_details["test_sampler"],
    )
    print("test dataset size:", len(model_details["test_sampler"]))

    model = model_details["model"].to(args.device)

    def f(params, data_target_pair):
        image, label = data_target_pair
        image_t = image.unsqueeze(0)
        label_t = label.unsqueeze(0)
        loss = nn.CrossEntropyLoss()
        yhat = torch.func.functional_call(model, params, image_t)
        logp = -loss(yhat, label_t)
        return logp - torch.log(1 - torch.exp(logp))

    def m(params, image_label_pair):
        image, label = image_label_pair
        image_t = image.unsqueeze(0)
        label_t = label.unsqueeze(0)
        loss = nn.CrossEntropyLoss()
        yhat = torch.func.functional_call(model, params, image_t)
        p = torch.exp(-loss(yhat, label_t))
        return p

    half_model_checkpoints = []
    for i in range(10):
        random.seed(i)
        train_index = random.sample(range(5000), int(2500))
        sub_train_sampler = SubsetSampler(train_index)
        print("sub train sampler size:", len(sub_train_sampler))
        sub_train_loader = DataLoader(
            model_details["train_dataset"],
            batch_size=64,
            sampler=sub_train_sampler,
        )
        model = train_mnist_mlp(sub_train_loader, seed=i, device=args.device, epoch_num=args.epoch_num)
        half_model_checkpoints.append(model.state_dict())

    task = AttributionTask(
        model=model.to(args.device),
        loss_func=f,
        checkpoints=half_model_checkpoints,
    )
    
    PROJECTOR_KWARGS = {
        "proj_dim": args.proj_dim,
        "proj_max_batch_size": 32,
        "proj_seed": 0,
        "device": args.device,
        "use_half_precision": False,
    }

    attributor = TRAKAttributor(
        task=task,
        correct_probability_func=m,
        device=args.device,
        projector_kwargs=PROJECTOR_KWARGS,
        regularization=args.regularization,
    )

    with torch.no_grad():
        score = attributor.attribute(test_loader, train_loader)
    lds_score = lds(score, groundtruth)[0]
    mean_lds = torch.mean(lds_score[~torch.isnan(lds_score)])

    # Print final result with all parameters
    print(f"[Result] regularization={args.regularization}, proj_dim={args.proj_dim}, epoch_num={args.epoch_num} => LDS={mean_lds.item():.4f}")

    # Optional: Save results to CSV
    result_path = "grid.csv"
    file_exists = os.path.isfile(result_path)
    with open(result_path, mode='a', newline='') as f:
        writer = csv.writer(f)
        if not file_exists:
            writer.writerow(["regularization", "proj_dim", "epoch_num", "LDS"])
        writer.writerow([args.regularization, args.proj_dim, args.epoch_num, mean_lds.item()])
