import torch.optim  as optim
import warnings
warnings.filterwarnings('ignore')
from torch.utils.data import Dataset
import torch.nn as nn
import torch
import numpy as np
from opacus import PrivacyEngine
from torch.utils.data import Subset, DataLoader
from Audit_TTT import EnhancedPrivacyEvaluator
class Cifar10FeatureDataset(Dataset):
    def __init__(self, features, targets, transform=None):
        self.features = features
        self.targets = targets
        self.transform = transform

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        feature = self.features[idx]
        target = self.targets[idx]

        if self.transform:
            feature = self.transform(feature)

        return {
            'feature': torch.FloatTensor(feature),
            'target': torch.LongTensor([target])
        }
class SimpleNN(nn.Module):
    def __init__(self, input_size=256, output_size=10):
        super(SimpleNN, self).__init__()
        # 创建全连接层
        self.fc = nn.Linear(input_size, output_size, bias=False).cuda()
        nn.init.xavier_uniform_(self.fc.weight, .1)
        for p in self.fc.parameters():
            p.requires_grad_(True)

    def forward(self, x):
        # self.fc.weight = self.fc.weight.to(x.dtype)
        x = x.to(torch.float32)
        x = self.fc(x.cuda())
        return x

class EnhancedClassifier(nn.Module):
    def __init__(self, input_dim=256, hidden_dims=[512, 256], output_dim=10, dropout=0.3):
        super(EnhancedClassifier, self).__init__()

        # Feature projection layers
        layers = []
        prev_dim = input_dim
        for hidden_dim in hidden_dims:
            layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                # nn.BatchNorm1d(hidden_dim),
                nn.ReLU(),
                nn.Dropout(dropout)
            ])
            prev_dim = hidden_dim

        self.feature_encoder = nn.Sequential(*layers)

        # Final projection to target dimension
        self.final_proj = nn.Linear(prev_dim, output_dim)
        self.layer_norm = nn.LayerNorm(output_dim)

    def forward(self, x):
        features = self.feature_encoder(x)
        features = self.final_proj(features)
        features = self.layer_norm(features)
        return features
# 768 dim embedding features
# features_ = np.load('./vit_trained_features_768.npy')
# targets_ = np.load('./vit_trained_labels_768.npy')
# dataset_train = Cifar10FeatureDataset(features_, targets_)
# features_ = np.load('./vit_trained_feature_test_768.npy')
# targets_ = np.load('./vit_trained_labels_test_768.npy')
# dataset_val = Cifar10FeatureDataset(features_, targets_)

# 256 dim emebedding features
features_ = np.load('./trained_features_256.npy')
targets_ = np.load('./trained_labels_256.npy')
dataset_train = Cifar10FeatureDataset(features_, targets_)
features_ = np.load('./trained_feature_test_256.npy')
targets_ = np.load('./trained_labels_test_256.npy')
dataset_val = Cifar10FeatureDataset(features_, targets_)

# raw features
# features_ = np.load('./cifar10_features/train_features.npy')
# targets_ = np.load('./cifar10_features/train_labels.npy')
# dataset_train = Cifar10FeatureDataset(features_, targets_)
# train_size = int(0.8 * len(dataset_train))
# val_size = len(dataset_train) - train_size
# generator = torch.Generator().manual_seed(42)
# dataset_train, dataset_val = torch.utils.data.random_split(
#     dataset_train,
#     [train_size, val_size],
#     generator=generator  # Pass the generator
# )
# dataset_train, dataset_val = torch.utils.data.random_split(dataset_train, [train_size, val_size])

train_loader = DataLoader(dataset_train, batch_size=256, shuffle=True, num_workers=1)
val_loader = DataLoader(dataset_val, batch_size=256, shuffle=False, num_workers=1)

device = torch.device("cuda"  if torch.cuda.is_available()  else "cpu")
model = SimpleNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
scheduler = optim.lr_scheduler.StepLR(optimizer,  step_size=35, gamma=0.1)

# for DP-SGD combination
# privacy_engine = PrivacyEngine()
# target_epsilon = 1
# model, optimizer, train_loader = privacy_engine.make_private_with_epsilon(
#     module=model,
#     optimizer=optimizer,
#     data_loader=train_loader,
#     epochs=10,
#     target_epsilon=target_epsilon,
#     target_delta=1e-5,
#     max_grad_norm=1,
# )
# print(optimizer.noise_multiplier)

