import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset, Sampler
import argparse
import os
import numpy as np
import random
from resnet9 import ResNet9
from mlp_dropout import ConvNet


# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


class SubsetSamper(Sampler):
    def __init__(self, indices):
        self.indices = indices

    def __iter__(self):
        return iter(self.indices)

    def __len__(self):
        return len(self.indices)


def get_cifar2_indices_and_adjust_labels(dataset):
    # Selecting 'cat' (class 3) and 'dog' (class 5) and adjusting labels to 0 and 1
    indices = []
    for i in range(len(dataset)):
        if dataset.targets[i] == 3:  # Cat
            indices.append(i)
            dataset.targets[i] = 0
        elif dataset.targets[i] == 5:  # Dog
            indices.append(i)
            dataset.targets[i] = 1
    return indices


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--training_seed", type=int, default=0, help="training seed")
    parser.add_argument("--model", type=str, default="resnet9", help="model to train")
    arg = parser.parse_args()

    torch.manual_seed(arg.training_seed)
    random.seed(arg.training_seed)
    np.random.seed(arg.training_seed)

    # Load CIFAR-10 data and create CIFAR-2 subset
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    cifar2_indices = get_cifar2_indices_and_adjust_labels(train_dataset)
    cifar2_indices_test = get_cifar2_indices_and_adjust_labels(test_dataset)
    all_index = cifar2_indices[:5000]
    all_index_test = cifar2_indices_test[:500]
    random.shuffle(all_index)
    portion_index = np.random.choice(all_index, size=2500, replace=False, p=None)
    portion_index_test = np.random.choice(all_index_test, size=500, replace=False, p=None)
    sampler = SubsetSamper(portion_index)
    sampler_test = SubsetSamper(portion_index_test)
    train_loader = DataLoader(train_dataset, batch_size=64, sampler=sampler)
    test_loader = DataLoader(test_dataset, batch_size=64, sampler=sampler_test)

    with open(os.path.join("./checkpoint",
                           f'selected_indices_seed_{arg.training_seed}.txt'), 'w') as f:
        for idx in portion_index:
            f.write(str(idx) + '\n')

    # Initialize the model, loss function, and optimizer
    if arg.model == "resnet9":
        model = ResNet9(dropout_rate=0).to(device)
    else:
        model = ConvNet(dropout_rate=0).to(device)
    # model = ImageTransformer(dropout_rate=0.5).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

    # Training loop
    epochs = 50
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}")

        model.eval()
        correct = 0
        total = 0
        # No gradient is needed for evaluation
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                # Get the predicted class from the maximum value in the output-list of class scores
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        accuracy = 100 * correct / total
        print(f'Accuracy of the model on the test set: {accuracy:.2f}%')

        correct = 0
        total = 0
        # No gradient is needed for evaluation
        with torch.no_grad():
            for images, labels in train_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                # Get the predicted class from the maximum value in the output-list of class scores
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        accuracy = 100 * correct / total
        print(f'Accuracy of the model on the train set: {accuracy:.2f}%')

    print("Training complete")

    # Save the model
    if not os.path.exists("./checkpoint"):
        os.makedirs("./checkpoint")
    torch.save(model.state_dict(), f"checkpoint/checkpoint_{arg.model}_{arg.training_seed}.pt")
