#!/usr/bin/env python3
"""
Compute influence functions for a trained VAE using efficient HVP+CG (Koh & Liang, 2017, Sec. 3).
"""
import os
import sys
from pathlib import Path
from typing import Tuple

import torch
from torch.utils.data import DataLoader, Subset

# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.append(str(project_root))

from utils.model_utils import load_trained_model
from utils.data_utils import get_mnist_data
from model.vae_models.vae import vae_loss_per_sample
from utils.influence import hvp, conjugate_gradient, s_test


def select_device() -> torch.device:
    requested = os.environ.get('VAE_DEVICE', None)
    if requested in {'cpu', 'cuda'}:
        if requested == 'cuda' and not torch.cuda.is_available():
            print('CUDA is not available. Falling back to CPU.')
            return torch.device('cpu')
        device = torch.device(requested)
    else:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Using device: {device}')
    if str(device) == 'cuda':
        try:
            _ = torch.tensor([0.0], device=device)
        except Exception as e:
            print(f'CUDA device check failed ({type(e).__name__}: {e}). Falling back to CPU.')
            device = torch.device('cpu')
            print('Using device: cpu')
    return device


def main() -> None:
    import argparse
    parser = argparse.ArgumentParser(description='Compute IF for VAE with HVP+CG')
    parser.add_argument('--experiment_dir', type=str, required=True)
    parser.add_argument('--split_idx', type=int, default=0)
    parser.add_argument('--num_train_samples', type=int, default=512, help='subset train for demo speed')
    parser.add_argument('--num_test_samples', type=int, default=64)
    parser.add_argument('--damp', type=float, default=0.01)
    parser.add_argument('--scale', type=float, default=25.0)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--out', type=str, default=None)
    args = parser.parse_args()

    device = select_device()

    # Load trained model and config
    model, config = load_trained_model(args.experiment_dir, split_idx=args.split_idx, device=device)
    model.zero_grad()

    # Data
    train_dataset, test_dataset = get_mnist_data()
    train_subset = Subset(train_dataset, list(range(min(len(train_dataset), args.num_train_samples))))
    test_subset = Subset(test_dataset, list(range(min(len(test_dataset), args.num_test_samples))))
    train_loader = DataLoader(train_subset, batch_size=args.batch_size, shuffle=False)
    test_loader = DataLoader(test_subset, batch_size=args.batch_size, shuffle=False)

    # Define loss_fn for HVP over the (subsampled) train set
    def loss_fn_for_hvp(model_: torch.nn.Module) -> torch.Tensor:
        total = 0.0
        count = 0
        for data, _ in train_loader:
            data = data.to(device)
            recon, mu, logvar = model_(data)
            per_sample = vae_loss_per_sample(recon, data, mu, logvar, beta=config['training']['beta'])
            total = total + per_sample.mean()
            count += 1
        return total / max(count, 1)

    # Compute influence for the first test batch (average gradient)
    test_batch = next(iter(test_loader))
    test_x, _ = test_batch
    test_x = test_x.to(device)
    recon_t, mu_t, logvar_t = model(test_x)
    test_per_sample = vae_loss_per_sample(recon_t, test_x, mu_t, logvar_t, beta=config['training']['beta'])

    # Gradient wrt params of test loss (we average over test samples here)
    params = [p for p in model.parameters() if p.requires_grad]
    test_loss = test_per_sample.mean()
    g_test = torch.autograd.grad(test_loss, params, create_graph=False)
    g_test_flat = torch.cat([g.reshape(-1) for g in g_test]).detach()

    # Compute s_test = H^{-1} g_test
    s = s_test(model, loss_fn_for_hvp, g_test_flat, damp=args.damp, scale=args.scale)

    # Compute train per-sample gradients and influences
    influences = []
    with torch.enable_grad():
        for data, _ in train_loader:
            data = data.to(device)
            recon, mu, logvar = model(data)
            per_sample = vae_loss_per_sample(recon, data, mu, logvar, beta=config['training']['beta'])
            for i in range(per_sample.shape[0]):
                gz = torch.autograd.grad(per_sample[i], params, retain_graph=True)
                gz_flat = torch.cat([g.reshape(-1) for g in gz])
                influences.append(-torch.dot(gz_flat, s).detach().cpu())

    influences_tensor = torch.stack(influences)
    out_path = args.out or os.path.join(args.experiment_dir, f'influence_split{args.split_idx}.pt')
    torch.save({'influences': influences_tensor}, out_path)
    print(f'Saved influences to {out_path}')


if __name__ == '__main__':
    main()


