from torch.utils.data import DataLoader, Dataset, TensorDataset
import torchvision
import torchvision.transforms as transforms
import torch
import numpy as np

class SumOfSinsDataset(Dataset):
    def __init__(self, x, amplitudes, ws, phases):
        self.x = x
        self.amplitudes = amplitudes
        self.ws = ws
        self.phases = phases

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return self.x[idx], sum([a * np.sin(w * self.x[idx] + p) for a, w, p in zip(self.amplitudes, self.ws, self.phases)])


def get_mnist_dataloaders(args):
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)), transforms.Lambda(lambda x: torch.flatten(x))])
    trainset = torchvision.datasets.MNIST(root=args.datadir, train=True, download=True, transform=transform)
    testset = torchvision.datasets.MNIST(root=args.datadir, train=False, download=True, transform=transform)

    trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
    testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)
    return trainloader, testloader

def get_fashion_mnist_dataloaders(args):
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)), transforms.Lambda(lambda x: torch.flatten(x))])
    trainset = torchvision.datasets.FashionMNIST(root=args.datadir, train=True, download=True, transform=transform)
    testset = torchvision.datasets.FashionMNIST(root=args.datadir, train=False, download=True, transform=transform)

    trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
    testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)
    return trainloader, testloader

def get_kmnist_dataloaders(args):
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)), transforms.Lambda(lambda x: torch.flatten(x))])
    trainset = torchvision.datasets.KMNIST(root=args.datadir, train=True, download=True, transform=transform)
    testset = torchvision.datasets.KMNIST(root=args.datadir, train=False, download=True, transform=transform)

    trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
    testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)
    return trainloader, testloader

def get_afro_mnist_dataloaders(name, data_dir):
    X_train = np.load(f'{data_dir}/{name}/{name}_MNIST_X_train.npy')
    X_test = np.load(f'{data_dir}/{name}/{name}_MNIST_X_test.npy')

    y_train = np.load(f'{data_dir}/{name}/{name}_MNIST_y_train.npy')
    y_test = np.load(f'{data_dir}/{name}/{name}_MNIST_y_test.npy')

    normalize_transform = transforms.Normalize((128,), (128,))

    X_train = normalize_transform(torch.from_numpy(X_train)).float().flatten(start_dim=1)
    X_test = normalize_transform(torch.from_numpy(X_test)).float().flatten(start_dim=1)

    y_train = torch.from_numpy(y_train).to(torch.long)
    y_test = torch.from_numpy(y_test).to(torch.long)
    
    trainset = TensorDataset(X_train, y_train)
    testset = TensorDataset(X_test, y_test)

    trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
    testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)

    return trainloader, testloader

def get_kannada_dataloaders(args):
    X_train = np.load(f'{args.datadir}/Kannada_MNIST/X_kannada_MNIST_train.npz')['arr_0']
    X_test = np.load(f'{args.datadir}/Kannada_MNIST/X_kannada_MNIST_test.npz')['arr_0']
    y_train = np.load(f'{args.datadir}/Kannada_MNIST/y_kannada_MNIST_train.npz')['arr_0']
    y_test = np.load(f'{args.datadir}/Kannada_MNIST/y_kannada_MNIST_test.npz')['arr_0']

    normalize_transform = transforms.Normalize((128,), (128,))

    X_train = normalize_transform(torch.from_numpy(X_train).float()).flatten(start_dim=1)
    X_test = normalize_transform(torch.from_numpy(X_test).float()).flatten(start_dim=1)

    y_train = torch.from_numpy(y_train).to(torch.long)
    y_test = torch.from_numpy(y_test).to(torch.long)
    
    trainset = TensorDataset(X_train, y_train)
    testset = TensorDataset(X_test, y_test)

    trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
    testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)

    return trainloader, testloader


def get_dataloaders(args):
    if args.dataset == 'mnist':
        return get_mnist_dataloaders(args)
    elif args.dataset == 'fashion_mnist':
        return get_fashion_mnist_dataloaders(args)
    elif args.dataset == 'kmnist':
        return get_kmnist_dataloaders(args)
    elif args.dataset == 'ethiopic':
        return get_afro_mnist_dataloaders('Ethiopic', args.datadir)
    elif args.dataset == 'nko':
        return get_afro_mnist_dataloaders('NKo', args.datadir)
    elif args.dataset == 'osmanya':
        return get_afro_mnist_dataloaders('Osmanya', args.datadir)
    elif args.dataset == 'vai':
        return get_afro_mnist_dataloaders('Vai', args.datadir)
    elif args.dataset == 'kannada':
        return get_kannada_dataloaders(args)
    else:
        raise ValueError(f'Dataset {args.dataset} not recognized')

