import os
import torch
import numpy as np
import csv
from torch.utils.data import DataLoader
import torch.nn as nn
from torch.distributions import MultivariateNormal
import torch.nn.functional as F
from preact_resnet import PreActResNet18

# Device setup
device = torch.device("cuda:5" if torch.cuda.is_available() else "cpu")


# Custom Dataset class
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, images, labels):
        self.images = images
        self.labels = [label.item() if isinstance(label, torch.Tensor) else label for label in labels]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        return image, label

# Load data and model
def load_data_and_model(balance_ratio):
    base_path = '/data/RPP/saved_models_and_datasets/cifar10/'
    model_path = os.path.join(base_path, f'imbalanced_{balance_ratio}_/final_model.pth')
    clean_train_dataset_path = os.path.join(base_path,
                                            f'imbalanced_{balance_ratio}_/imbalanced_clean_train_dataset.pth')
    attacked_dataset_path = os.path.join(base_path, f'imbalanced_{balance_ratio}_/attacked_trainset.pth')

    model.load_state_dict(torch.load(model_path))
    clean_train_dataset = torch.load(clean_train_dataset_path)
    attacked_dataset = torch.load(attacked_dataset_path)

    return model, clean_train_dataset, attacked_dataset

# Extract probabilities
probs_list = []

def probs_hook(module, input, output):
    probabilities = F.softmax(output, dim=1)
    probs_list.append(probabilities)

model = PreActResNet18().to(device)
model.linear.register_forward_hook(probs_hook)
def extract_probabilities_single(image, model, device):
    model.eval()
    probs_list.clear()
    with torch.no_grad():
        _ = model(image.to(device).unsqueeze(0))
    return probs_list.pop().squeeze(0).to(device)
def extract_probabilities_single_with_noise(image, model, device, covariance_scale):
    model.eval()
    probs_list.clear()
    noise = torch.randn_like(image) * covariance_scale
    noisy_image = image + noise
    with torch.no_grad():
        _ = model(noisy_image.to(device).unsqueeze(0))
    return probs_list.pop().squeeze(0).to(device)
# Calculate distances
def calculate_inf_norm_distance(image, model, device, covariance_scale, num_noisy_samples=3):

    image = image.to(device)
    original_prob = extract_probabilities_single(image, model, device)

    distances = []
    for _ in range(num_noisy_samples):
        noisy_prob = extract_probabilities_single_with_noise(image, model, device, covariance_scale)
        distance = torch.norm(original_prob - noisy_prob, p=float('inf')).item()
        distances.append(distance)
    
    return sum(distances) / len(distances)
def process_dataset(dataset, model, device, covariance_scale, num_noisy_samples=3):
    mean_distances = []
    for img, _ in DataLoader(dataset, batch_size=1):
        img = img.squeeze(0)
        inf_norm_distance = calculate_inf_norm_distance(img, model, device, covariance_scale, num_noisy_samples)
        mean_distances.append(inf_norm_distance)
    return mean_distances
def calculate_rates_at_different_cutoffs(validate_set_mean_distances, imbalanced_mean_distances,
                                         attackset_mean_distances, alpha):
    """
    Calculate T (alpha quantile), TPR, and FPR based on validate, imbalanced, and attackset mean distances.

    Args:
        validate_set_mean_distances (list): Mean distances for the validation (calibration) set.
        imbalanced_mean_distances (list): Mean distances for the imbalanced (non-attacked) set.
        attackset_mean_distances (list): Mean distances for the attack (poisoned) set.
        alpha (float): Significance level.

    Returns:
        tuple: T (alpha-quantile threshold), TPR, and FPR.
    """
    # Step 1: Sort the validate set distances to find the alpha-quantile threshold
    validate_set_mean_distances_sorted = sorted(validate_set_mean_distances)
    n = len(validate_set_mean_distances)

    # Compute the alpha-quantile threshold T based on the formula from the image
    quantile_index = int(alpha * (n + 1)) - 1  # Convert to 0-based index
    quantile_index = max(0, min(quantile_index, n - 1))  # Ensure index is within bounds
    T = validate_set_mean_distances_sorted[quantile_index]

    # Step 2: Calculate TPR (True Positive Rate)
    tpr = sum(distance <= T for distance in attackset_mean_distances) / len(attackset_mean_distances)

    # Step 3: Calculate FPR (False Positive Rate)
    fpr = sum(distance <= T for distance in imbalanced_mean_distances) / len(imbalanced_mean_distances)

    return T, tpr, fpr
# Save results to CSV
def save_results_to_csv(results, csv_file='result_set_cifar10.csv'):
    with open(csv_file, 'w', newline='') as csvfile:
        csvwriter = csv.writer(csvfile)
        csvwriter.writerow(
            ['Balance Ratio', 'Validation Dataset', 'Alpha', 'Covariance Scale', 'Threshold', 'TPR', 'FPR'])
        for (balance_ratio, validate_set_name, alpha, covariance_scale), (T, tpr, fpr) in results.items():
            csvwriter.writerow([balance_ratio, validate_set_name, alpha, covariance_scale, T, tpr, fpr])

# Parameters
# balance_ratios = ['0.001','0.005', '0.01', '0.1', '0.5', '1']
# validation_datasets = ['val_100_samples.pth', 'val_200_samples.pth', 'val_400_samples.pth', 'val_600_samples.pth', 'val_800_samples.pth', 'val_1000_samples.pth']
# cutoff_alphas = [0.05, 0.1]
# covariance_scales = [0.1, 0.5, 1.0, 1.5, 2.0, 2.5,3.0]
# num_noisy_samples = 3

balance_ratios = ['1']
cutoff_alphas = [0.05, 0.1]
validation_datasets = ['val_100_samples.pth']
covariance_scales = [1.0]
num_noisy_samples = 3

# Results storage
results = {}
# Iterate over balance ratios and validation datasets
for balance_ratio in balance_ratios:
    model, imbalanced_clean_train_dataset, attackset = load_data_and_model(balance_ratio)
    for validate_set_name in validation_datasets:
        validate_set = torch.load(f'/data/RPP/calibration_data/cifar10_calibration_sets//{validate_set_name}')
        for covariance_scale in covariance_scales:
            validate_distances = process_dataset(validate_set, model, device, covariance_scale, num_noisy_samples)
            imbalanced_mean_distances = process_dataset(imbalanced_clean_train_dataset, model, device, covariance_scale,
                                                        num_noisy_samples)
            attackset_mean_distances = process_dataset(attackset, model, device, covariance_scale, num_noisy_samples)
            for alpha in cutoff_alphas:
                T, tpr, fpr = calculate_rates_at_different_cutoffs(validate_distances, imbalanced_mean_distances,
                                                                   attackset_mean_distances, alpha)
                results[(balance_ratio, validate_set_name, alpha, covariance_scale)] = (T, tpr, fpr)

                print(
                    f"Balance Ratio: {balance_ratio}, Validation Dataset: {validate_set_name}, Alpha: {alpha}, Covariance Scale: {covariance_scale}")
                print(f"  Threshold: {T:.4f}, TPR: {tpr:.4f}, FPR: {fpr:.4f}")

# Save to CSV
save_results_to_csv(results)
print(f"Results saved to 'result_set_cifar10.csv'.")
