import torch
from torchvision import datasets, transforms
import src.nn.modules as mod
from src.coded_vae import CodedVAE
import os
import random
import numpy as np
from torch.utils.data import DataLoader

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(0)

# GPU
os.environ["CUDA_VISIBLE_DEVICES"]='0'

# Inference type
inf_type='uncoded'
batch_size=128
bits = 8
likelihood = 'ber'

# ---- Load data ---- #
dataset = 'MNIST' # Dataset can be 'MNIST', 'FMNIST', 'CIFAR10', or 'IMAGENET' 

# MNIST # 
if dataset == 'MNIST':

    D = 28*28

    # Download and load the training data
    trainset = datasets.MNIST('./data/', download=False, train=True, transform=transforms.Compose([transforms.ToTensor()]) )
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)  # whole dataset

    # Download and load the test data
    testset = datasets.MNIST('./data/', download=False, train=False, transform=transforms.Compose([transforms.ToTensor()]))
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, worker_init_fn=seed_worker, generator=g)

# FMNIST #
if dataset == 'FMNIST':

    D = 28*28

    # Download and load the training data
    trainset = datasets.FashionMNIST('./data/FMNIST/', download=True, train=True, transform=transforms.Compose([transforms.ToTensor()]))
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

    # Download and load the test data
    testset = datasets.FashionMNIST('./data/FMNIST/', download=True, train=False, transform=transforms.Compose([transforms.ToTensor()]))
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, generator=g)

if dataset == 'CIFAR10':

    D=32*32*3
    # Download and load the training data
    trainset = datasets.CIFAR10(root='./data/CIFAR10/',download=True, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]))
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

    # Download and load the test data
    testset = datasets.CIFAR10('./data/CIFAR10/', download=True, train=False, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]))
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, worker_init_fn=seed_worker, generator=g)

if dataset == 'IMAGENET':

    D=64*64*3
    trainset = datasets.ImageFolder('./data/tiny-imagenet-200/train', transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]))
    trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
    testset = datasets.ImageFolder('./data/tiny-imagenet-200/test', transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]))
    testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

# ---- Get encoder and decoder networks---- #
enc = mod.get_encoder('cnn', bits, dataset)
dec = mod.get_decoder('cnnskip', bits, dataset)

# ---- Declare model ---- #
model = CodedVAE(enc, dec, bits, likelihood=likelihood, beta=15, lr=1e-4, inference=inf_type, seed=0)

# ---- Train model ---- #
elbo_evol, kl_evol, rec_evol = model.train(trainloader, n_epochs=100, n_epochs_wu=0)

# ---- Save model ---- #
model.save('my_checkpoint_uncoded.pt')
print('Model saved!')