import numpy as np
import torch 
from opacus import PrivacyEngine
from opacus.utils.batch_memory_manager import BatchMemoryManager
from torch.utils.data import DataLoader
from dataclasses import dataclass

from typing import Tuple, List

from .mironov_rdp_accounting import get_noise_multiplier, compute_adp_epsilon_from_accountant, compute_adp_epsilon_from_accountant_with_threshold_check
from src.basic_mechanisms import gaussian_mechanism_rdp_epsilon, gaussian_mechanism_rdp_variance


# From Tobaben et al. (2023) "On the Efficacy of Differentially Private Few-shot Image Classification" TMLR
CUSTOM_ALPHAS = [1.01, 1.05] + [1 + x / 10.0 for x in range(1, 100)] + list(range(12, 256))


@dataclass
class DPSGDHyperparameters:
    learning_rate: float
    max_grad_norm: float
    num_epochs: int
    batch_size: int


@dataclass
class AccuracyFirstTrainingResult:
    validation_accuracy: float
    test_accuracy: float | None
    additional_epochs: int
    epsilon: float


class AccuracyFirstDPSGD:
    def __init__(self, model, hyperparameters: DPSGDHyperparameters, device, delta: float):
        self.model = model
        self.hyperparameters = hyperparameters
        self.device = device
        self.privacy_engine = None
        self.loss = None
        self.train_loader = None
        self.optimizer = None
        self.delta = delta

    def initial_train(self, train_dataset, initial_epsilon: float, seed: int):
        self.model, self.privacy_engine, self.optimizer, self.loss, self.train_loader = train_with_dp(
            model=self.model,
            train_dataset=train_dataset,
            learning_rate=self.hyperparameters.learning_rate,
            max_grad_norm=self.hyperparameters.max_grad_norm,
            num_epochs=self.hyperparameters.num_epochs,
            batch_size=self.hyperparameters.batch_size,
            max_physical_batch_size=128,
            epsilon=initial_epsilon,
            delta=self.delta,
            device=self.device,
            seed=seed
        )

    def continue_training(self, additional_epochs: int):
        dp_training_loop(
            model=self.model,
            train_loader=self.train_loader,
            optimizer=self.optimizer,
            loss=self.loss,
            num_epochs=additional_epochs,
            max_physical_batch_size=128,
            device=self.device
        )

    def accuracy_first_train(
        self, accuracy_threshold: float, train_dataset, validation_dataset, 
        test_dataset,
        initial_epsilon: float, max_releases: int, epochs_per_release: int, 
        public_validation_set: bool, seed: int):
        if max_releases <= 0:
            raise ValueError(f"max_additional_epochs must be positive, got {max_releases}")
        if epochs_per_release <= 0:
            raise ValueError(f"accuracy_check_interval must be positive, got {epochs_per_release}")

        self.initial_train(train_dataset=train_dataset, initial_epsilon=initial_epsilon, seed=seed)

        epsilon, best_alpha, best_rdp_epsilon = self.get_adp_epsilon(return_best_alpha=True)
        if epsilon > initial_epsilon:
            raise RuntimeError(f"Initial training exceeded initial_epsilon: epsilon={epsilon}, initial_epsilon={initial_epsilon}")

        learning_rate_reduction_after_initial_training = 10
        reduce_learning_rate(self.optimizer, learning_rate_reduction_after_initial_training)
        
        threshold_check_variance = 0
        threshold_check_rdp_epsilons = []
        if not public_validation_set:
            threshold_check_sensitivity = 1.0 / len(validation_dataset) * np.sqrt(max_releases - 1)
            threshold_check_epsilon = best_rdp_epsilon
            threshold_check_variance = gaussian_mechanism_rdp_variance(
                epsilon=threshold_check_epsilon,
                alpha=best_alpha,
                sensitivity=threshold_check_sensitivity
            )
            threshold_check_rdp_epsilons = [
                gaussian_mechanism_rdp_epsilon(np.sqrt(threshold_check_variance), alpha, threshold_check_sensitivity)
                for alpha in CUSTOM_ALPHAS
            ]

        additional_epochs = 0
        releases = []
        for i in range(max_releases):
            validation_accuracy = evaluate(self.model, validation_dataset, self.device)
            if not public_validation_set:
                current_epsilon = self.get_adp_epsilon_with_threshold_check(threshold_check_rdp_epsilons)
            else:
                current_epsilon = self.get_adp_epsilon()

            test_accuracy = None
            if test_dataset is not None:
                test_accuracy = evaluate(self.model, test_dataset, self.device)
            releases.append(AccuracyFirstTrainingResult(
                validation_accuracy=validation_accuracy,
                test_accuracy=test_accuracy,
                additional_epochs=additional_epochs,
                epsilon=current_epsilon,
            ))

            if i == max_releases - 1: # Skip threshold check and additional training on last iteration
                break

            if public_validation_set:
                noisy_validation_accuracy = validation_accuracy
            else:
                noisy_validation_accuracy = validation_accuracy + np.random.normal(0, np.sqrt(threshold_check_variance))
            if noisy_validation_accuracy >= accuracy_threshold:
                break

            self.continue_training(additional_epochs=epochs_per_release)
            additional_epochs += epochs_per_release

        return releases, threshold_check_variance

    def get_adp_epsilon(self, return_best_alpha=False) -> float | Tuple[float, float, float]:
        if self.privacy_engine is None:
            raise ValueError("Model has not been trained with DP yet.")
        
        return compute_adp_epsilon_from_accountant(
            accountant=self.privacy_engine.accountant,
            delta=self.delta,
            alphas=CUSTOM_ALPHAS,
            return_best_alpha=return_best_alpha,
        )

    def get_adp_epsilon_with_threshold_check(self, threshold_check_rdp_epsilons) -> float:
        if self.privacy_engine is None:
            raise ValueError("Model has not been trained with DP yet.")

        return compute_adp_epsilon_from_accountant_with_threshold_check(
            accountant=self.privacy_engine.accountant,
            delta=self.delta,
            alphas=CUSTOM_ALPHAS,
            threshold_check_rdp_epsilons=threshold_check_rdp_epsilons
        )



