import random, argparse
from collections import Counter

import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from sklearn.model_selection import StratifiedKFold, StratifiedShuffleSplit
from torch.utils.data import DataLoader, Subset
from helpers import *
import numpy as np

# Setting seeds for reproducibility
os.environ['PYTHONHASHSEED'] = '0'
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)

def calculate_sparsity(model):
    """
    Function to calculate the sparsity of a PyTorch model.
    Sparsity is defined as the ratio of zero-valued parameters to the total number of parameters.

    Parameters:
    model (torch.nn.Module): The PyTorch model for which to calculate the sparsity.

    Returns:
    float: The sparsity of the model.
    """
    # Count the total number of parameters
    total_params = sum(p.numel() for p in model.parameters())

    # Use the existing function to count the number of zero parameters
    zero_params = count_zero_params(model)

    # Calculate and return the sparsity
    sparsity = zero_params / total_params
    return sparsity

def count_zero_params(model):
    """
    Function to count the number of zero parameters in a PyTorch model.

    Parameters:
    model (torch.nn.Module): The PyTorch model for which to count the zero parameters.

    Returns:
    int: The number of zero parameters in the model.
    """
    return sum((param == 0.0).sum().item() for param in model.parameters())

def train_model(model, train_loader, criterion, optimizer, device):
    """
    Function to train a PyTorch model for one epoch.

    Parameters:
    model (torch.nn.Module): The PyTorch model to train.
    train_loader (torch.utils.data.DataLoader): The DataLoader for the training data.
    criterion (torch.nn.Module): The loss function.
    optimizer (torch.optim.Optimizer): The optimizer.
    device (torch.device): The device (CPU or GPU) where the model should be trained.

    Returns:
    float: The average loss for this epoch.
    """
    model.train()
    total_loss = 0

    for images, labels in train_loader:
        optimizer.zero_grad()
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    # Count and print the number of zero parameters
    num_zero_params = count_zero_params(model)
    print(f"Number of zero parameters: {num_zero_params}")
    model_sparsity = calculate_sparsity(model)
    print(f"Model sparsity: {model_sparsity}")

    return total_loss / len(train_loader)

