"""This experiment benchmarks specified TDA methods on the MNIST-10 dataset
to collect LDS scores for different hyperparameter settings."""

# ruff: noqa

import argparse
import torch
from torch import nn
from torch.utils.data import DataLoader
import itertools # For LiSSA grid search
import random
import sys
import csv
import datetime

from dattri.algorithm.influence_function import (
    IFAttributorCG,
    IFAttributorLiSSA,
    IFAttributorArnoldi, # Kept for ATTRIBUTOR_DICT, but not used in this specific run
    IFAttributorExplicit,
)

from dattri.algorithm.tracin import TracInAttributor
from dattri.algorithm.trak import TRAKAttributor
from dattri.algorithm.rps import RPSAttributor

from dattri.metric import lds, loo_corr
from dattri.benchmark.load import load_benchmark
from dattri.task import AttributionTask
from dattri.benchmark.datasets.mnist import train_mnist_mlp
from dattri.benchmark.utils import SubsetSampler

# --- Define the specific hyperparameter search spaces based on all requirements ---

epoch_search_values = [50, 70, 100, 150]

default_params = {
    "if-explicit": {"regularization": 1e-5, "layer_name": ["fc3.weight", "fc3.bias"]},
    "if-cg": {"regularization": 1e-2, "max_iter": 10},
    "if-lissa": {"damping": 1e-3, "scaling": 5, "recursion_depth": 1000, "batch_size": 50},
}

IHVP_SEARCH_SPACE_MODIFIED = {
    method: [
        {"epoch_num": epoch, **params} for epoch in epoch_search_values
    ]
    for method, params in default_params.items()
}

ATTRIBUTOR_DICT = {
    "if-explicit": IFAttributorExplicit,
    "if-cg": IFAttributorCG,
    "if-lissa": IFAttributorLiSSA,
    "if-arnoldi": IFAttributorArnoldi,
}

METRICS_DICT = {
    "lds": lds,
    "loo": loo_corr,
}

