import argparse
import torch, torchvision
import numpy as np
import nice, utils
from matplotlib import pyplot as pl
from masked_mnist_nf import IndepMaskedMNIST, BlockMaskedMNIST, PatchMaskedMNIST

def main():
    test_dir = "patch_97"
    device = torch.device("cuda:0")
    with torch.no_grad():
        individual = []
        average = []
        for run in range(0, 10):
            checkpoint_file = './models/%s/tests/impute_%d.tar' % (test_dir,run)
            checkpoint_dict = torch.load(checkpoint_file)
            projected_end = checkpoint_dict['completions'].view(len(checkpoint_dict['completions']), 784)
            inputs = checkpoint_dict['true'].view(len(checkpoint_dict['completions']), 784)
            masks = checkpoint_dict['mask'].view(len(checkpoint_dict['completions']), 784)
            scaling = checkpoint_dict['scaling']
            clamp_min = checkpoint_dict['min']
            clamp_max = checkpoint_dict['max']
            mean = checkpoint_dict['mean']
            projected_end = torch.max(projected_end, clamp_min)
            projected_end = torch.min(projected_end, clamp_max)
            if run == 0:
                average_proj  = projected_end.clone()/10
            else:
                average_proj += projected_end.clone()/10
            single_length = int(len(inputs)/5)
            for i in range(0, 5):
                single_rmse = float(((projected_end[i*single_length:(i+1)*single_length]-inputs[i*single_length:(i+1)*single_length]).mul(scaling.cuda()).mul(1.0-masks[i*single_length:(i+1)*single_length])**2).sum(1).div((1.0-masks[i*single_length:(i+1)*single_length]).sum(1)).sqrt().sum()/float(len(masks[i*single_length:(i+1)*single_length])))
                print(single_rmse)
                individual.append(single_rmse)
                if run == 9:
                    average_rmse = float(((1.0*average_proj[i*single_length:(i+1)*single_length]-inputs[i*single_length:(i+1)*single_length]).mul(scaling.cuda()).mul(1.0-masks[i*single_length:(i+1)*single_length])**2).sum(1).div((1.0-masks[i*single_length:(i+1)*single_length]).sum(1)).sqrt().sum()/float(len(masks[i*single_length:(i+1)*single_length])))
                    average.append(average_rmse)
        print("Individual: ", np.mean(individual), np.std(individual))
        print("Average: ", np.mean(average), np.std(average))
                    


if __name__ == '__main__':
    main()
