import torch
from torch.utils.data import Dataset
from torchvision import datasets, transforms
import numpy as np
import utils
import csv
import pandas as pd

def get_uci_data(dataset='breast', location='./data/wdbc.data'):
    if dataset=='breast':
        data_list = []
        with open(location) as csvfile:
            reader = csv.reader(csvfile, delimiter=',')
            data_list = [[float(entry) for entry in row[2:]] for row in reader]
        return np.array(data_list)
    if dataset == 'red':
        url = "https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-red.csv"
        return np.array(pd.read_csv(url, low_memory=False, sep=';'))
    if dataset == 'white':
        url = "https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-white.csv"
        return np.array(pd.read_csv(url, low_memory=False, sep=';'))
    if dataset == 'banknote':
        url = "https://archive.ics.uci.edu/ml/machine-learning-databases/00267/data_banknote_authentication.txt"
        return np.array(pd.read_csv(url, low_memory=False, sep=','))[:,0:4]
    if dataset == 'yeast':
        url = "https://archive.ics.uci.edu/ml/machine-learning-databases/yeast/yeast.data"
        print(np.array(pd.read_csv(url, low_memory=False, sep=r'\s+', usecols=[1,2,3,4,5,6,7,8]))[:,0:8][0])
        return np.array(pd.read_csv(url, low_memory=False, sep=r'\s+', usecols=[1,2,3,4,5,6,7,8]))[:,0:8]
    if dataset == 'concrete':
        url = "https://archive.ics.uci.edu/ml/machine-learning-databases/concrete/compressive/Concrete_Data.xls"
        return np.array(pd.read_excel(url))[:,0:9]


