import os
import torch
import numpy as np
from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from torchvision.utils import save_image

from utils import plot_function, UnNormalize

__all__ = ['ANPValidator']

class Validator(object):
    def __init__(self, model, validloader, args=None):
        self.args = args
        self.model = model
        self.validloader = validloader
        
        self.load_checkpoint()

    def load_checkpoint(self):
        try:
            checkpoint = torch.load('checkpoints/{}/{}/{}.pth'.format(self.args.dataset, self.args.stage, self.args.run))
        except:
            print('checkpoints/{}/{}/{}.pth not available.'.format(self.args.dataset, self.args.stage, self.args.run))
            exit()
        self.model.load_state_dict(checkpoint['model'])

        if torch.cuda.device_count() > 1:
            self.model = nn.DataParallel(self.model)
        print('Evaluating with : checkpoints/{}/{}/{}.pth trained for {} epochs'.\
                format(self.args.dataset, self.args.stage, self.args.run, checkpoint['epoch']))

    def validate(self):
        raise NotImplementedError

class ANPValidator(Validator):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.names = {
                'sss'           : 'sss',
                'random'        : 'random',
                'randomautoregressive': 'ranauto',
                'autoregressive': 'auto',
                'candidate': 'cand'
                }
    
    def valid(self, subset_size):
        self.model.eval()
        preds = []
        valid_stats = OrderedDict()
        with torch.no_grad():
            for it, D in tqdm(enumerate(self.validloader), total=len(self.validloader), ncols=75, leave=False):
                if self.args.dataset == 'function':
                    context, target = D
                    context_x, context_y, target_x, target_y = context.to(self.args.device), target.to(self.args.device), context.to(self.args.device), target.to(self.args.device)
                elif self.args.dataset == 'celeba':
                    context_x, context_y, target_x, target_y = D
                    context_x, context_y, target_x, target_y = context_x.to(self.args.device), context_y.to(self.args.device), target_x.to(self.args.device), target_y.to(self.args.device)
                output = self.model(context=(context_x, context_y), target=(target_x, target_y), subset_size=subset_size)

                batch_stats = OrderedDict()
                log_p, pred = output
                loss_p = -log_p.sum(dim=1).mean()
                batch_stats['RCLoss'] = loss_p.item()
                
                if self.args.visualize:
                    preds.append([[context_x, context_y, target_x, target_y], pred])

                for key in batch_stats:
                    if key in valid_stats:
                        valid_stats[key].append(batch_stats[key])
                    else:
                        valid_stats[key] = [batch_stats[key]]
        return valid_stats, preds
    
    def visualize_function(self):
        for subset_size in self.args.subset_sizes:
            _, preds = self.valid(subset_size=subset_size)
            
            path = 'visualizations/{}/{}/{}/{}'.format(self.args.dataset, self.args.stage, self.args.run, subset_size)
            if not os.path.exists(path):
                os.makedirs(path)

            for i in range(len(preds)):
                context_x, context_y, target_x, target_y = preds[i][0]
                if self.args.stage in ['sss', 'randomautoregressive']:
                    mu, var, subset_mask, candidate_mask = preds[i][1]
                elif self.args.stage in ['candidate', 'autoregressive', 'random']:
                    mu, var, subset_mask = preds[i][1]

                B, S, H = context_x.size()
                subset_x = torch.masked_select(context_x, subset_mask.ge(0.5)).view(B, -1, H)
                subset_y = torch.masked_select(context_y, subset_mask.ge(0.5)).view(B, -1, H)

                plot_function(target_x=target_x, target_y=target_y, context_x=subset_x, context_y=subset_y, pred_y=mu, var=var,
                        filename=os.path.join(path, '{}.pdf'.format(i)), title=self.names[self.args.stage])

    def visualize_celeba(self):
        un_normalize = UnNormalize()
        all_images = []
        for subset_size in self.args.subset_sizes:
            _, preds = self.valid(subset_size=subset_size)
            
            path = 'visualizations/{}/{}/{}/{}'.format(self.args.dataset, self.args.stage, self.args.run, subset_size)
            path_original = 'visualizations/{}/{}'.format(self.args.dataset, 'original')
            if not os.path.exists(path):
                os.makedirs(path)
            if not os.path.exists(path_original):
                os.makedirs(path_original)
            
            size = self.validloader.dataset.size
            
            pred_images = []
            for i in range(len(preds)):
                context_x, context_y, target_x, target_y = preds[i][0]
                if self.args.stage in ['sss', 'randomautoregressive']:
                    mu, var, subset_mask, candidate_mask = preds[i][1]
                elif self.args.stage in ['candidateset', 'autoregressive', 'random']:
                    mu, var, subset_mask = preds[i][1]
                
                B, S, H = context_x.size()
                target_y = un_normalize(target_y.transpose(1, 2).view(3, size[0], size[1])).unsqueeze(0)
                mu = un_normalize(mu.transpose(1, 2).view(3, size[0], size[1])).unsqueeze(0)
                
                index = (subset_x*torch.Tensor([[[size[0]*size[1], size[1]]]]).to(subset_x.device)).sum(-1).round().to(torch.long).unsqueeze(-1).repeat([1,1,subset_y.shape[-1]])
                subset_image = subset_y.new_zeros([subset_y.shape[0], size[0]*size[1], subset_y.shape[2]])
                subset_image.scatter_(1, index, subset_y)
                subset_image = subset_image.transpose(1,2).view(1, 3, size[0], size[1])
                
                if self.args.stage in ['sss', 'randomautoregressive']:
                    subset_x_candidate = torch.masked_select(context_x, candidate_mask.ge(0.5)).view(B, -1, context_x.size(2))
                    subset_y_candidate = torch.masked_select(context_y, candidate_mask.ge(0.5)).view(B, -1, context_y.size(2))
                    
                    index = (subset_x_candidate*torch.Tensor([[[size[0]*size[1], size[1]]]]).to(subset_x_candidate.device)).sum(-1).round()\
                            .to(torch.long).unsqueeze(-1).repeat([1, 1, subset_y_candidate.shape[-1]])
                    subset_image_candidate = subset_y_candidate.new_zeros([subset_y_candidate.shape[0], size[0]*size[1], subset_y_candidate.shape[2]])
                    subset_image_candidate.scatter_(1, index, subset_y_candidate)
                    subset_image_candidate = subset_image_candidate.transpose(1,2).view(1, 3, size[0], size[1])

                    image = torch.cat([subset_image_candidate, subset_image, mu, target_y], dim=0)
                    save_image(image, os.path.join(path, '{}.pdf'.format(i)), nrow=4)
                    save_image(target_y, os.path.join(path_original, '{}.pdf'.format(i)), nrow=4)
                else:
                    image = torch.cat([subset_image, mu, target_y], dim=0)
                    save_image(image, os.path.join(path, '{}.pdf'.format(i)), nrow=4)
                pred_images.append(image)
            all_images.append(pred_images)
        
        #Aggregate images from different subset sizes
        path = 'visualizations/{}/{}/{}/{}'.format(self.args.dataset, self.args.stage, self.args.run, 'all')
        if not os.path.exists(path):
            os.makedirs(path)
        
        for i in range(len(all_images[0])):
            imgs = torch.cat([m[i] for m in all_images], dim=0)
            save_image(imgs, os.path.join(path, '{}.pdf'.format(i)), nrow=imgs.size(0))

    def validate(self):
        results = OrderedDict()
        for subset_size in self.args.subset_sizes:
            valid_stats = OrderedDict()
            for i in range(self.args.eval_per_run):
                run_stats, _ = self.valid(subset_size=subset_size)
                
                for key in run_stats:
                    if key in valid_stats:
                        valid_stats[key].append(np.mean(run_stats[key]))
                    else:
                        valid_stats[key] = [np.mean(run_stats[key])]
            
            for key in valid_stats:
                value = valid_stats[key]
                mean, std = np.mean(value), np.std(value)
                valid_stats[key] = [mean, std]
            results[str(subset_size)] = valid_stats
        return results