if __name__ == "__main__":
    argparser = argparse.ArgumentParser(description="Benchmark TDA methods with specified grid searches.")
    argparser.add_argument(
        "--method",
        type=str,
        required=True, 
        choices=["if-explicit", "if-cg", "if-lissa"],
        help="The TDA method to benchmark from the defined grid searches.",
    )
    argparser.add_argument(
        "--device", type=str, default="cuda", help="The device to run the experiment."
    )
    argparser.add_argument(
        "--seed", type=int, default=42, help="Random seed for reproducibility."
    )

    args_dataset = "mnist"
    args_model = "mlp"
    args_metric = "lds"

    args = argparser.parse_args()

    # Set seed for reproducibility
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)

    print(f"Running experiment for method: {args.method}")
    print(f"Dataset: {args_dataset}, Model: {args_model}, Metric: {args_metric}")
    print(f"Device: {args.device}")
    print(f"Seed: {args.seed}")


    print("Loading benchmark data...")
    model_details, groundtruth = load_benchmark(
        model=args_model, dataset=args_dataset, metric=args_metric
    )
    print("Benchmark data loaded.")

    # Batch sizes from the original script for MNIST
    train_loader_cache_batch_size = 5000
    train_loader_batch_size = 64
    test_loader_batch_size = 500

    train_loader_cache = DataLoader(
        model_details["train_dataset"],
        shuffle=False,
        batch_size=train_loader_cache_batch_size,
        sampler=model_details["train_sampler"],
    )
    train_loader = DataLoader(
        model_details["train_dataset"],
        shuffle=False,
        batch_size=train_loader_batch_size,
        sampler=model_details["train_sampler"],
    )
    test_loader = DataLoader(
        model_details["test_dataset"],
        shuffle=False,
        batch_size=test_loader_batch_size,
        sampler=model_details["test_sampler"],
    )
    print("DataLoaders prepared.")

    
    def loss_if(params, data_target_pair):
        image, label = data_target_pair
        loss_fn_val = nn.CrossEntropyLoss()
        yhat = torch.func.functional_call(model_details["model"], params, image)
        return loss_fn_val(yhat, label.long())


    all_results_for_method = []

    if args.method in IHVP_SEARCH_SPACE_MODIFIED:
        search_space = IHVP_SEARCH_SPACE_MODIFIED[args.method]
        total_configs = len(search_space)
        print(f"Starting grid search for {args.method} with {total_configs} configurations.")

        for i, full_config in enumerate(search_space):
            current_epoch_num = full_config["epoch_num"]
            attributor_params = {k: v for k, v in full_config.items() if k != "epoch_num"}

            print(f"\nRunning Configuration {i+1}/{total_configs}: Epochs={current_epoch_num}, Params={attributor_params}")

            try:
                # Train the model for the current number of epochs
                print(f"Training model for {current_epoch_num} epochs...")
                trained_model = train_mnist_mlp(
                    train_loader,
                    seed=args.seed,
                    device=args.device,
                    epoch_num=current_epoch_num
                )
                print("Model training complete.")

                # Create AttributionTask with the newly trained model
                task = AttributionTask(
                    model=trained_model.to(args.device),
                    loss_func=loss_if,
                    checkpoints=trained_model.state_dict(), # Use the state_dict of the trained model
                )
                
                attributor = ATTRIBUTOR_DICT[args.method](
                    task=task,
                    device=args.device,
                    **attributor_params,
                )
                
                print("Caching attributor...")
                attributor.cache(train_loader_cache)
                print("Caching complete.")

                print("Calculating scores...")
                with torch.no_grad():
                    score = attributor.attribute(train_loader, test_loader)
                print("Score calculation complete.")

                metric_fn = METRICS_DICT[args_metric]
                metric_score_tensor = metric_fn(score, groundtruth)[0]
                valid_scores = metric_score_tensor[~torch.isnan(metric_score_tensor)]
                if valid_scores.numel() > 0:
                    metric_score_val = torch.mean(valid_scores).item()
                else:
                    metric_score_val = float('nan') 

                print(f"Config: {full_config} -> {args_metric}: {metric_score_val:.6f}")
                all_results_for_method.append({"config": full_config, f"{args_metric}_score": metric_score_val})
            
            except Exception as e:
                print(f"Error during configuration {full_config}: {e}")
                all_results_for_method.append({"config": full_config, f"{args_metric}_score": "Error", "error_message": str(e)})
            
            print(f"Completed configuration {i+1}/{total_configs}.")

        print(f"\n--- All results for {args.method} ({args_dataset}, {args_model}) ---")
        for res in all_results_for_method:
            print(res)
        print("--- End of results ---")

        # Save results to CSV file
        timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        csv_filename = f"/results_{args.method}_{args_dataset}_{args_model}_{timestamp}.csv"
        
        # Define field names for CSV based on method type
        base_param_fields = ['epoch_num']
        specific_param_fields = []
        if args.method == 'if-explicit':
            specific_param_fields = ['regularization', 'layer_name']
        elif args.method == 'if-cg':
            specific_param_fields = ['regularization', 'max_iter']
        elif args.method == 'if-lissa':
            specific_param_fields = ['damping', 'scaling', 'recursion_depth', 'batch_size']
        
        fieldnames = base_param_fields + specific_param_fields + [f'{args_metric}_score', 'error_message']
        
        print(f"Writing results to {csv_filename}")
        
        try:
            with open(csv_filename, 'w', newline='') as csvfile:
                writer = csv.DictWriter(csvfile, fieldnames=fieldnames, extrasaction='ignore')
                writer.writeheader()
                
                for result in all_results_for_method:
                    row_to_write = result['config'].copy() # Start with config parameters
                    row_to_write[f'{args_metric}_score'] = result[f'{args_metric}_score']
                    row_to_write['error_message'] = result.get('error_message', '')

                    # Handle list conversion for layer_name if present
                    if 'layer_name' in row_to_write and isinstance(row_to_write['layer_name'], list):
                        row_to_write['layer_name'] = str(row_to_write['layer_name'])
                    
                    writer.writerow(row_to_write)
            
            print(f"Results successfully saved to {csv_filename}")
            
        except Exception as e:
            print(f"Error writing CSV file: {e}")

    else:
        print(f"Method {args.method} is not one of the configured methods for this specific benchmark run.")

    print("Experiment finished.")