import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import numpy as np
import torch.nn.functional as F
import math


def get_loaders(batch_size, dataset = "cifar10"):   
    if(dataset == "cifar10"):
        transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), 
                                            torchvision.transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), 
                                            std=(0.2470, 0.2435, 0.2616))])
        # Loaders for train and test
        cifar_dset_train = torchvision.datasets.CIFAR10('../cifar10_data', train=True, transform=transform, target_transform=None, download=True)
        trainloader = torch.utils.data.DataLoader(cifar_dset_train, batch_size=batch_size, shuffle=True, num_workers=0)

        cifar_dset_test = torchvision.datasets.CIFAR10('../cifar10_data', train=False, transform=transform, target_transform=None, download=True)
        testloader = torch.utils.data.DataLoader(cifar_dset_test, batch_size=batch_size, shuffle=False, num_workers=0)
        return trainloader, testloader
    
    elif(dataset == "mnist"):
        # download and transform train dataset
        transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), 
                                                    torchvision.transforms.Normalize((0.1307,), (0.3081,))])
        # Loaders for train and test
        mnist_dset_train = torchvision.datasets.MNIST('../mnist_data', train=True, transform=transform, target_transform=None, download=True)
        trainloader = torch.utils.data.DataLoader(mnist_dset_train, batch_size=batch_size, shuffle=True, num_workers=0)

        mnist_dset_test = torchvision.datasets.MNIST('../mnist_data', train=False, transform=transform, target_transform=None, download=True)
        testloader = torch.utils.data.DataLoader(mnist_dset_test, batch_size=batch_size, shuffle=False, num_workers=0)
        return trainloader, testloader
    

def evaluateClassification(model, loader, device, printing=False):
    # Evaluate Artificial Neural Network on Classification Task
    model.eval()
    correct = 0
    total_samples = 0
    num_classes = 10
    loss = 0
    criterion = nn.MSELoss()
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device).type(torch.float64), y.to(device)
            y_hat, preh_list, hidden_list = model(x)
            pred = torch.argmax(y_hat, dim=1).squeeze()
            y_one_hot = F.one_hot(y, num_classes=10)
            correct += (y == pred).sum().item()
            loss += criterion(y_hat, y_one_hot.to(device).to(torch.float64))
            total_samples += len(y)

            del x, y, y_hat, preh_list, hidden_list, pred
            torch.cuda.empty_cache()
    
    torch.cuda.empty_cache()
    acc = correct / len(loader.dataset)
    mse = loss / len(loader.dataset)
    
    if printing:
        print(f'Accuracy:\t{acc}')
        print(f'Mean Squared Error:\t{mse}')
    return acc, mse
