import torch
from torchvision.datasets import MNIST
import pickle
import argparse


import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

import pandas as pd

from config import columns

binary_list = ['DIABETES', 'GERMAN', 'no2', 'news', 'spam']

# Define the neural network model
class SimpleNN(nn.Module):
    def __init__(self, layers: list, input_dim: int, output_dim: int):
        super(SimpleNN, self).__init__()
        self.layer_list = nn.ParameterList()
        self.layer_list.append(nn.Linear(input_dim, layers[0]))
        self.layer_list.append(nn.ReLU())
        for i in range(len(layers)-1):
            self.layer_list.append(nn.Linear(layers[i], layers[i + 1]))
            self.layer_list.append(nn.ReLU())
        self.layer_list.append(nn.Linear(layers[-1], output_dim))
        self.use_sig = args.dataset in binary_list


    def forward(self, x):
        x = x.view(-1, args.data_dim)  # Flatten the input tensor
        for layer in self.layer_list:
            x = layer(x)
        if self.use_sig:
            return torch.sigmoid(x)
        else:
            return x


class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, data_tensor, target_tensor):
        """
        Args:
            data_tensor (Tensor): Tensor containing the data (features).
            target_tensor (Tensor): Tensor containing the labels (targets).
        """
        self.data = data_tensor
        self.target = target_tensor

    def __len__(self):
        """Returns the total number of samples."""
        return self.data.size(0)

    def __getitem__(self, idx):
        """
        Returns the sample at the given index.

        Args:
            idx (int): Index of the sample to fetch.

        Returns:
            tuple: (data, target) where data is the input and target is the label.
        """
        return self.data[idx], self.target[idx]


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--layer_num', required=True, type=int)
    parser.add_argument('--layer_size', required=True, type=int)
    parser.add_argument('--dataset', default='MNIST', required=True, type=str, help="the dataset to train on")
    parser.add_argument('--ensemble_size', required=True, type=int, help="Number of MLPs in ensemble")
    parser.add_argument('--use_cpu', action='store_true', help="Force torch to use CPU")
    args = parser.parse_args()

    # if not args.use_cpu and torch.cuda.is_available():
    #     device = torch.device('cuda')
    #     torch.set_default_device(device)
    # else:
    device = torch.device('cpu')

    layer_list = [args.layer_size for i in range(args.layer_num)]#args.layers#[512, 512]
    model_path = "Networks/{0}x{1}x{2}.h5".format(layer_list[0], len(layer_list), args.ensemble_size)

    # Load the MNIST dataset
    print("Training an MLP ensemble of {2} with {0} layers of size {1}".format(args.layer_num, args.layer_size, args.ensemble_size))
    if args.dataset == "MNIST":
        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

        trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
        testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
        args.data_dim = 784
        args.output_dim = 10
        trainset = CustomDataset(trainset.data/255., trainset.targets)
        testset = CustomDataset(testset.data/255., testset.targets)
    elif args.dataset == 'GERMAN':
        args.data_dim = 20
        args.output_dim = 1
        data = pd.read_csv("data/german.data", sep=' ', header=None)
        data.columns = columns
        y = data['target'] - 1
        # Features
        x = data.drop('target', axis=1)
        cat_columns = x.select_dtypes(['object']).columns
        x[cat_columns] = x[cat_columns].apply(lambda x: x.astype('category').cat.codes)
        # Ensure the data types are numeric
        for column in x.columns:
            if x[column].dtype == 'object':
                x[column] = pd.Categorical(x[column]).codes
        # Normalise the features
        for column in x.columns:
            column_min = x[column].min()
            column_max = x[column].max()
            x[column] = (x[column] - column_min) / (column_max - column_min)

        x = torch.tensor(x.values, dtype=torch.float32, device=device)
        y = torch.tensor(y.values, dtype=torch.float32, device=device).unsqueeze(-1)
        trainset = CustomDataset(x[:800], y[:800])
        testset = CustomDataset(x[800:], y[800:])
    elif args.dataset == 'DIABETES':
        args.data_dim = 8
        args.output_dim = 1
        with open("data/diabetes_train.h5", 'rb') as f:
            train_raw = pickle.load(f)
        with open("data/diabetes_test.h5", 'rb') as f:
            test_raw = pickle.load(f)
        x_train, y_train = train_raw
        x_test, y_test = test_raw
        y_train = y_train.unsqueeze(-1)
        y_test = y_test.unsqueeze(-1)
        '''Convert to torch dset'''
        trainset = CustomDataset(x_train, y_train)
        testset = CustomDataset(x_test, y_test)

    elif args.dataset == 'no2':
        args.data_dim = 7
        args.output_dim = 1
        with open("data/no2_train.h5", 'rb') as f:
            train_raw = pickle.load(f)
        with open("data/no2_test.h5", 'rb') as f:
            test_raw = pickle.load(f)
        x_train, y_train = train_raw
        x_test, y_test = test_raw
        x_train = x_train.type(torch.float32)
        x_test = x_test.type(torch.float32)
        y_train = y_train.unsqueeze(-1).type(torch.float32)
        y_test = y_test.unsqueeze(-1).type(torch.float32)
        '''Convert to torch dset'''
        trainset = CustomDataset(x_train, y_train)
        testset = CustomDataset(x_test, y_test)

    elif args.dataset == 'news':
        args.data_dim = 58
        args.output_dim = 1
        with open("data/news_train.h5", 'rb') as f:
            train_raw = pickle.load(f)
        with open("data/news_test.h5", 'rb') as f:
            test_raw = pickle.load(f)
        x_train, y_train = train_raw
        x_test, y_test = test_raw
        x_train = x_train.type(torch.float32)
        x_test = x_test.type(torch.float32)
        y_train = y_train.unsqueeze(-1).type(torch.float32)
        y_test = y_test.unsqueeze(-1).type(torch.float32)
        '''Convert to torch dset'''
        trainset = CustomDataset(x_train, y_train)
        testset = CustomDataset(x_test, y_test)

    elif args.dataset == 'spam':
        args.data_dim = 57
        args.output_dim = 1
        with open("data/spam_train.h5", 'rb') as f:
            train_raw = pickle.load(f)
        with open("data/spam_test.h5", 'rb') as f:
            test_raw = pickle.load(f)
        x_train, y_train = train_raw
        x_test, y_test = test_raw
        x_train = x_train.type(torch.float32)
        x_test = x_test.type(torch.float32)
        y_train = y_train.unsqueeze(-1).type(torch.float32)
        y_test = y_test.unsqueeze(-1).type(torch.float32)
        '''Convert to torch dset'''
        trainset = CustomDataset(x_train, y_train)
        testset = CustomDataset(x_test, y_test)



    trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, generator=torch.Generator(device=device))
    testloader = torch.utils.data.DataLoader(testset, batch_size=10000, shuffle=False, generator=torch.Generator(device=device))

    means = [[torch.zeros((layer_list[0], args.data_dim, args.ensemble_size)), torch.zeros(layer_list[0], args.ensemble_size)]]
    if len(layer_list) > 1:
        for i in range(len(layer_list) - 1):
            means.append([torch.zeros((layer_list[i+1], layer_list[i], args.ensemble_size)), torch.zeros(layer_list[i+1], args.ensemble_size) ])
    means.append([torch.zeros((args.output_dim, layer_list[-1], args.ensemble_size)), torch.zeros(args.output_dim, args.ensemble_size)])


    for model_num in range(args.ensemble_size):
        # Initialize the model, loss function, and optimizer
        model = SimpleNN(layer_list, args.data_dim, args.output_dim)
        if args.dataset == 'MNIST':
            criterion = nn.CrossEntropyLoss()
        elif args.dataset in binary_list:
            criterion = nn.BCELoss()
        optimizer = optim.Adam(model.parameters(), lr=0.001)

        # Training the model
        num_epochs = 5
        for epoch in range(num_epochs):
            running_loss = 0.0
            for i, data in enumerate(trainloader):
                inputs, labels = data

                # Zero the parameter gradients
                optimizer.zero_grad()

                # Forward pass
                outputs = model(inputs)
                loss = criterion(outputs, labels)

                # Backward pass and optimize
                loss.backward()
                optimizer.step()

                running_loss += loss.item()
                if i % 100 == 99:  # Print every 100 mini-batches
                    print(f'Epoch [{epoch + 1}, {i + 1}] loss: {running_loss / 100:.3f}')
                    running_loss = 0.0

        print('Finished Training')

        i = 0
        for layer in model.layer_list:
            if isinstance(layer, nn.ReLU):
                continue
            means[i][0][:, :, model_num] = layer.weight.data
            means[i][1][:, model_num] = layer.bias.data
            i += 1

        # Testing the model
        correct = 0
        total = 0
        with torch.no_grad():
            for data in testloader:
                images, labels = data
                outputs = model(images)
                # _, predicted = torch.max(outputs.data, 1)
                if args.dataset == 'MNIST':
                    predicted = torch.argmax(outputs, dim=-1)
                else:
                    predicted = torch.round(outputs)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        print(f'Accuracy of the network on the test images: {100 * correct / total:.2f}%')

    if model_path is not None:
        # torch.save(model, model_path)
        # means = []
        # for layer in model.layer_list:
        #     if isinstance(layer, nn.ReLU):
        #         continue
        #     means.append([layer.weight.data, layer.bias.data])
        with open(model_path, 'wb') as f:
            pickle.dump(means, f)