"""Training procedure for NICE.
"""

import argparse
import torch, torchvision
import numpy as np
import nice, utils
from masked_uci_nf import IndepMaskedUCI
import time

def main(args):
    ind_mse = []
    avg_mse = []
    start = time.time()
    for run in range(0, 5):
        device = torch.device("cuda:0")

        checkpoint_file = args.checkpoint + 'run' + str(run) + '.tar'
        checkpoint_dict = torch.load(checkpoint_file)

        # model hyperparameters
        dataset = checkpoint_dict['dataset']


        batch_size = checkpoint_dict['batch_size']
        latent = checkpoint_dict['latent']
        max_iter = checkpoint_dict['total_iter']
        sample_size = args.sample_size
        coupling = checkpoint_dict['coupling']
        mask_config = checkpoint_dict['mask_config']
        mean = checkpoint_dict['mean']
        scaling = checkpoint_dict['scaling']
        clamp_min = checkpoint_dict['min']
        clamp_max = checkpoint_dict['max']
        mid_dim = checkpoint_dict['mid_dim']
        hidden = checkpoint_dict['hidden']
        trained_masks = checkpoint_dict['masks']
        num_horizontal=1


        zca = None
        if dataset == 'breast':
            (full_dim, mid_dim, hidden) = (30, 120, 5)
            transform = torchvision.transforms.ToTensor()
            testset = IndepMaskedUCI(0.5, train=False,num_copies=1, dataset=dataset)
            testloader = torch.utils.data.DataLoader(testset,
                batch_size=len(testset.data), shuffle=False)
        if dataset == 'red':
            (full_dim, mid_dim, hidden) = (12, 120, 5)
            transform = torchvision.transforms.ToTensor()
            testset = IndepMaskedUCI(0.5, train=False,num_copies=1, dataset=dataset)
            testloader = torch.utils.data.DataLoader(testset,
                batch_size=len(testset.data), shuffle=False)
        if dataset == 'white':
            (full_dim, mid_dim, hidden) = (12, 120, 5)
            transform = torchvision.transforms.ToTensor()
            testset = IndepMaskedUCI(0.5, train=False,num_copies=1, dataset=dataset)
            testloader = torch.utils.data.DataLoader(testset,
                batch_size=len(testset.data), shuffle=False)
        if dataset == 'banknote':
            (full_dim, mid_dim, hidden) = (4, 120, 5)
            transform = torchvision.transforms.ToTensor()
            num_horizontal=1
            testset = IndepMaskedUCI(0.5, train=False,num_copies=1, dataset=dataset, horizontal_copies=num_horizontal)
            testloader = torch.utils.data.DataLoader(testset,
                batch_size=len(testset.data), shuffle=False)
        if dataset == 'yeast':
            (full_dim, mid_dim, hidden) = (8, 120, 5)
            transform = torchvision.transforms.ToTensor()
            testset = IndepMaskedUCI(0.5, train=False,num_copies=1, dataset=dataset)
            testloader = torch.utils.data.DataLoader(testset,
                batch_size=len(testset.data), shuffle=False)
        if dataset == 'concrete':
            (full_dim, mid_dim, hidden) = (18, 120, 5)
            transform = torchvision.transforms.ToTensor()
            num_horizontal=2
            testset = IndepMaskedUCI(0.5, train=False, num_copies=1, dataset=dataset, horizontal_copies=2)
            testloader = torch.utils.data.DataLoader(testset,
                batch_size=batch_size, shuffle=False)
        if latent == 'normal':
            prior = torch.distributions.Normal(
                torch.tensor(0.).to(device), torch.tensor(1.).to(device))
        elif latent == 'logistic':
            prior = utils.StandardLogistic()

        flow = nice.NICE(prior=prior, prior_noise=prior,
                    coupling=coupling, 
                    in_out_dim=full_dim, 
                    mid_dim=mid_dim, 
                    hidden=hidden, 
                    mask_config=mask_config, permutations_list = checkpoint_dict['permutations']).to(device)

        flow.load_state_dict(checkpoint_dict['model_state_dict'])
        for p in flow.scaling.parameters():
            print(p.data)

        total_iter = 0
        train = True
        running_loss = 0

        mask_prob = 0.6
        flow.eval()        # set to inference mode
        sample_std = 1.0
        scaling_cuda = scaling.cuda()
        mean_cuda = mean.cuda()
        testset.mask = trained_masks
        true_std = testset.data.std(dim = 0).view(1, testset.data.shape[1])
        testset.init_latents(flow)
        
        print(true_std.min())
        single_length = int(testset.data.shape[1]/num_horizontal)
        means = torch.zeros([testset.data.shape[0], single_length])
        reconstructions = []
        reconstruction_nll=[]
        num_runs = 25
        with torch.no_grad():
            for ind_num in range(0, num_runs):
                testset.reset_latents(flow, batch_size=5000, model_reset=True, sample_std=sample_std)
                for i in range(0, 10):
                    acceptance_ratio  = testset.get_latents(flow,step=True, num_steps=100, sample_std=sample_std)
                    projected_end = testset.projected_ends
                    masks = testset.mask
                    inputs = testset.data
                    recon = projected_end.mul(1.0 - masks) + inputs.mul(masks)
                    print(run, ind_num, acceptance_ratio, -(flow.log_prob(inputs) - flow.log_prob(projected_end)).mean(), float(((recon-inputs).mul(scaling_cuda).div(true_std).mul(1.0-masks)**2).sum()/float((1.0-masks).sum())))
                recon = torch.max(recon, clamp_min)
                recon = torch.min(recon, clamp_max)
                outputs = recon
                reconstructions.append(outputs.clone().cuda())
                reconstruction_nll.append(flow.log_prob(recon).unsqueeze(1))
                stacked_prob = torch.stack(reconstruction_nll)
                stack_max = stacked_prob.max(0)[0]
                sum_factor = torch.exp(torch.stack([entry - stack_max for entry in reconstruction_nll])).sum(0)
                weighted_mean = torch.stack([reconstructions[index].mul(torch.exp(reconstruction_nll[index] - stack_max)).div(sum_factor) for index in range(0, len(reconstructions))]).sum(0)
                for cpy_idx in range(0, num_horizontal):
                    means += outputs.cpu()[:, cpy_idx*single_length:(cpy_idx+1)*single_length]/num_horizontal
                print(float(((means.cuda()[:, 0:single_length]/float(ind_num+1)-inputs[:, 0:single_length]).mul(1.0-masks[:, 0:single_length])**2).mul(scaling_cuda[:, 0:single_length]).div(true_std[:, 0:single_length]).sum()/float((1.0-masks[:, 0:single_length]).sum())))
                print(float(((weighted_mean.cuda()-inputs).mul(1.0-masks)**2).mul(scaling_cuda).div(true_std).sum()/float((1.0-masks).sum())))
                ind_mse.append(float(((recon-inputs).mul(scaling_cuda).div(true_std).mul(1.0-masks)**2).sum()/float((1.0-masks).sum())))
                
            avg_mse.append(float(((means.cuda()/(float(num_runs))-inputs[:, 0:single_length]).mul(scaling_cuda[:, 0:single_length]).div(true_std[:, 0:single_length]).mul(1.0-masks[:, 0:single_length])**2).sum()/float((1.0-masks[:, 0:single_length]).sum())))        

        

        print('Finished Testing')
    print('Individual: ', np.mean(ind_mse), np.std(ind_mse))
    print('Conditional Mean: ', np.mean(avg_mse), np.std(avg_mse))

if __name__ == '__main__':
    parser = argparse.ArgumentParser('MNIST NICE PyTorch implementation')
    parser.add_argument('--checkpoint',
                        help='checkpointfile.',
                        type=str,
                        default='./models/red/red_bs3000_normal_cp4_md120_hd5_missing__')
    parser.add_argument('--sample_size',
                        help='number of images to generate.',
                        type=int,
                        default=64)
    args = parser.parse_args()
    main(args)
