import argparse
import torch
import os
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Subset, DataLoader, Dataset, random_split

from tqdm import tqdm
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import MultiStepLR
from cuda_selector import auto_cuda

torch.manual_seed(123)

parser = argparse.ArgumentParser(description='Class conditional regularized WAE-MMD on MNIST')
parser.add_argument('--batch_size', type=int, default=100, metavar='n')
parser.add_argument('--epochs', type=int, default=500)
parser.add_argument('--lr', type=float, default=5e-3)
parser.add_argument('--z_dim', type=int, default=20, help='hidden dimension of z (default: 8)')
parser.add_argument('--reg_lambda', type=float, default=100)
parser.add_argument('--z_var', type=float, default=1.414)
parser.add_argument("--data_dir", type=str, default="dataset", help="directory of CelebA")
parser.add_argument("--save_dir", type=str, default="models/", help="directory to save models")

# ------------------- Params -------------------------------
args = parser.parse_args()

# ------------------- Functions -------------------------------
def mmd_loss_imq(z1, z2):
    N = z.shape[0]
    k_z = imq_kernel(z1, z1)
    k_z_prior = imq_kernel(z2, z2)
    k_cross = imq_kernel(z1, z2)


    mmd_z = (k_z - k_z.diag().diag()).sum() / ((N - 1) * N)
    mmd_z_prior = (k_z_prior - k_z_prior.diag().diag()).sum() / ((N - 1) * N)
    mmd_cross = k_cross.sum() / (N ** 2)

    mmd_loss = mmd_z + mmd_z_prior - 2 * mmd_cross
    return mmd_loss

def imq_kernel(z1, z2):
    """Returns a matrix of shape [batch x batch] containing the pairwise kernel computation"""

    Cbase = (
            2.0 * args.z_dim * args.z_var ** 2
    )

    k = 0

    for scale in [.1, .2, .5, 1., 2., 5., 10.]:
        C = scale * Cbase
        k += C / (C + torch.norm(z1.unsqueeze(1) - z2.unsqueeze(0), dim=-1) ** 2)

    return k


# ------------------- Select Device -------------------------------
cuda_available = torch.cuda.is_available()
save_dir = args.save_dir
data_root = args.data_dir #'/home/kychong/Documents/dataset'
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

if cuda_available:
    device_id = int(auto_cuda()[-1])
    torch.cuda.set_device(device_id)
    device = torch.device('cuda')
    print('Using cuda:{}'.format(torch.cuda.current_device()))

else:
    device = 'cpu'
    print('Using CPU')


# ------------------- Dataset -------------------------------
class TransformDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        img, target = self.dataset.__getitem__(index)
        x = img.detach().clone()
        x[:, 14:, :14] = 0
        y = img[:, 14:, :14].detach().clone()
        return x, y, target


data_root = args.data_dir #'/home/kychong/Documents/dataset'

full_dataset = datasets.MNIST(
    root=data_root,
    train=True,
    transform=transforms.ToTensor(),
    download=True
)
# ------------------- Target -------------------------------

train_loader = DataLoader(TransformDataset(full_dataset), batch_size=args.batch_size, shuffle=True)

# ------------------- Models -------------------------------
from models import MNISTEncoder, MNISTDecoder
encoder = MNISTEncoder(z_dim=args.z_dim).to(device)
decoder = MNISTDecoder(z_dim=args.z_dim).to(device)

# Optimizers
enc_optim = optim.Adam(encoder.parameters(), lr = args.lr)
dec_optim = optim.Adam(decoder.parameters(), lr = args.lr)

enc_scheduler = MultiStepLR(enc_optim, milestones=[100, 150, 200], gamma=0.5)
dec_scheduler = MultiStepLR(dec_optim, milestones=[100, 150, 200], gamma=0.5)

mse = nn.MSELoss()
best_val_loss = float('inf')
patience = 5
wait = 0
best_model_weights = None
train_loss = 0
val_loss = 0

with tqdm(range(args.epochs)) as pbar:
    for i in pbar:
        encoder.train()
        decoder.train()
        train_loss = 0
        count = 0
        for x, _ , _ in train_loader:
            recon_loss = 0
            mmd_loss = 0

            x = x.to(device)
            z = encoder(x)
            x_recon = decoder(z)
            recon_loss += mse(x_recon, x)

            z_prior2 = torch.randn_like(z)
            mmd_loss += args.reg_lambda * mmd_loss_imq(z, z_prior2)

            loss = recon_loss + mmd_loss
            enc_optim.zero_grad()
            dec_optim.zero_grad()
            loss.backward()
            enc_optim.step()
            dec_optim.step()

            train_loss += loss.item()
            count += 1
        train_loss /= count
        pbar.set_postfix_str('Recon:{} / MMD:{}'.format(recon_loss, mmd_loss))

        enc_scheduler.step()
        dec_scheduler.step()
        best_model_weights = {
            'encoder': encoder.state_dict(),
            'decoder': decoder.state_dict(),
            'enc_optim': enc_optim.state_dict(),
            'dec_optim': dec_optim.state_dict(),
            'enc_scheduler': enc_scheduler.state_dict(),
            'dec_scheduler': dec_scheduler.state_dict(),
        }
        torch.save(best_model_weights, 'MNIST_WAE_zdim{}_cornercrop_all.pth'.format(args.z_dim))

torch.save(best_model_weights, 'MNIST_WAE_zdim{}_cornercrop_all.pth'.format(args.z_dim))