def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, device, num_epochs=20):
    best_val_acc = 0.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch + 1}/{num_epochs}')
        print('-' * 10)
        # 训练阶段
        model.train()
        running_loss = 0.0
        running_corrects = 0

        for inputs in train_loader:
            features = inputs['feature'].to(device)
            labels = inputs['target'].squeeze().to(device)
            # 前向传播
            outputs = model(features)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            probs = torch.softmax(outputs, dim=1)
            preds = torch.argmax(probs, dim=1)

            running_loss += loss.item() * features.size(0)
            running_corrects += torch.sum(preds == labels).item()

        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_acc = running_corrects / len(train_loader.dataset) * 100

        print(f'Train Loss: {epoch_loss:.4f} | Train Acc: {epoch_acc:.2f}%')

        # print(f'{privacy_engine.get_epsilon(1e-5)} with delta {1e-5}')

        # 验证阶段
        model.eval()
        val_running_loss = 0.0
        val_running_corrects = 0

        with torch.no_grad():
            for inputs in val_loader:
                features = inputs['feature'].to(device)
                labels = inputs['target'].squeeze().to(device)
                # 前向传播
                outputs = model(features)
                probs = torch.softmax(outputs, dim=1)
                preds = torch.argmax(probs, dim=1)

                # running_loss += loss.item() * features.size(0)
                # running_corrects += torch.sum(preds == labels).item()

                # val_running_loss += loss.item() * features.size(0)
                val_running_corrects += torch.sum(preds == labels).item()

        # val_epoch_loss = val_running_loss / len(val_loader.dataset)
        val_epoch_acc = val_running_corrects / len(val_loader.dataset) * 100
        print(f'Val Acc: {val_epoch_acc:.2f}%')
        scheduler.step()
        if val_epoch_acc > best_val_acc:
            best_val_acc = val_epoch_acc
            torch.save(model.state_dict(), './best_model.pth')
    print('Training complete!')
    print(f'Best validation accuracy: {best_val_acc:.2f}%')

def create_simple_mia_datasets(dataset_train, dataset_test, target_size=10000):
    """
    Simple training-only split without stratification
    """
    train_size = len(dataset_train)

    if train_size < 2 * target_size:
        target_size = train_size // 2
        print(f"Adjusting target_size to {target_size}")

    # Simple random split without stratification
    indices = list(range(train_size))
    np.random.shuffle(indices)

    member_indices = indices[:target_size]
    non_member_indices = indices[target_size:target_size * 2]

    member_subset = Subset(dataset_train, member_indices)
    non_member_subset = Subset(dataset_train, non_member_indices)

    print(f"Created MIA datasets:")
    print(f"  Members: {len(member_subset)} samples")
    print(f"  Non-members: {len(non_member_subset)} samples")

    return member_subset, non_member_subset

import pickle
if __name__ == '__main__':
    # use parameter-free pre-trained prediction be better for privacy
    # the class pdf are stored in the pkl based on the 40000 split of training data

    with open("classes_kdes.pkl", "rb") as f:
        classes_kdes = pickle.load(f)
    member_subset, non_member_dataset = create_simple_mia_datasets(
        dataset_train, None, target_size=10000)
    member_loader = DataLoader(member_subset, batch_size=256, shuffle=True, num_workers=4)
    # non_member_loader = DataLoader(non_member_dataset, batch_size=256, shuffle=True, num_workers=4)
    non_member_loader = val_loader

    # use classification layer training as target model
    # train_model(model, member_loader, non_member_loader, criterion, optimizer, scheduler, device, num_epochs=10)
    # evaluator = EnhancedPrivacyEvaluator(model, device, num_classes=10)
    #
    # results = evaluator.run_complete_privacy_audit(
    #     member_loader=member_loader,
    #     non_member_loader=non_member_loader,
    #     n_splits=5,
    #     n_bootstraps=1000
    # )
    # # audit_results = results['audit_results']


    evaluator = EnhancedPrivacyEvaluator(None, device, num_classes=10, kdes=classes_kdes)
    results = evaluator.run_complete_privacy_audit(
        member_loader=member_loader,
        non_member_loader=non_member_loader,
        n_splits=5,
        n_bootstraps=1000
    )
    print(results['audit_results']['results']['MLP-Enhanced']['non_member_acc'])