def train_with_dp(
    model, train_dataset, learning_rate, max_grad_norm, num_epochs,
    batch_size, max_physical_batch_size, 
    epsilon, delta, device, seed
    ):
    seeded_generator = torch.Generator(device=device)
    seeded_generator.manual_seed(seed)

    privacy_engine = PrivacyEngine(accountant="rdp", secure_mode=False)
    # optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    loss = torch.nn.functional.cross_entropy
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, generator=seeded_generator)

    model, optimizer, train_loader = privacy_engine.make_private(
        module=model,
        optimizer=optimizer,
        data_loader=train_loader,
        poisson_sampling=True,
        noise_multiplier=get_noise_multiplier(
            target_epsilon=epsilon,
            target_delta=delta,
            sample_rate=1 / len(train_loader),
            epochs=num_epochs,
            alphas=CUSTOM_ALPHAS,
        ),
        epochs=num_epochs,
        max_grad_norm=max_grad_norm,
        noise_generator=seeded_generator,
        loss_reduction="mean",
        alphas=CUSTOM_ALPHAS,
    )

    dp_training_loop(
        model=model,
        train_loader=train_loader,
        optimizer=optimizer,
        loss=loss,
        num_epochs=num_epochs,
        max_physical_batch_size=max_physical_batch_size,
        device=device
    )

    return model, privacy_engine, optimizer, loss, train_loader


def dp_training_loop(
    model, train_loader, optimizer,
    loss, num_epochs, max_physical_batch_size, device
    ):
    model.train()
    for epoch in range(num_epochs):
        with BatchMemoryManager(
                data_loader=train_loader,
                max_physical_batch_size=max_physical_batch_size,
                optimizer=optimizer
        ) as new_train_loader:
            for batch_features, batch_labels in new_train_loader:
                batch_features = batch_features.to(device)
                batch_labels = batch_labels.to(device)
                optimizer.zero_grad()
                torch.set_grad_enabled(True)
                logits = model(batch_features)
                l = loss(logits, batch_labels)
                l.backward()
                del logits
                optimizer.step()


def evaluate(model, test_dataset, device):
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for features, labels in test_loader:
            features = features.to(device)
            labels = labels.to(device)
            logits = model(features)
            _, predicted = torch.max(logits.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = correct / total
    return accuracy


def reduce_learning_rate(optimizer, reduction_factor):
    for param_group in optimizer.param_groups:
        param_group["lr"] /= reduction_factor