

import torch
import torch.nn as nn
from torchvision import datasets, transforms
import numpy as np
from torchvision.utils import save_image
import time
import os 
import sys
"""
add vqvae and pixelcnn dirs to path
make sure you run from vqvae directory
"""
current_dir = sys.path.append(os.getcwd())
pixelcnn_dir = sys.path.append(os.getcwd()+ '/pixelcnn')

from pixelcnn.models import GatedPixelCNN
import util

"""
Hyperparameters
"""
import argparse 

parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--is_label", type=int)
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--log_interval", type=int, default=100)
parser.add_argument("-save", action="store_true")
parser.add_argument("-gen_samples", action="store_true")
parser.add_argument('--model_name', type=str, default='WS')
parser.add_argument('--device', type=int, default=0)



parser.add_argument("--dataset",  type=str, default='LATENT_BLOCK',
    help='accepts CIFAR10 | MNIST | FashionMNIST | LATENT_BLOCK')
parser.add_argument('--data_file_path', type=str, default='WS')

parser.add_argument("--num_workers", type=int, default=4)
parser.add_argument("--img_dim", type=int, default=8)
parser.add_argument("--input_dim", type=int, default=1,
    help='1 for grayscale 3 for rgb')
parser.add_argument("--n_embeddings", type=int, default=512,
    help='number of embeddings from VQ VAE')
parser.add_argument("--n_layers", type=int, default=15)
parser.add_argument("--learning_rate", type=float, default=3e-4)


def train():
    train_loss = []
    for batch_idx, (x, label) in enumerate(train_loader):
        start_time = time.time()
        
        if args.dataset == 'LATENT_BLOCK':

            x = (x[..., 0]).to(device)

        else:
            x = (x[..., 0] * (K-1)).long().to(device)
        label = label.to(device)
        
        # Train PixelCNN with images
        logits = model(x, label)
        logits = logits.permute(0, 2, 3, 1).contiguous()

        loss = criterion(
            logits.view(-1, args.n_embeddings),
            x.view(-1)
        )

        opt.zero_grad()
        loss.backward()
        opt.step()

        train_loss.append(loss.item())

        if (batch_idx + 1) % args.log_interval == 0:
            print('\tIter: [{}/{} ({:.0f}%)]\tLoss: {} Time: {}'.format(
                batch_idx * len(x), len(train_loader.dataset),
                args.log_interval * batch_idx / len(train_loader),
                np.asarray(train_loss)[-args.log_interval:].mean(0),
                time.time() - start_time
            ))


def test():
    start_time = time.time()
    val_loss = []
    with torch.no_grad():
        for batch_idx, (x, label) in enumerate(test_loader):
            if args.dataset == 'LATENT_BLOCK':
                x = (x[..., 0]).to(device)
            else:
                x = (x[..., 0] * (args.n_embeddings-1)).long().to(device)
            label = label.to(device)
            logits = model(x, label)
            
            
            logits = logits.permute(0, 2, 3, 1).contiguous()
            loss = criterion(
                logits.view(-1, args.n_embeddings),
                x.view(-1)
            )
            
            val_loss.append(loss.item())

    print('Validation Completed!\tLoss: {} Time: {}'.format(
        np.asarray(val_loss).mean(0),
        time.time() - start_time
    ))
    return np.asarray(val_loss).mean(0)


def generate_samples(epoch):
    label = torch.arange(10).expand(10, 10).contiguous().view(-1)
    label = label.long().to(device)

    x_tilde = model.generate(label, shape=(args.img_dim,args.img_dim), batch_size=100)
    
    print(x_tilde[0])



if __name__ == "__main__":

    args = parser.parse_args()
    device_name = f"cuda:{args.device}" 
    device = torch.device(device_name)

    """
    data loaders
    """
    if args.dataset == 'LATENT_BLOCK':
        _, _, train_loader, test_loader, _ = util.load_data_and_data_loaders_LATENT_BLOCK('LATENT_BLOCK', args.data_file_path, args.batch_size, args.is_label)
    else:
        train_loader = torch.utils.data.DataLoader(
            eval('datasets.'+args.dataset)(
                '../data/{}/'.format(args.dataset), train=True, download=True,
                transform=transforms.ToTensor(),
            ), batch_size=args.batch_Size, shuffle=False,
            num_workers=args.num_workers, pin_memory=True
        )
        test_loader = torch.utils.data.DataLoader(
            eval('datasets.'+args.dataset)(
                '../data/{}/'.format(args.dataset), train=False,
                transform=transforms.ToTensor(),
            ), batch_size=args.batch_size, shuffle=False,
            num_workers=args.num_workers, pin_memory=True
        )

    model = GatedPixelCNN(args.n_embeddings, args.img_dim**2, args.n_layers).to(device)
    criterion = nn.CrossEntropyLoss().to(device)
    opt = torch.optim.Adam(model.parameters(), lr=args.learning_rate)


    """
    train, test, and log
    """

    BEST_LOSS = 999
    LAST_SAVED = -1
    for epoch in range(1, args.epochs):
        print("\nEpoch {}:".format(epoch))
        train()
        cur_loss = test()

        if args.save or cur_loss <= BEST_LOSS:
            BEST_LOSS = cur_loss
            LAST_SAVED = epoch

            print("Saving model!")
            if args.is_label:
                save_is_label = 'True'
            else:
                save_is_label = 'False'

            print('results/{}_{}_label_{}_pixelcnn.pt'.format(args.model_name,str(args.epochs),save_is_label))

            torch.save(model.state_dict(), '/home/ubuntu/sqvae/latent_results/{}_{}_label_{}_pixelcnn.pt'.format(args.model_name,str(args.epochs),save_is_label))
        else:
            print("Not saving model! Last saved: {}".format(LAST_SAVED))
        if args.gen_samples:
            generate_samples(epoch)
