import torch
from datetime import datetime
from pathlib import Path
import argparse
from mnist_generator import (ConvDataGenerator, FCDataGenerator,
                             ConvMaskGenerator, FCMaskGenerator)
from mnist_imputer import (ComplementImputer,
                           MaskImputer,
                           FixedNoiseDimImputer)
from mnist_critic import ConvCritic, FCCritic
from masked_mnistMG import IndepMaskedMNIST, BlockMaskedMNIST, ShadowMaskedMNIST, PatchMaskedMNIST
from torch.utils.data import DataLoader
from utils import CriticUpdater, mask_norm, mkdir, mask_data

from misgan_impute import misgan_impute
import numpy as np

import os
import json
import time
use_cuda = torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
save_dir = './models'
log_dir = './'
data_dir = log_dir + '/mnist-data'
torch.cuda.set_device(0)
with open('./run_params.json', 'r') as paramfile:
    param_list = json.load(paramfile)

def main():
    for param in param_list:
        param['results'] = {'per_pixel_seen_test_error' : [], 'per_pixel_masked_test_error' : [], 'image_seen_test_rmse' : [], 'image_masked_test_rmse' : []}
        for rep_run in range(0, 10):
            starttime = time.time()
            param['METHOD'] = 'MisGAN'
            parser = argparse.ArgumentParser()
    
            # resume from checkpoint
            parser.add_argument('--resume')
    
            # training options
            parser.add_argument('--workers', type=int, default=0)
            parser.add_argument('--epoch', type=int, default=500)
            parser.add_argument('--batch-size', type=int, default=64)
            parser.add_argument('--pretrain', default=None)
            parser.add_argument('--imputeronly', action='store_true')
    
            # log options: 0 to disable plot-interval or save-interval
            parser.add_argument('--plot-interval', type=int, default=10)
            parser.add_argument('--save-interval', type=int, default=0)
            parser.add_argument('--prefix', default='impute')
    
            # mask options (data): block|indep
            parser.add_argument('--mask', default='patch')
            # option for block: set to 0 for variable size
            parser.add_argument('--block-len', type=int, default=14)
            # option for indep:
            parser.add_argument('--obs-prob', type=float, default=.2)
            parser.add_argument('--obs-prob-high', type=float, default=None)
            # option for shadow:
            parser.add_argument('--depth', type=float, default=.89)
            # option for patch:
            parser.add_argument('--num_patches', type=float, default=27)
            # model options
            parser.add_argument('--tau', type=float, default=0)
            parser.add_argument('--generator', default='conv')   # conv|fc
            parser.add_argument('--critic', default='conv')   # conv|fc
            parser.add_argument('--alpha', type=float, default=.1)   # 0: separate
            parser.add_argument('--beta', type=float, default=.1)
            parser.add_argument('--gamma', type=float, default=0)
            parser.add_argument('--arch', default='784-784')
            parser.add_argument('--imputer', default='comp')   # comp|mask|fix
            # options for mask generator: sigmoid, hardsigmoid, fusion
            parser.add_argument('--maskgen', default='fusion')
            parser.add_argument('--gp-lambda', type=float, default=10)
            parser.add_argument('--n-critic', type=int, default=5)
            parser.add_argument('--n-latent', type=int, default=128)
    
            args = parser.parse_args()
    
            checkpoint = torch.load(save_dir + '/mnist/'+ param['MASK_FUNCTION'] + '_' + str(param['MASK_PARAM']) + '/log/checkpoint.pth', map_location='cpu')
    
            if args.imputeronly:
                assert args.pretrain is not None
    
            arch = args.arch
            imputer_type = args.imputer
            args.epoch = param['SCHEDULING'][0]['epochs']
            args.mask = param['MASK_FUNCTION']
            mask = args.mask
            if mask == 'indep':
                args.obs_prob = param['MASK_PARAM']
            elif mask == 'block':
               args.block_len = param['MASK_PARAM']
            elif mask == 'shadow':
                args.depth = param['MASK_PARAM']
            elif mask == 'patch':
                args.num_patches = param['MASK_PARAM']
            obs_prob = args.obs_prob
            obs_prob_high = args.obs_prob_high
            block_len = args.block_len
            depth = args.depth
            num_patches = args.num_patches
            if block_len == 0:
                block_len = None
    
            if args.generator == 'conv':
                DataGenerator = ConvDataGenerator
                MaskGenerator = ConvMaskGenerator
            elif args.generator == 'fc':
                DataGenerator = FCDataGenerator
                MaskGenerator = FCMaskGenerator
            else:
                raise NotImplementedError
    
            if imputer_type == 'comp':
                Imputer = ComplementImputer
            elif imputer_type == 'mask':
                Imputer = MaskImputer
            elif imputer_type == 'fix':
                Imputer = FixedNoiseDimImputer
            else:
                raise NotImplementedError
    
            if args.critic == 'conv':
                Critic = ConvCritic
            elif args.critic == 'fc':
                Critic = FCCritic
            else:
                raise NotImplementedError
    
            if args.maskgen == 'sigmoid':
                hard_sigmoid = False
            elif args.maskgen == 'hardsigmoid':
                hard_sigmoid = True
            elif args.maskgen == 'fusion':
                hard_sigmoid = -.1, 1.1
            else:
                raise NotImplementedError
    
            if mask == 'indep':
                if obs_prob_high is None:
                    mask_str = f'indep_{obs_prob:g}'
                else:
                    mask_str = f'indep_{obs_prob:g}_{obs_prob_high:g}'
            elif mask == 'block':
                mask_str = 'block_{}'.format(block_len if block_len else 'varsize')
            elif mask == 'shadow':
                mask_str = 'shadow_{}'.format(depth if depth else 'vardepth')
            elif mask == 'patch':
                mask_str = 'patch_50'
            else:
                raise NotImplementedError
    
            path = '{}_{}_{}'.format(
                args.prefix, datetime.now().strftime('%m%d.%H%M%S'),
                '_'.join([
                    f'gen_{args.generator}',
                    f'critic_{args.critic}',
                    f'imp_{args.imputer}',
                    f'tau_{args.tau:g}',
                    f'arch_{args.arch}',
                    f'maskgen_{args.maskgen}',
                    f'coef_{args.alpha:g}_{args.beta:g}_{args.gamma:g}',
                    mask_str
                ]))
    
            if mask == 'indep':
                data = IndepMaskedMNIST(
                    obs_prob=obs_prob, obs_prob_high=obs_prob_high, data_dir=data_dir, train=False)
            elif mask == 'block':
                data = BlockMaskedMNIST(block_len=block_len, data_dir=data_dir, train=False)
            elif mask == 'shadow':
                data = ShadowMaskedMNIST(depth=depth, data_dir=data_dir, train=False)
            elif mask == 'patch':
                data = PatchMaskedMNIST(num_patches=num_patches, data_dir=data_dir, train=False)
            data_gen = DataGenerator().to(device)
            mask_gen = MaskGenerator(hard_sigmoid=hard_sigmoid).to(device)
    
            hid_lens = [int(n) for n in arch.split('-')]
            imputer = Imputer(arch=hid_lens).to(device)
    
            data_critic = Critic().to(device)
            mask_critic = Critic().to(device)
            impu_critic = Critic().to(device)
    
            imputer.load_state_dict(checkpoint['imputer'])
            batch_size = 128
            data_loader = DataLoader(data, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=args.workers)
            n_batch = len(data_loader)
            data_shape = data[0][0].shape
            nz = args.n_latent
            tau = args.tau
    
            data_noise = torch.FloatTensor(batch_size, nz).to(device)
            mask_noise = torch.FloatTensor(batch_size, nz).to(device)
            impu_noise = torch.FloatTensor(batch_size, *data_shape).to(device)
    
            # Interpolation coefficient
            eps = torch.FloatTensor(batch_size, 1, 1, 1).to(device)
    
            # For computing gradient penalty
            ones = torch.ones(batch_size).to(device)
            unmask_error = 0.0
            seen_error = 0.0
            unmask_rmse = 0.0
            seen_rmse = 0.0
            num_masked = 0
            num_seen = 0
            test_len = 0
            i=0
            for real_data, real_mask, _, index in data_loader:
                # Assume real_data and real_mask have the same number of channels.
                # Could be modified to handle multi-channel images and
                # single-channel masks.
                real_mask = real_mask.float()[:, None]
    
                real_data = real_data.to(device)
                real_mask = real_mask.to(device)
    
                masked_real_data = mask_data(real_data, real_mask, tau)
    
    
                # Update discriminators' parameters
                data_noise.normal_()
                fake_data = data_gen(data_noise)
    
                impu_noise.uniform_()
                imputed_data = imputer(real_data, real_mask, impu_noise)
                masked_imputed_data = mask_data(real_data, real_mask, imputed_data)
                real_mask = real_mask.view(len(real_data), 28*28)
                imputed_data = imputed_data.view(len(real_data), 28*28)
                real_data = real_data.view(len(real_data), 28*28)
                batch_seen_error = torch.sum((real_data.mul(real_mask) - imputed_data.mul(real_mask))**2)
                batch_unmask_error = torch.sum((real_data.mul(1-real_mask) - imputed_data.mul(1-real_mask))**2)
                unmask_error += batch_unmask_error
                seen_error += batch_seen_error
                num_masked += (1-real_mask).sum()
                num_seen += (real_mask).sum()
                i += 1
                seen_rmse += ((real_data.mul(real_mask) - imputed_data.mul(real_mask))**2).sum(1).div(real_mask.sum(1)).sqrt().sum().detach()
                unmask_rmse += ((real_data.mul(1-real_mask) - imputed_data.mul(1-real_mask))**2).sum(1).div((1-real_mask).sum(1)).sqrt().sum().detach()
                
                test_len += len(real_data)
                    
            print('%f\t%f\t%f' % (seen_error/num_seen, unmask_error/num_masked, unmask_rmse/test_len))
            print(param['MASK_FUNCTION'] + '_' + str(param['MASK_PARAM']) + ' imputer fid:', time.time() - starttime)
            print('%.4f\t%.4f\t%.4f' % (seen_error/num_seen, unmask_error/num_masked, unmask_rmse/test_len))
            param['results']['per_pixel_seen_test_error'].append(float(seen_error/num_seen))
            param['results']['per_pixel_masked_test_error'].append(float(unmask_error/num_masked))
            param['results']['image_seen_test_rmse'].append(float(seen_rmse/test_len))
            param['results']['image_masked_test_rmse'].append(float(unmask_rmse/test_len))

        with open(log_dir + '/rmse_log.json', 'r') as logfile:
            arg_list = json.load(logfile)
        with open(log_dir + '/rmse_log.json', 'w') as logfile:
            json.dump(arg_list + [param], logfile)
       
        


if __name__ == '__main__':
    main()