class MaskedUCI(Dataset):
    def __init__(self, image_size=28, random_seed=0, train=True, dataset='breast', device=torch.device("cuda:0"), num_copies = 10, horizontal_copies=1):
        self.rnd = np.random.RandomState(random_seed)
        self.train=train
        self.data = torch.Tensor(get_uci_data(dataset=dataset)).to(device)
        self.image_size = self.data.shape[1]
        self.base_len = self.data.shape[0]
        self.generate_masks()
        self.image_size = self.data.shape[1]*horizontal_copies
        self.mask = self.mask.float().to(device)
        self.data = self.data.repeat(num_copies, horizontal_copies)
        self.mask = self.mask.repeat(num_copies, horizontal_copies)
        self.latents = torch.zeros(len(self.data))
        self.projected_ends = torch.zeros(len(self.data))
        self.stdev = None
        self.means = None
        self.observed_values = None
        self.min= None
        self.max =None

    def init_latents(self, flow, batch_size=200):
        self.latents = torch.empty(self.data.shape).normal_(mean=0.0, std=1.0).cuda()#torch.zeros(self.data.data.shape).cuda()#
        self.projected_ends = torch.zeros(self.data.shape).cuda()
        self.projected_ends = self.data.mul(1.0 - self.mask)
        vis_prob = self.mask.mean(dim=0)
        
        observed_mean = (self.data.mul(self.mask)).mean(dim=0).div(vis_prob)
        self.means = observed_mean.view(1, self.image_size).cpu()
        self.stdev = torch.max((((self.data - observed_mean).mul(self.mask)**2).mean(dim=0)).div(vis_prob)**0.5, 1e-6*torch.ones(self.means.shape).cuda())
        self.stdev = self.stdev.view(1,self.image_size).cpu()
        with torch.no_grad():
            for ndx in range(0, len(self.data), batch_size):
                l=len(self.data)
                self.latents[ndx:min(ndx + batch_size, l)] = self.latents[ndx:min(ndx + batch_size, l)].mul(self.stdev.cuda())
                original_proposals = flow.f(self.latents[ndx:min(ndx + batch_size, l)])[0]
                self.data[ndx:min(ndx + batch_size, l)] = (self.data[ndx:min(ndx + batch_size, l)] - self.means.cuda()).div(self.stdev.cuda())
                inputs = self.data[ndx:min(ndx + batch_size, l)]
                resample_mask = torch.empty([len(inputs), 1]).bernoulli(0.5).cuda()
                masks = self.mask[ndx:min(ndx + batch_size, l)]
                original_proposals = original_proposals.view(len(original_proposals), self.image_size)

                #resample_mask = resample_mask.view(len(resample_mask), 28*28)
                masks = masks.view(len(masks), self.image_size)
                projected_end = flow.g(original_proposals).mul(1.0 - masks) + self.data[ndx:min(ndx + batch_size, l)].view(len(self.data[ndx:min(ndx + batch_size, l)]), self.image_size).mul(masks)
                self.projected_ends[ndx:min(ndx + batch_size, l)] = projected_end.view(len(projected_end), self.image_size)
                projected_end = None
                if ndx == 0:
                    self.min = (self.data[ndx:min(ndx + batch_size, l)].view(len(self.data.data[ndx:min(ndx + batch_size, l)]), self.image_size).mul(masks) + 100*(1.0-masks)).min(dim=0)[0]
                    self.max = (self.data[ndx:min(ndx + batch_size, l)].view(len(self.data.data[ndx:min(ndx + batch_size, l)]), self.image_size).mul(masks)).max(dim=0)[0]
                else:
                    self.min = torch.min(self.data[ndx:min(ndx + batch_size, l)].view(len(self.data[ndx:min(ndx + batch_size, l)]), self.image_size).mul(masks) + 100*(1.0-masks), self.min).min(dim=0)[0]
                    self.max = torch.max(self.data[ndx:min(ndx + batch_size, l)].view(len(self.data[ndx:min(ndx + batch_size, l)]), self.image_size).mul(masks), self.max).max(dim=0)[0]

    def reset_latents(self, flow, batch_size=200, model_reset=False, sample_std=1.0):
        self.latents = torch.empty(len(self.data), self.image_size).normal_(mean=0.0, std=1.0).cuda()#torch.zeros(self.data.data.shape).cuda()#
        self.projected_ends = torch.zeros(self.data.shape).cuda()
        with torch.no_grad():
            for ndx in range(0, len(self.data.data), batch_size):
                l=len(self.data.data)
                if model_reset:
                    self.latents[ndx:min(ndx + batch_size, l)] = flow.g(torch.empty((self.latents[ndx:min(ndx + batch_size, l)]).shape).normal_(mean=0.0, std=1.0).mul(sample_std).cuda())
                else:
                    self.latents[ndx:min(ndx + batch_size, l)] = self.latents[ndx:min(ndx + batch_size, l)]#.mul(self.stdev.cuda())
                original_proposals = flow.f(self.latents[ndx:min(ndx + batch_size, l)])[0]
                masks = self.mask[ndx:min(ndx + batch_size, l)]
                original_proposals = original_proposals.view(len(original_proposals), self.image_size)

                #resample_mask = resample_mask.view(len(resample_mask), 28*28)
                masks = masks.view(len(masks), self.image_size)
                projected_end = flow.g(original_proposals).mul(1.0 - masks) + self.data[ndx:min(ndx + batch_size, l)].view(len(self.data[ndx:min(ndx + batch_size, l)]), self.image_size).mul(masks)
                self.projected_ends[ndx:min(ndx + batch_size, l)] = projected_end.view(len(projected_end), self.image_size)
                projected_end = None


    def get_latents(self, flow, batch_size=10000, step=True, num_steps=5, sample_std=1.0):
        with torch.no_grad():
            acceptances = 0
            tries = 0
            resample_prob=0.5
            prop_std = 0.01
            gibbs_prob = 1.0
            if step:
                accepted_indices = range(0, len(self.data))
                changing_indices = []
                current_mask_holder = torch.empty(batch_size, self.image_size).cuda()
                resample_mask_holder = torch.empty([batch_size, 1]).cuda()
                acceptance_samples = torch.empty([batch_size, 1]).cuda()
                normal_sample_holder = torch.empty(batch_size, self.image_size).cuda()
                perturbations = torch.empty(batch_size, self.image_size).cuda()
                for sample_idx in range(0, num_steps):
                    #if sample_idx != 0:
                    #    accepted_indices = changing_indices
                    for repetition in range(0, int(len(self.data)/len(accepted_indices))):
                        for ndx in range(0, len(accepted_indices), batch_size):
                            l=len(accepted_indices)
                            batch_length = min(ndx + batch_size, l) - ndx
                            affected_indices = accepted_indices[ndx:min(ndx + batch_size, l)]
                            original_proposals = flow.f(self.latents[affected_indices])[0]
                            inputs = self.data[affected_indices]
                            resample_mask_holder.bernoulli_(1.0 - resample_prob)
                            resample_mask = resample_mask_holder[0:batch_length]
                            masks = self.mask[affected_indices]
                            original_proposals = original_proposals.view(len(original_proposals), self.image_size)

                            #resample_mask = resample_mask.view(len(resample_mask), 28*28)
                            masks = masks.view(len(masks), self.image_size)
                            inputs = inputs.view(len(inputs), self.image_size)
                            projected_end = flow.g(original_proposals).mul(1.0 - masks) + inputs.mul(masks)
                            current_mask_holder.bernoulli_(gibbs_prob)
                            current_mask = current_mask_holder[0:batch_length]
                            normal_sample_holder.normal_(mean=0.0, std=1.0)
                            normal_samples = normal_sample_holder[0:batch_length]
                            new_proposals = (original_proposals + normal_samples.mul(prop_std)).mul(current_mask).mul(resample_mask)

                            
                            new_proposals += normal_samples.mul(sample_std).mul(1.0 - resample_mask)
                            proposal = flow.g(new_proposals)
                            #proposal = torch.max(proposal, self.min)
                            #proposal = torch.min(proposal, self.max)
                            perturbations.normal_(mean=0.0, std=1.0)
                            current_perturbations = 0.01*perturbations[0:batch_length]
                            #bayes_mod = (1e6)*(((flow.g(original_proposals)-inputs + perturbations).mul(self.stdev.cuda()).mul(masks)**2).sum(dim=1)/2.0 - ((proposal-inputs + perturbations).mul(self.stdev.cuda()).mul(masks)**2).sum(dim=1)/2.0)
                            bayes_mod = (1e6)*(((flow.g(original_proposals)-inputs + current_perturbations).mul(masks)**2).sum(dim=1)/2.0 - ((proposal-inputs+current_perturbations).mul(masks)**2).sum(dim=1)/2.0)
                            proposal = proposal.mul(1.0 - masks) + (inputs+current_perturbations).mul(masks)
                            projected_end = projected_end.mul(1.0 - masks) + (inputs+current_perturbations).mul(masks)
                            #acceptance_prob = torch.exp(bayes_mod.unsqueeze(1) + flow.log_prob(proposal).unsqueeze(1) - (proposal**2).div(8).sum(dim=1).unsqueeze(1) - flow.log_prob(projected_end).unsqueeze(1) + (projected_end**2).div(8).sum(dim=1).unsqueeze(1) + ((new_proposals**2).sum(1)/2.0 - (original_proposals**2).sum(1)/2.0).unsqueeze(1).mul(1.0 - resample_mask)/sample_std**2)
                            acceptance_samples.uniform_()

                            acceptance_prob = bayes_mod.unsqueeze(1) + flow.log_prob(proposal).unsqueeze(1) - flow.log_prob(projected_end).unsqueeze(1) + ((new_proposals**2).sum(1)/2.0 - (original_proposals**2).sum(1)/2.0).unsqueeze(1).mul(1.0 - resample_mask)/sample_std**2
                            acceptance_prob = torch.exp(torch.clamp(acceptance_prob, -25, 25))                            
                            current_acceptances=(acceptance_samples[0:batch_length] < acceptance_prob).float()
                            acceptances += float(current_acceptances.sum())
                            tries += float(len(current_acceptances))
                            if sample_idx ==0:
                                changing_indices += [affected_indices[index] for index in range(0,batch_length) if current_acceptances[index]]
                            original_proposals = new_proposals.mul(current_acceptances) + original_proposals.mul(1.0 - current_acceptances)
                            self.latents[affected_indices] = flow.g(original_proposals)
                            self.projected_ends[affected_indices] = (flow.g(original_proposals).mul(1.0 - masks) + inputs.mul(masks)).view(len(original_proposals), self.image_size)

        return acceptances/tries



    def __getitem__(self, index):
        return  self.data[index], self.mask[index], self.projected_ends[index],self.latents[index],  index

    def __len__(self):
        return len(self.data)

    def generate_masks(self):
        raise NotImplementedError


class IndepMaskedUCI(MaskedUCI):
    def __init__(self, obs_prob=.2, obs_prob_high=None, *args, **kwargs):
        self.prob = obs_prob
        self.prob_high = obs_prob_high
        super().__init__(*args, **kwargs)

    def generate_masks(self):
        imsize = self.image_size
        prob = self.prob
        prob_high = self.prob_high
        n_masks = len(self)
        self.mask = torch.ByteTensor(n_masks, imsize)
        for i in range(n_masks):
            if prob_high is None:
                p = prob
            else:
                p = self.rnd.uniform(prob, prob_high)
            self.mask[i] = torch.ByteTensor(imsize).bernoulli_(p)


