#!/usr/bin/env python
# -*- coding:utf-8 -*-
# software: PyCharm
# basic functions
import os
import sys
import math
import numpy as np
import shutil
import argparse

# torch functions
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.utils import save_image
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# local functions
from densenet_mnist import DenseNet
from models import *

# arguments setting
parser = argparse.ArgumentParser()
parser.add_argument('--batchSz', type=int, default=512, help='mini batch size')
parser.add_argument('--nEpochs', type=int, default=1000, help='the number of outter loop')
parser.add_argument('--bridge_num', type=int, default=10, help='the length of chain')
parser.add_argument('--cuda_device', type=int, default=0, help='choose cuda device')
parser.add_argument('--sub_size', type=int, default=20000, help='subsample size')
parser.add_argument('--no-cuda', action='store_true', help='if TRUE, cuda will not be used')
parser.add_argument('--save', help='path to save results')
parser.add_argument('--seed', type=int, default=1, help='random seed')
args = parser.parse_args(args=[])

args.cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if args.cuda else "cpu")
args.save = args.save or 'Results/MNIST'
print(args)
setup_seed(args.seed, args.cuda)
if args.cuda:
    torch.cuda.manual_seed(args.seed)
    torch.cuda.set_device(args.cuda_device)

if not os.path.exists(args.save):
    os.makedirs(args.save, exist_ok=True)

net_saved_path = args.save+'/SavedNet'
if not os.path.exists(net_saved_path):
    os.makedirs(net_saved_path, exist_ok=True)

kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
trainTransform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
testTransform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_set = dset.MNIST(root='mnist', train=True, download=True,
                       transform=trainTransform)
test_set = dset.MNIST(root='mnist', train=False, download=True,
                      transform=testTransform)

trainLoader = DataLoader(train_set, batch_size=64, shuffle=False)
train_image_data_full, _ = next(iter(trainLoader))
for batch_idx, (data, target) in enumerate(trainLoader):
    if batch_idx != 0:
        train_image_data_full = torch.cat((train_image_data_full, data), 0)

testLoader = DataLoader(test_set, batch_size=64, shuffle=False)
test_image_data, _ = next(iter(testLoader))
for batch_idx, (data, target) in enumerate(testLoader):
    if batch_idx != 0:
        test_image_data = torch.cat((test_image_data, data), 0)

rep_num = 5
FinalResults = np.ndarray(shape=(rep_num, 1), dtype=float)
for i in range(rep_num):
    ind = np.arange(args.sub_size)
    np.random.shuffle(ind)
    train_image_data = train_image_data_full[ind, :, :, :]
    # print(train_image_data.shape)

    testLoader = DataLoader(test_image_data, batch_size=args.batchSz, shuffle=False)  # , **kwargs

    n = train_image_data.shape[0]
    image_h = train_image_data.shape[2]
    dim = image_h ** 2
    Training_p = torch.randn(n, 1, image_h, image_h)
    mu = torch.zeros(dim)
    Cov = torch.eye(dim)
    dist = MultivariateNormal(mu, Cov)
    # train models
    bridge_num = args.bridge_num
    mtre_temp = 0
    ref_loglike_temp = 0
    for data in testLoader:
        data_n = data.shape[0]
        sample = data.reshape((data_n, dim))
        ref_loglike_temp = ref_loglike_temp + dist.log_prob(sample).sum().item()
    test_n = test_image_data.shape[0]
    ref_loglike = ref_loglike_temp / (test_n * dim)
    train_delta_array = torch.zeros(n, bridge_num + 1)
    for m in range(bridge_num + 1):
        lin_com_a = m / bridge_num
        sampler1 = torch.distributions.bernoulli.Bernoulli(torch.tensor([lin_com_a]))
        train_delta = sampler1.sample((n,))
        train_delta_array[:, m] = train_delta[:, 0]
    for m in range(bridge_num):
        # nets and optimizers setting
        train_delta1_temp = train_delta_array[:, m].reshape((-1, 1))
        train_delta2_temp = train_delta_array[:, m + 1].reshape((-1, 1))
        train_delta1 = train_delta1_temp.repeat(1, dim).reshape(n, 1, image_h, image_h)
        train_delta2 = train_delta2_temp.repeat(1, dim).reshape(n, 1, image_h, image_h)
        log_R_net = DenseNet(growthRate=12, depth=10, reduction=0.5, bottleneck=True)

        print('  + Number of params (net) : {}'.format(
            sum([p.data.nelement() for p in log_R_net.parameters()])))

        if args.cuda:
            log_R_net = log_R_net.cuda()

        optimizer_R = optim.Adam(log_R_net.parameters(), weight_decay=1e-4)

        lin_com_a1 = m / bridge_num
        lin_com_a2 = (m + 1) / bridge_num
        Train_q_Loader = torch.utils.data.DataLoader(
            (1 - train_delta1) * train_image_data + train_delta1 * Training_p, batch_size=args.batchSz,
            shuffle=True)
        Train_p_Loader = torch.utils.data.DataLoader(
            (1 - train_delta2) * train_image_data + train_delta2 * Training_p, batch_size=args.batchSz,
            shuffle=True)
        Training_breprocess(log_R_net, optimizer_R, Train_p_Loader, Train_q_Loader, args.nEpochs, device)
        torch.save(log_R_net.state_dict(), os.path.join(net_saved_path, 'mtre_net' + str(m) + '.pt'))
        with torch.no_grad():
            for data in testLoader:
                data = data.to(device)
                mtre_temp = mtre_temp + log_R_net(data).detach().cpu().sum().numpy().item()
    FinalResults[i] = mtre_temp / (test_n * dim) + ref_loglike
print(FinalResults)
print('Mean:\t', np.mean(FinalResults), 'Std:\t', np.std(FinalResults))


