#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Sep 25 16:21:42 2020

"""

import tensorflow as tf
import pandas as pd
import random
import numpy as np
np.random.seed(318)
random.seed(3718)
tf.random.set_seed(0)
class Batch(object):
    
    def __init__(self,batch_size,fixed_shape = True):
        
        self.batch_size = batch_size
        self.fixed_shape = fixed_shape
        self.clear()
    
    def clear(self):
        # flattened triplets
        self.x = []
        # number of instances per item in triplets
        self.instances = []
        # number of features per item in triplets
        self.features = []
        # number of classes per item in triplets
        self.classes = []
        # surrogate value of the postive item in the triplets
        self.ypos = []
        # surrogate value of the negativ item in the triplets
        self.yneg = []
        # surrogate task 
        self.surr = []        
        # model input
        self.input = None
        
    def append(self,instance):
        
        if len(self.x)==self.batch_size:
            
            self.clear()
            
        self.x.append(instance[0])
        self.instances.append(instance[1])
        self.features.append(instance[2])
        self.classes.append(instance[3])
        self.ypos.append(instance[4])
        self.surr.append(instance[5])
        self.yneg.append(instance[6])
        
    def collect(self):
        
        if len(self.x)!= self.batch_size and self.fixed_shape:
            raise(f'Batch formation incomplete!\n{len(self.x)}!={self.batch_size}')
        self.input = (tf.concat(self.x,axis=0),
                      tf.cast(tf.transpose(tf.concat(self.classes,axis=0)),dtype=tf.int32),
                      tf.cast(tf.transpose(tf.concat(self.features,axis=0)),dtype=tf.int32),
                      tf.cast(tf.transpose(tf.concat(self.instances,axis=0)),dtype=tf.int32),
                      tf.stack(self.surr),
                      )
        self.output = {'response':tf.cast(tf.concat([self.ypos],axis=0),dtype=tf.float32),
                       'similaritytarget':tf.concat([tf.ones(self.batch_size),tf.zeros(self.batch_size)],axis=0)}

def pool(n,ntotal,shuffle):
    _pool = [_ for _ in list(range(ntotal)) if _!= n]
    if shuffle:
        random.shuffle(_pool)
    return _pool

class Sampling(object):
    def __init__(self,dataset,fixed_hyperparameter):
        self.dataset          = dataset
        self.distribution     = pd.DataFrame(data=None,columns=['targetdataset','sourcedataset','hyperparameters'])
        self.fixed_hyperparameter = fixed_hyperparameter
        self.targetdataset   = None
        self.hyperparameters = None
        

    def sample(self,batch,split,sourcesplit,reuse=False):
        
        nsource  = len(self.dataset.orig_data[sourcesplit])
        ntarget  = len(self.dataset.orig_data[split])
        targetdataset = np.random.choice(ntarget) if not reuse else self.targetdataset
        
        if self.fixed_hyperparameter:
            hyperparameters = np.asarray(batch.batch_size*[np.random.choice(np.arange(self.dataset.cardinality))]) if not reuse \
                else self.hyperparameters
        else:
            hyperparameters = np.random.choice(np.arange(self.dataset.cardinality),size=batch.batch_size,replace=False)
        # clear batch
        batch.clear() 
        # find the negative dataset list of batch_size
        swimmingpool  = pool(targetdataset,nsource,shuffle=True) if split==sourcesplit else pool(-1,nsource,shuffle=True)
        # double check divisibilty by batch size
        sourcedataset = np.random.choice(swimmingpool,batch.batch_size,replace=False)
        # iterate over batch negative datasets
        for source,hyperparameter in zip(sourcedataset,hyperparameters):
            # build instance
            instance = self.dataset.instances(targetdataset,source,hyperparameter,split=split,sourcesplit=sourcesplit)
            batch.append(instance)
        
        distribution      = np.concatenate([np.asarray(batch.batch_size*[targetdataset])[:,None],sourcedataset[:,None],hyperparameters[:,None]],axis=1)
        self.distribution = pd.concat([self.distribution,\
                                       pd.DataFrame(distribution,columns=['targetdataset','sourcedataset','hyperparameters'])],axis=0,ignore_index=True)
            
        self.hyperparameters = hyperparameters
        self.targetdataset   = targetdataset  
        return batch
    
class TestSampling(object):
    def __init__(self,dataset,fixed_hyperparameter):
        self.dataset          = dataset
        self.distribution     = pd.DataFrame(data=None,columns=['targetdataset','sourcedataset','hyperparameters'])
        self.fixed_hyperparameter=fixed_hyperparameter

    def sample(self,batch,split,sourcesplit,targetdataset,collection,index=None):
        
        nsource  = len(self.dataset.orig_data[sourcesplit])
        if not self.fixed_hyperparameter and index is None:
            hyperparameters = np.random.choice(collection,size=batch.batch_size,replace=True if len(collection)<batch.batch_size else False)
        else:
            hyperparameters = np.asarray(batch.batch_size*[collection[index]]).reshape(-1,)
        # clear batch
        batch.clear() 
        # find the negative dataset list of batch_size
        swimmingpool  = pool(targetdataset,nsource,shuffle=True) if split==sourcesplit else pool(-1,nsource,shuffle=True)
        # double check divisibilty by batch size
        sourcedataset = np.random.choice(swimmingpool,batch.batch_size,replace=False)
        # iterate over batch negative datasets
        for source,hyperparameter in zip(sourcedataset,hyperparameters):
            # build instance
            instance = self.dataset.instances(targetdataset,source,hyperparameter,split=split,sourcesplit=sourcesplit)
            batch.append(instance)
        
        distribution      = np.concatenate([np.asarray(batch.batch_size*[targetdataset])[:,None],sourcedataset[:,None],hyperparameters[:,None]],axis=1)
        self.distribution = pd.concat([self.distribution,\
                                       pd.DataFrame(distribution,columns=['targetdataset','sourcedataset','hyperparameters'])],axis=0,ignore_index=True)    
            
        return batch