import argparse
from pathlib import PosixPath

import torch
from torch import nn
from torch.utils.data import DataLoader
import random
import sys
import os
import csv


from dattri.algorithm.trak import TRAKAttributor
from dattri.benchmark.load import load_benchmark
from dattri.metric import lds
from dattri.task import AttributionTask
from dattri.benchmark.datasets.cifar import train_cifar_resnet9
from dattri.benchmark.utils import SubsetSampler

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--device", default="cpu", type=str)
    parser.add_argument("--test_type", default="epoch", type=str)
    parser.add_argument("--regularization", type=float, default=0.0)
    parser.add_argument("--proj_dim", type=int, default=512)
    parser.add_argument("--epoch_num", type=int, default=50)
    args = parser.parse_args()

    # Load pre-trained benchmark
    model_details, groundtruth = load_benchmark(
        model="resnet9", dataset="cifar2", metric="lds"
    )
    
    train_loader = DataLoader(
        model_details["train_dataset"],
        batch_size=64,
        sampler=model_details["train_sampler"],
    )
    print("train dataset size:", len(model_details["train_sampler"]))

    test_loader = DataLoader(
        model_details["test_dataset"],
        batch_size=64,
        sampler=model_details["test_sampler"],
    )
    print("test dataset size:", len(model_details["test_sampler"]))

    model = model_details["model"].to(args.device)

    def f(params, data_target_pair):
        image, label = data_target_pair
        image_t = image.unsqueeze(0)
        label_t = label.unsqueeze(0)
        loss = nn.CrossEntropyLoss()
        yhat = torch.func.functional_call(model, params, image_t)
        logp = -loss(yhat, label_t)
        return logp - torch.log(1 - torch.exp(logp))

    def m(params, image_label_pair):
        image, label = image_label_pair
        image_t = image.unsqueeze(0)
        label_t = label.unsqueeze(0)
        loss = nn.CrossEntropyLoss()
        yhat = torch.func.functional_call(model, params, image_t)
        p = torch.exp(-loss(yhat, label_t))
        return p

    half_model_checkpoints = []
    for i in range(10):
        random.seed(i)
        train_index = random.sample(range(5000), int(2500))
        sub_train_sampler = SubsetSampler(train_index)
        print("sub train sampler size:", len(sub_train_sampler))
        sub_train_loader = DataLoader(
            model_details["train_dataset"],
            batch_size=64,
            sampler=sub_train_sampler,
        )
        model = train_cifar_resnet9(sub_train_loader, seed=i, device=args.device, num_epochs=args.epoch_num)
        half_model_checkpoints.append(model.state_dict())

    task = AttributionTask(
        model=model.to(args.device),
        loss_func=f,
        checkpoints=half_model_checkpoints,
    )
    
    PROJECTOR_KWARGS = {
        "proj_dim": args.proj_dim,
        "proj_max_batch_size": 32,
        "proj_seed": 0,
        "device": args.device,
        "use_half_precision": False,
    }

    attributor = TRAKAttributor(
        task=task,
        correct_probability_func=m,
        device=args.device,
        projector_kwargs=PROJECTOR_KWARGS,
        regularization=args.regularization,
    )

    with torch.no_grad():
        score = attributor.attribute(test_loader, train_loader)
    lds_score = lds(score, groundtruth)[0]
    mean_lds = torch.mean(lds_score[~torch.isnan(lds_score)])

    # Print final result with all parameters
    print(f"[Result] regularization={args.regularization}, proj_dim={args.proj_dim}, epoch_num={args.epoch_num} => LDS={mean_lds.item():.4f}")

    # Optional: Save results to CSV
    result_path = "grid.csv"
    file_exists = os.path.isfile(result_path)
    with open(result_path, mode='a', newline='') as f:
        writer = csv.writer(f)
        if not file_exists:
            writer.writerow(["regularization", "proj_dim", "epoch_num", "LDS"])
        writer.writerow([args.regularization, args.proj_dim, args.epoch_num, mean_lds.item()])
