import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Sampler
import argparse
import random
import numpy as np
import random
import os


# First, check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Define a simple Deep Neural Network
class SimpleDNN(nn.Module):
    def __init__(self, dropout_rate=0.5):
        super(SimpleDNN, self).__init__()
        self.fc1 = nn.Linear(28*28, 128)  # Flatten the 28x28 input image
        self.dropout1 = nn.Dropout(dropout_rate) # 50% dropout
        self.fc2 = nn.Linear(128, 64)
        self.dropout2 = nn.Dropout(dropout_rate) # 50% dropout
        self.fc3 = nn.Linear(64, 10)      # 10 classes for MNIST digits

    def forward(self, x):
        x = x.view(-1, 28*28)            # Flatten the image
        x = torch.relu(self.fc1(x))
        x = self.dropout1(x)
        x = torch.relu(self.fc2(x))
        x = self.dropout2(x)
        x = self.fc3(x)
        return x

    def enable_dropout(self):
        """Function to enable the dropout layers during test-time """
        for module in self.modules():
            if module.__class__.__name__.startswith('Dropout'):
                module.train()

    def disable_dropout(self):
        """Function to enable the dropout layers during test-time """
        for module in self.modules():
            if module.__class__.__name__.startswith('Dropout'):
                module.eval()


class SubsetSamper(Sampler):
    def __init__(self, indices):
        self.indices = indices

    def __iter__(self):
        return iter(self.indices)

    def __len__(self):
        return len(self.indices)

if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("--training_seed", type=int, default=0, help="training seed")
    arg = parser.parse_args()

    torch.manual_seed(arg.training_seed)
    random.seed(arg.training_seed)
    np.random.seed(arg.training_seed)

    # Load MNIST data
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (0.5,))])
    train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

    all_index = [i for i in range(5000)]
    all_index_test = [i for i in range(500)]

    random.shuffle(all_index)
    portion_index = np.random.choice(all_index, size=2500, replace=False, p=None)
    sampler = SubsetSamper(portion_index)
    train_loader = DataLoader(train_dataset, batch_size=64, sampler=sampler)

    portion_index_test = np.random.choice(all_index_test, size=500, replace=False, p=None)
    sampler_test = SubsetSamper(portion_index_test)
    test_loader = DataLoader(test_dataset, batch_size=64, sampler=sampler_test)

    with open(os.path.join("./checkpoint", f'selected_indices_seed_{arg.training_seed}_sample_{len(portion_index)}.txt'), 'w') as f:
        for idx in portion_index:
            f.write(str(idx) + '\n')

    # Initialize the model, loss function, and optimizer
    model = SimpleDNN(dropout_rate=0).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

    # Training loop
    epochs = 100
    import time
    st = time.time()
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}")

        model.eval()
        correct = 0
        total = 0
        # No gradient is needed for evaluation
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                # Get the predicted class from the maximum value in the output-list of class scores
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        accuracy = 100 * correct / total
        print(f'Accuracy of the model on the test set: {accuracy:.2f}%')

        correct = 0
        total = 0
        # No gradient is needed for evaluation
        with torch.no_grad():
            for images, labels in train_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                # Get the predicted class from the maximum value in the output-list of class scores
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        accuracy = 100 * correct / total
        print(f'Accuracy of the model on the train set: {accuracy:.2f}%')

    print("Time used:", time.time() - st)
    print("Training complete")
    torch.save(model.state_dict(), f"./checkpoint/checkpoint_{arg.training_seed}_sample_{len(portion_index)}.pt")
