#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sun Sep 20 00:20:41 2020

"""

import numpy as np
import pandas as pd
import tensorflow as tf
import os
from metrics import distance_to_minimum
np.random.seed(42)

# helper functions

def getresponse(model,batch,dataset,ntarget,split,nsource,sourcesplit,fn,directory,index=False,return_response=False):
    losses       = pd.DataFrame(data=None,columns=['mse','absolute','targetdataset'])
    cardinality  = dataset.cardinality
    
    results = []
    iterable = range(ntarget) if not index else [ntarget]
    
    for targetdataset in iterable:
        hyperparameter          = 0
        response,targetresponse = [],[]
        similarity = [];sourceresponse = []
        distribution = pd.DataFrame(data=None)
        while hyperparameter<cardinality:
            # create placeholders
            batch.clear()
            for _ in range(batch.batch_size):
                # build instance
                sourcedataset    = np.random.choice(np.arange(nsource))
                instance     = fn(targetdataset,sourcedataset,hyperparameter,targetsplit=split,sourcesplit=sourcesplit)
                # append instance to batch
                batch.append(instance)
                distribution = pd.concat([distribution,\
                                          pd.DataFrame(np.asarray([targetdataset,hyperparameter,sourcedataset]).reshape(1,-1))],
                                         axis=0,ignore_index=True)                        
                hyperparameter += 1
            # group instances of batch in a group
            batch.collect()
            # append neg_list
            # prediction step
            targety,sourcey,distance = model.predict(x=batch.input,y=batch.output)
            # append target y of evalpos-sourcedataset pair
            response.append(batch.output['response'])
            # append mean y of evalpos-sourcedataset pair
            targetresponse.append(targety)
            # append mean y of evalpos-sourcedataset pair
            sourceresponse.append(sourcey)                    
            # append similarity of evalpos-sourcedataset pair
            similarity.append(distance)
        # reshape valid outputs
        
        response       = tf.reshape(tensor=tf.stack(response),shape=(1,-1))
        targetresponse = tf.reshape(tensor=tf.stack(targetresponse),shape=(1,-1))
        sourceresponse = tf.reshape(tensor=tf.stack(sourceresponse),shape=(1,-1))
        similarity     = tf.reshape(tensor=tf.stack(similarity),shape=(1,-1))
        # calculate loss        
        mse       = tf.keras.losses.mean_squared_error(y_true = response,y_pred=targetresponse)
        absolute  = tf.keras.losses.mean_absolute_error(y_true = response,y_pred=targetresponse)

        losses = pd.concat([losses,pd.DataFrame(np.asarray([mse.numpy()[0],absolute.numpy()[0],targetdataset]).reshape(1,-1))],axis=0,ignore_index=True)
        
        summary = pd.DataFrame(np.concatenate([targetresponse.numpy().reshape(-1,1),response.numpy().reshape(-1,1),
                                               sourceresponse.numpy().reshape(-1,1),similarity.numpy().reshape(-1,1),
                                               np.asarray(distribution[0]).reshape(-1,1),np.asarray(distribution[2]).reshape(-1,1),
                                               np.asarray(distribution[1]).reshape(-1,1)],axis=1)\
                                     ,columns=['targetresponse','response','sourceresponse','similarity','targetdataset','sourcedataset','hyperparameter'])
        results.append(summary)
    
    savedir = os.path.join(directory,split if not return_response else '')
    os.makedirs(savedir,exist_ok=True)
    losses.to_csv(os.path.join(savedir,"loss-summary.csv"))
    
    [_.to_csv(f"{savedir}/overview-targetdataset-{targetdataset if not index else ntarget}.csv") for targetdataset,_ in enumerate(results)]
    if index:
        assert(len(results)==1)
        d2m     = [distance_to_minimum(results[0],targetdataset,split=split,dataset=dataset)]
    else:
        d2m     = [distance_to_minimum(_,targetdataset,split=split,dataset=dataset) for targetdataset,_ in enumerate(results)]
    
    pd.DataFrame(np.vstack(d2m)).to_csv(os.path.join(savedir,"distance-to-minimum-per-dataset.csv"))
    d2m  = pd.DataFrame(np.vstack(d2m).mean(axis=0))
    d2m.to_csv(os.path.join(savedir,"distance-to-minimum.csv"))
    if not return_response:
        d2m = np.asarray(d2m)
        return {'min_20':np.min(d2m[:20]),'min_50':np.min(d2m[:50]),'argmin':np.argmin(d2m)}
    else:
        return results,d2m
        