import os 
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import datetime

from tensorflow.python.ops.clip_ops import clip_by_value
import utils.models as um
import json
import tensorflow as tf
import csv

#tf.compat.v1.enable_eager_execution()

tfk=tf.keras

EXPERIMENT_BASEPATH="/home/documentation/experiments"
physical_devices = tf.config.experimental.list_physical_devices('GPU')
if physical_devices:
    tf.config.experimental.set_memory_growth(physical_devices[0], True)


class Experiment:

    ds_train=None
    ds_val=None
    last_epoch=0

    def __init__(self, experiment_configs):
        #TODO if type(experiment_configs) == str: path

        
        self.set_experiment_variables(experiment_configs)
        
        #If there was no model prevously the program saves new weights, if there was loads its weights
        #Only weights because the nested subclassed object lose some properties if not saved properly
        #This is why set_experiment_variables creates new model object either way

        if not os.path.isfile(str(self.model_path)+'.index'):
            print ("Creating model...")
            self.model.save_weights(self.model_path)

        else:
            print ("Loading model...")
            self.model.load_weights(self.model_path)

        self.model.compile(optimizer=tfk.optimizers.Adam(learning_rate=0.00001))
        #self.log_experiment()

    def set_experiment_variables(self, experiment_configs):
        #Setting variables for object
        self.experiment_configs=experiment_configs

        self.set_directory()

        self.set_configs()
        
        self.set_model()

      #  self.latent_dims=experiment_configs["model_configs"]["encoder"]["latent_dims"]#This is ugly...

        self.max_epochs=self.experiment_configs["max_epochs"]
    
    def set_directory(self):
        #Checking directory
        self.directory = self.experiment_configs["experiment_directory"]
        if not os.path.isdir(self.directory):
            print("Creating directory for experiment...")
            os.makedirs(self.directory, exist_ok=True)

    def set_configs(self):
        #Saving or checking configuration
        self.configs_path = os.path.join(self.directory,"configs.json")

        if os.path.isfile(self.configs_path):
            print ("Loading experiment configs...")
            with open(self.configs_path) as configs_json:
                saved_configs = json.load(configs_json)

            if saved_configs["model_configs"] != self.experiment_configs["model_configs"]:
                #Shapes has to be lists because json reads tuple as list hopefully this will work
                raise ValueError("Experiment configs corruption. Saved data does not match current congurations.")

            old_log=pd.read_csv(os.path.join(self.directory,"log.csv"))
            self.last_epoch=old_log["epoch"].iloc[-1]


        else:
            print("Saving experiment configs...")
            with open(self.configs_path,"w") as configs_json:
                json.dump(self.experiment_configs,configs_json)

    def set_model(self):
        #Creating new instance of model
        self.model_path=(os.path.join(self.directory,"weights" ))

        self.model_configs=self.experiment_configs["model_configs"]

        self.model = um.VAE_model(self.model_configs)



    def train(self, log_dir=None, learning_rate=0.00001):
        self.model.compile(optimizer=tfk.optimizers.Adam(learning_rate=learning_rate))
        if not self.ds_train:
            print("No data given!")
            return 
        
  
        callbacks=[
                tfk.callbacks.ModelCheckpoint(filepath=self.model_path,save_weights_only=True,save_best_only=False,),
                tfk.callbacks.CSVLogger(os.path.join(self.directory,"log.csv") ,append=True),
                LossCallback(self.directory,
                    beta1_configs=self.experiment_configs["beta1"],
                    beta2_configs=self.experiment_configs["beta2"])
                  ]
    
        if log_dir:
            callbacks.append(tfk.callbacks.TensorBoard(log_dir=log_dir))
        
        self.model.fit(self.ds_train,callbacks=callbacks, validation_data=self.ds_val, epochs=self.max_epochs, initial_epoch=self.last_epoch)
        #tfimport pdb; pdb.set_trace()

    def set_datasets(self,train_dataset, val_dataset):
        
        self.ds_train = train_dataset
        self.ds_val=val_dataset

    def log_experiment(self):
        um.nested_model_summary(self.model, path=self.directory)
        um.nested_model_visual(self.model, path=self.directory)

        
    def plot_reconstruction(self,split="val",shape=None):
        #Not tested, goal is to give the same dataset to different models / setups in the future
        if split=="val":
            x = next(iter(self.ds_val))[0][:10]

        elif split=="train":
            x = next(iter(self.ds_train))[0][:10]

        else:
            raise ValueError(f"Plotting reconstruction for split {split} is not possible.")
        
  
        xhat = self.model(x)
        
        if shape != None:
            x = tf.reshape(x, shape)
        
        xhat = xhat
    
        xs = [x,tf.reshape(xhat.sample(), shape), tf.reshape(xhat.mode(), shape), tf.reshape(xhat.mean(), shape)]

        row_titles = ["Original", "Sampled", "Mode", "Mean"]
        n=x.shape[0]
        m=len(row_titles)

        fig, axes = plt.subplots(nrows=m, ncols=n, figsize=(12, 8))

        for ax, row in zip(axes[:,0], row_titles):
            ax.set_title(row,  size='large')

        for i in range (m):
            for j in range (n):
                axes[i,j].imshow(xs[i][j],interpolation='none', cmap='gray')
                axes[i,j].axis('off')

        fig.tight_layout()

        fig.savefig(os.path.join(self.directory,f"reconstruction_{split}.png"),facecolor="white")
        plt.show()



class LossCallback(tfk.callbacks.Callback):
    def __init__(self,directory, beta1_configs, beta2_configs):
        super(LossCallback, self).__init__()

        self.save_path=os.path.join(directory,"extra_log.csv")
        self.beta1_configs=beta1_configs
        self.beta2_configs=beta2_configs

    def on_epoch_end(self, epoch, logs=None):

        loss=logs["loss"]

        beta1=self.model.beta1.numpy()
        beta2=self.model.beta2.numpy()
        #kl1=logs["kl1"]
        kl1=logs["reg1"]
        kl2=logs["kl2"]
        nll_loss=logs["nll"]
        fields=[epoch,beta1, beta2, nll_loss, kl1, kl2, loss]


        with open(self.save_path,'a') as log_file:
            writer = csv.writer(log_file)
            writer.writerow(fields)
        
        if epoch %500==0:
            filename=f"snapshot_{int(epoch)}"
            save_p=os.path.join(os.path.dirname(self.save_path), filename)
            self.model.save_weights(save_p)


        return super().on_epoch_end(epoch, logs=logs)


    def on_epoch_begin(self, epoch, logs=None):
        next_beta1=tf.cast(beta_scheduler(epoch, self.beta1_configs),tf.float32)
        next_beta2=tf.cast(beta_scheduler(epoch, self.beta2_configs),tf.float32)
    
        tfk.backend.set_value(self.model.beta1, next_beta1)
        tfk.backend.set_value(self.model.beta2, next_beta2)

        return super().on_epoch_begin(epoch, logs=logs)

    def on_train_begin(self, logs=None):
        fields=["epoch", "beta1", "beta2", "nll", "kl1", "kl2", "loss"]
        with open(self.save_path,'w') as log_file:
            writer = csv.writer(log_file)
            writer.writerow(fields)

        return super().on_train_begin(logs=logs)


def beta_scheduler(epoch, beta_configs):

    start=beta_configs["start"]
    stop=beta_configs["stop"]
    init=beta_configs["init"]
    final=beta_configs["final"]
    
    if epoch<=start:
        return init

    elif epoch<=stop:
        return init + (final-init)*(epoch-start)/(stop-start)

    else:
        return final