import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader, Dataset
from vanilla_vae import VanillaVAE
import argparse


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--env-id", type=str, default='RoadRunner-v5')

    parser.add_argument("--load-dataset-name", type=str, default=None)
    parser.add_argument("--save-model-name", type=str, default='')
    args = parser.parse_args()
    return args

class MyDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset['obs'])

    def __getitem__(self, idx):
        return self.dataset['obs'][idx]

if __name__ == '__main__':
    args = parse_args()
    if args.load_dataset_name is None:
        args.load_dataset_name = args.env_id + '_'
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if args.env_id == 'RoadRunner-v5':
        num_actions = 18
    elif args.env_id == 'Riverraid-v5':
        num_actions = 18
    elif args.env_id =='SpaceInvaders-v5':
        num_actions = 6
    else:
        raise TypeError('env-id not recognized')

    dataset_raw = torch.load('dataset/'+args.load_dataset_name)
    # dataset = MyDataset
    batch_size = 64
    learning_rate = 0.001
    num_epochs = 50000
    data_loader = torch.utils.data.DataLoader(dataset=dataset_raw['obs'],
                                              batch_size=batch_size,
                                              shuffle=True)


    model = VanillaVAE(4,num_actions, hidden_dims = [32, 64, 128, 256, 512], out_dim_temp = 3).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    reconst_loss = 1e8
    for epoch in range(num_epochs):
        for i, x in enumerate(data_loader):
            # Forward pass
            x = x.to(device) #[:,:,:64,:64]
            x_reconst, _, mu, log_var, _ = model(x)

            results = model.forward(x)
            train_loss = model.loss_function(*results,
                                                  M_N=1)  # al_img.shape[0]/ self.num_train_imgs

            loss = train_loss['loss']
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if (i + 1) % 10 == 0:
                print("Epoch[{}/{}], Step [{}/{}], Reconst Loss: {:.4f}, KL Div: {:.4f}"
                      .format(epoch + 1, num_epochs, i + 1, len(data_loader), train_loss['Reconstruction_Loss'].item(), train_loss['KLD'].item()))
        if epoch % 20 == 0:
            torch.save(model.state_dict(), 'saved_models/VAE_'+args.env_id + '_' + args.save_model_name)
            print ('model saved: ', 'saved_models/VAE_'+args.env_id + '_' + args.save_model_name)
            reconst_loss = train_loss['Reconstruction_Loss'].item()

            plt.imshow(results[0].data.cpu()[0,0,:,:])
            plt.savefig('figs/VAE_'+args.env_id+'_reconst.jpg')
            plt.imshow(results[1].data.cpu()[0,0,:,:])
            plt.savefig('figs/VAE_'+args.env_id+'_ori.jpg')