def main(opt):
    """
    Main function to run the training process.

    Parameters:
    opt (argparse.Namespace): Parsed command-line arguments.

    """
    # Extract the command-line arguments
    lr = opt.lr
    momentum = opt.momentum
    train_size_percentage = opt.train_size_percentage
    MODE = opt.mode

    # Check if GPU is available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")

    lst = ['adam', 'sgd']
    for item in lst:
        model_dst = '../runs/'+item+'_mnist_'+str(train_size_percentage)+'/lr'+str(lr)+'momentum'+str(momentum)+'/'
        if not os.path.exists(model_dst):
            try:
                os.makedirs(model_dst)
                print(f"Folder '{model_dst}' created successfully.")
            except OSError as e:
                print(f"Error occurred while creating folder: {e}")
        else:
            print(f"Folder '{model_dst}' already exists.")

        # Define the transformations to be applied to the data
        train_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))])

        test_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))])

        # Load the training and test datasets
        full_train_mnist_dataset = torchvision.datasets.MNIST(root='../data/mnist/', train=True, download=True, transform=train_transform)
        test_mnist_dataset = torchvision.datasets.MNIST(root='../data/mnist/', train=False, download=True, transform=test_transform)

        # Perform 10-fold stratified cross-validation
        num_epochs = 30
        num_folds = 10
        cv_val_accuracy = []
        cv_test_mnist_accuracy = []

        if MODE == 'TRAIN':
            # Perform Stratified K-Fold splitting on the full dataset
            skf = StratifiedKFold(n_splits=num_folds, shuffle=True, random_state=42)

            for fold, (train_index, val_index) in enumerate(
                    skf.split(np.zeros(len(full_train_mnist_dataset)), full_train_mnist_dataset.targets)):
                # Create training and validation subsets
                if train_size_percentage == 100:
                    # Use the full training dataset
                    train_subset = Subset(full_train_mnist_dataset, train_index)
                    val_subset = Subset(full_train_mnist_dataset, val_index)
                elif train_size_percentage < 100:
                    # Now, further subset the training indices to reduce the training size
                    # Calculate the actual number of training samples to include
                    num_train_samples_fold = int(len(train_index) * train_size_percentage / 100)
                    # Convert to a proportion of the total
                    train_size_proportion = num_train_samples_fold / len(train_index)

                    # Subtract a small value to ensure it's less than 1.0
                    train_size_proportion -= 0.0001

                    # Use StratifiedShuffleSplit to maintain class distribution
                    sss = StratifiedShuffleSplit(n_splits=1, train_size=train_size_proportion, random_state=42)
                    subset_train_index, _ = next(
                        sss.split(np.zeros(len(train_index)), full_train_mnist_dataset.targets[train_index]))

                    # Use these subset indices to create training and validation subsets
                    train_subset = Subset(full_train_mnist_dataset, subset_train_index)
                    val_subset = Subset(full_train_mnist_dataset, val_index)

                print(f"Fold {fold + 1}:")
                print(f"Training set size (after subsetting): {len(train_subset)}")
                print(f"Validation set size: {len(val_subset)}")

                # Get the labels from the train_subset
                labels = [label for _, label in train_subset]

                # Count the occurrences of each label
                class_distribution = Counter(labels)

                # Print the class distribution
                for class_label, count in class_distribution.items():
                    print(f"Class {class_label}: {count} samples")

                # Create DataLoaders for the training and validation subsets
                train_loader = DataLoader(train_subset, batch_size=64, shuffle=True)
                val_loader = DataLoader(val_subset, batch_size=64, shuffle=False)

                model = CNNModel200k().to(device)  # Move model to GPU

                # Count the number of parameters
                num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
                print("Number of parameters:", num_params)

                # Initialize the loss function and optimizer
                criterion = nn.CrossEntropyLoss()
                if item == 'adam':
                    optimizer = optim.Adam(model.parameters(), lr=lr, betas=(momentum, 0.999))
                elif item == 'sgd':
                    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

                printer = PrintToFile(model_dst + '/mnist_loss_lr' + str(lr) + 'momentum' + str(momentum) + 'fold' + str(fold) + '.txt')
                printer.start()
                print(f"Fold {fold}/{num_folds}")

                for epoch in range(num_epochs):
                    # Train the model
                    epoch_loss = train_model(model, train_loader, criterion, optimizer, device)
                    print(f"Epoch {epoch}/{num_epochs}, Loss: {epoch_loss}")

                    # Evaluate the model on the validation set
                    val_accuracy, val_loss = evaluate_model(model, criterion, val_loader, device)
                    print(f"Validation Accuracy on epoch {epoch}: {val_accuracy:.10f}%")
                    print(f"Validation Loss on epoch {epoch}: {val_loss:.10f}")
                    print()

                printer.stop()

                # Save the model
                torch.save(model.state_dict(), model_dst+'/model'+str(fold)+'.pth')

                # Evaluate the model on the validation set
                accuracy, _ = evaluate_model(model, criterion, val_loader, device)
                cv_val_accuracy.append(accuracy)
                print(f"Validation Accuracy on Fold {fold + 1}: {accuracy:.10f}%")
                print()


        if MODE == 'TEST':
            for fold in range(num_folds):
                # Initialize your model
                model = CNNModel200k().to(device) # Move model to GPU

                # Initialize the loss function
                criterion = nn.CrossEntropyLoss()

                # Load the saved model's state dictionary
                model.load_state_dict(torch.load(model_dst+'/model'+str(fold)+'.pth', map_location=device))

                # Evaluate the model on the test set
                test_loader = DataLoader(test_mnist_dataset, batch_size=64, shuffle=False)
                test_accuracy, _ = evaluate_model(model, criterion, test_loader, device)
                cv_test_mnist_accuracy.append(test_accuracy)
                print(f"Test MNIST Accuracy on Fold {fold}: {test_accuracy:.5f}%")
                print()

            printer = PrintToFile(model_dst+'/mnist_lr'+str(lr)+'momentum'+str(momentum)+'.txt')
            printer.start()

            average_cv_test_mnist_accuracy = sum(cv_test_mnist_accuracy) / len(cv_test_mnist_accuracy)
            print(f"Average Cross-Validation Test MNIST Accuracy: {average_cv_test_mnist_accuracy:.10f}%")

            printer.stop()

def parse_opt():
    """
    Function to parse command-line arguments.

    Returns:
    argparse.Namespace: Parsed command-line arguments.
    """
    parser = argparse.ArgumentParser()
    parser.add_argument('--lr', type=float, help='learning rate') # Add learning rate argument e.g. 0.01
    parser.add_argument('--momentum', type=float, help='momentum') # Add momentum argument e.g. 0.9
    parser.add_argument('--train_size_percentage', type=int, help='percentage of training data to use',
                        default=100)  # Default to using the full dataset
    parser.add_argument('--mode', type=str, help='train mode or test mode', default='TRAIN') # Default to training mode, TRAIN or TEST
    opt = parser.parse_args()
    return opt


if __name__ == '__main__':
    """
    Main entry point of the script.
    """
    opt = parse_opt()
    main(opt)