import os
import tempfile
import shutil
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['TF_DETERMINISTIC_OPS'] = '1'

import numpy as np
import tensorflow as tf
import tensorflow.keras.backend as K
import random as pyrand

K.set_image_data_format('channels_last')

from tensorflow.keras import optimizers
from tensorflow.keras.layers import Activation
from tensorflow.keras.layers import Conv2D,  MaxPooling2D
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Dropout
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.models import Model
from time import time
from tensorflow.keras.utils import Progbar


tf.autograph.set_verbosity(1)
tf.get_logger().setLevel('ERROR')

# from datalib import Cifar10
# from datalib import Lfw
# from datalib import Mnist
from datalib import FashionMnist
from datalib.processors import Normalize, UnitBall

DEFAULT_BATCH_SIZE = 64

def string_to_architecture(input_shape, num_classes, architecture, multi=True, 
        model_constructor=None, model_args=[], model_kwargs={}):

    if model_constructor is None:
        model_constructor = tf.keras.Model

    if '.' in architecture:
        layers = architecture.split('.')
    else:
        layers = [str(architecture)]
    layers = architecture.split('.')
    flat = (len(input_shape) == 1)

    x = Input(input_shape)
    z = x

    for layer_string in layers:

        # Convolutional layers are specified in the format:
        #   `(num_channels, kernel_size, strides, padding)`
        # The only required data is `num_channels` and `kernel_size`.
        # If `strides` is not given, then it defaults to 1 (no striding).
        # If `padding` is not given, then it defaults to `'same'`
        if layer_string.startswith('('):
            if flat:
                raise ValueError(
                    'cannot have a convolutional layer on a flat input')

            layer_info = layer_string[1:-1].split(',')
            strides = 1
            padding = 'same'
            if len(layer_info) >= 2:
                channels = int(layer_info[0])
                kernel_size = int(layer_info[1])

            if len(layer_info) >= 3:
                strides = int(layer_info[2])

            if len(layer_info) >= 4:
                padding = layer_info[3]
        
            next_layer = Conv2D(channels,
                            kernel_size,
                            strides=strides,
                            padding=padding,
                            activation='linear')
            z = next_layer(z)
            z = Activation('relu')(z)

        # A dropout layer is specified in the format:
        #   dXX
        # Where XX is a number in the range 0-100 that
        # specifies the dropout percentage (rate * 100).
        # For example:
        #   d05 gives a dropout layer with rate of 0.05
        #   d10 gives a dropout layer with rate of 0.1
        #   d1  does *NOT*  give a layer with rate 0.01, this is also rate 0.1!
        elif layer_string.startswith('d'):
            rate = float("0.{}".format(layer_string[1:]))
            z = Dropout(rate)(z)
            
        elif layer_string.startswith('p'):
            next_layer = MaxPooling2D(pool_size=(2,2), strides=(2,2))
            z = next_layer(z)

        # Dense layers are simply given as integers that
        # specify the number of units to use
        else:
            if not flat:
                z = Flatten()(z)
                flat = True
            
            
            z = Dense(int(layer_string))(z)
            z = Activation('relu')(z)

    if not flat:
        z = Flatten()(z)
    
    if multi:

        y = Dense(num_classes)(z)
    else:

        y = Dense(1)(z)

    return model_constructor(x, y, *model_args, **model_kwargs)

class LUFExperiment(object):

    def __init__(self, architecture, data, multi=True, indices=None, num_trials=None, seed=0,
            model_constructor=None, model_args=[], model_kwargs={},
            compile_kwargs={}, fit_args=[], fit_kwargs={}, 
            custom_objects={}, early_stop_cutoff=1.e-1):

        self._arch = architecture
        self._data = data
        self._seed = seed
        self._early_stop_cutoff = early_stop_cutoff
        self._indices=indices
        
        if multi:
            self._multi=True
        else:
            self._multi=False

        # if num_trials unspecified, then we iterate over the entire
        # training set, leaving each instance out once
        
        #we want to be able to pass in the leave-one-out points!
        if indices is None:
            num_trials=num_trials
        else:
            num_trials=indices.shape[0]
        
        if num_trials is None:
            self._trials = len(data.x_tr)
        else:
            self._trials = num_trials
        
        if model_constructor is None:
            self._model_con = tf.keras.Model
        else:
            self._model_con = model_constructor
        
        if not 'loss' in compile_kwargs:

            if self._multi:
                compile_kwargs['loss'] = 'categorical_crossentropy'
            else:
                compile_kwargs['loss'] = 'binary_crossentropy'
        if not 'optimizer' in compile_kwargs:
            compile_kwargs['optimizer'] = 'adam'
        if not 'metrics' in compile_kwargs:
            compile_kwargs['metrics'] = ['acc']
        
        self._model_args = model_args
        self._model_kwargs = model_kwargs
        self._compile_kwargs = compile_kwargs
        self._fit_args = fit_args
        self._fit_kwargs = fit_kwargs

        self._custom_objects = None
        
        self._models = None
        self._info = None
        self._tmp_dir = None
    
    def save(self, base_dir='luf_experiments'):

        if self._models is None or self._info is None:
            raise ValueError('call the `run` method before saving')
        
        if not os.path.isdir(base_dir):
            os.mkdir(base_dir)
        
        data_dir = os.path.join(base_dir, self._data._name)
        if not os.path.isdir(data_dir):
            os.mkdir(data_dir)
        
        exp_dir = os.path.join(data_dir, f'{self._model_con.__name__}_{self._arch}_{self._trials}trials')
        if not os.path.isdir(exp_dir):
            os.mkdir(exp_dir)
        
        model_dir = os.path.join(exp_dir, 'models')
        if not os.path.isdir(model_dir):
            os.mkdir(model_dir)
        
        np.save(os.path.join(exp_dir, 'results.npy'), self._info, allow_pickle=True)
        for i, model in enumerate(self._models):
            shutil.move(model, os.path.join(model_dir, f'loo_{i}.h5'))
        self._tmp_dir = model_dir
        
        np.save(
        os.path.join(exp_dir, 'init_params.npy'),
        (
            self._arch,
            self._seed,
            self._trials,
            self._model_con,
            self._model_args,
            self._model_kwargs,
            self._compile_kwargs,
            self._fit_args,
            self._fit_kwargs
        ),
        allow_pickle=True)
        
        print(f'saved experiment to {exp_dir}')

        return exp_dir
    
    def load(dir, data, custom_objects={}):

        init_params = os.path.join(dir, 'init_params.npy')
        if not os.path.isfile(init_params):
            raise FileNotFoundError(f'could not find config file in {init_params}')

        model_dir = os.path.join(dir, 'models')
        if not os.path.isdir(model_dir):
            raise FileNotFoundError(f'could not find model directory at {model_dir}')
        
        results = os.path.join(dir, 'results.npy')
        if not os.path.isfile(results):
            raise FileNotFoundError(f'could not find results file at {results}')
        
        config = np.load(os.path.join(dir, 'init_params.npy'), allow_pickle=True)
        
        architecture = config[0]
        seed = config[1]
        num_trials = config[2]
        model_con = config[3]
        model_args = config[4]
        model_kwargs = config[5]
        compile_kwargs = config[6]
        fit_args = config[7]
        fit_kwargs = config[8]

        obj = LUFExperiment(
            architecture, 
            data, 
            num_trials=num_trials, 
            model_constructor=model_con,
            seed=seed,
            model_args=model_args,
            model_kwargs=model_kwargs,
            compile_kwargs=compile_kwargs,
            fit_args=fit_args,
            fit_kwargs=fit_kwargs)
        
        obj._models = [os.path.join(model_dir, f'loo_{i}.h5') for i in range(num_trials)]
        obj._tmp_dir = model_dir
        obj._info = np.load(os.path.join(dir, 'results.npy'), allow_pickle=True)

        return obj

    def _train_one_model(self, data, leave_out_index, seed, 
            architecture, model_constructor, model_args, model_kwargs,
            compile_kwargs, fit_args, fit_kwargs):

        pyrand.seed(seed)
        np.random.seed(seed)
        tf.random.set_seed(seed)

        x_tr = np.delete(data.x_tr, leave_out_index, axis=0)
        if self._multi:
        
            y_tr = np.delete(data.y_tr_1hot, leave_out_index, axis=0)
        else:
            y_tr = np.delete(data.y_tr, leave_out_index, axis=0)

        train = tf.data.Dataset.from_tensor_slices((x_tr, y_tr))
        if 'batch_size' in fit_kwargs:
            train = train.batch(fit_kwargs['batch_size'])
            del fit_kwargs['batch_size']
        else:
            train = train.batch(DEFAULT_BATCH_SIZE)

        callbacks = []
        if 'epochs' not in fit_kwargs:
            callbacks += [
                tf.keras.callbacks.EarlyStopping(
                    monitor='loss', 
                    min_delta=self._early_stop_cutoff, 
                    patience=3
                )]
            fit_kwargs['epochs'] = 1000

        if 'callbacks' not in fit_kwargs:
            fit_kwargs['callbacks'] = callbacks
        else:
            fit_kwargs['callbacks'] += callbacks

        model = string_to_architecture(
            data.input_shape,
            data.num_classes,
            architecture,
            multi=self._multi,
            model_constructor=model_constructor, 
            model_args=model_args,
            model_kwargs=model_kwargs)
        model.compile(**compile_kwargs)
    
        history = model.fit(
            train,
            *fit_args,
            verbose=0,
            **fit_kwargs)
    


        return model, history
    
    def _train_all_models(self):

        pyrand.seed(self._seed)
        np.random.seed(self._seed)
        tf.random.set_seed(self._seed)
        
        if self._indices is None:
            indices = np.random.choice(np.arange(len(self._data.x_tr)), self._trials, replace=False)
        else:
            indices=self._indices
        if self._models is None:

            if self._tmp_dir is None:
                self._tmp_dir = tempfile.TemporaryDirectory(dir='.')
                tmp_dir_name = self._tmp_dir.name
            else:
                if isinstance(self._tmp_dir, str):
                    tmp_dir_name = self._tmp_dir
                else:
                    tmp_dir_name = self._tmp_dir.name

            models = []

            pb = Progbar(self._trials)
            for (i, index) in enumerate(indices):
                model, history = self._train_one_model(
                    self._data,
                    index,
                    self._seed,
                    self._arch,
                    self._model_con,
                    self._model_args,
                    self._model_kwargs,
                    self._compile_kwargs,
                    self._fit_args,
                    dict(self._fit_kwargs))
                try:
                    num_epochs = len(history.history['loss'])
                    train_loss = history.history['loss'][-1]
                    train_acc = history.history[model.compiled_metrics.metrics[0].name][-1]
                except:
                    num_epochs = 0
                    train_loss = 0
                    train_acc = 0
                
                if self._multi:
                    res = model.evaluate(self._data.x_te, self._data.y_te_1hot, verbose=0)
                else:
                    res = model.evaluate(self._data.x_te, self._data.y_te, verbose=0)
                
                pb.add(1, [
                    ('train_loss', train_loss), 
                    ('train_acc', train_acc), 
                    ('test_loss', res[0]), 
                    ('test_acc', res[1]), 
                    ('epochs', num_epochs)
                ])

                model_file = os.path.join(tmp_dir_name, f'loo_{i}.h5')
                # preds_file = os.path.join(tmp_dir_name, f'loo_dsds_{i}.h5')
                
                model.save(model_file)
                    # np.save(preds_file, model_v.predict(self._data.x_te))
                # model_v.summary()

                models.append(model_file)
                del model

            self._models = models

        else:
            models = self._models
        
        return models
    
    def _get_prediction_sets(self, models, x, y, batch_size=None):

        data = tf.data.Dataset.from_tensor_slices((x))
        if batch_size is not None:
            data = data.batch(batch_size)

        y_preds = []
        for (i, m_file) in enumerate(models):
            m = tf.keras.models.load_model(m_file, custom_objects=self._custom_objects)
            if self._multi:
                y_preds.append(np.expand_dims(tf.nn.softmax(m.predict(data)).numpy(), -1))
            else:
                y_preds.append(np.expand_dims(tf.nn.sigmoid(m.predict(data)).numpy(), -1))
            del m
        y_preds = np.concatenate(y_preds, axis=-1)
        if self._multi:
            y_labels = y_preds.argmax(axis=1)
            y_confs = y_preds.max(axis=1)
        else:
            y_labels = y_preds>0.5
            y_confs = y_preds

        return y_labels, y_confs
    
    def run(self, eval_batch_size=128, do_save=True):

        models = self._train_all_models()

        train_labels, train_confs = self._get_prediction_sets(
            models, 
            self._data.x_tr, 
            self._data.y_tr, 
            batch_size=eval_batch_size)
        
        test_labels, test_confs = self._get_prediction_sets(
            models, 
            self._data.x_te, 
            self._data.y_te, 
            batch_size=eval_batch_size)

        np.save(f"/../data/{self._data}_yp_train_loo.npy", train_confs)
        np.save(f"/../data/{self._data}_yp_test_loo.npy", test_confs)

        train_unique_labels = np.array([len(np.unique(y)) for y in train_labels])
        test_unique_labels = np.array([len(np.unique(y)) for y in test_labels])

        train_stable_points = np.where(train_unique_labels <= 1)[0]
        train_unstable_points = np.where(train_unique_labels > 1)[0]
        test_stable_points = np.where(test_unique_labels <= 1)[0]
        test_unstable_points = np.where(test_unique_labels > 1)[0]

        train_stable_confs = train_confs[train_stable_points]
        train_unstable_confs = train_confs[train_unstable_points]
        test_stable_confs = test_confs[test_stable_points]
        test_unstable_confs = test_confs[test_unstable_points]


        info = {
            'train': {
                'labels': train_labels,
                'confs': train_confs,
                'unique_labels': train_unique_labels,
                'stable_points': train_stable_points,
                'unstable_points': train_unstable_points,
                'stable_confs': train_stable_confs,
                'unstable_confs': train_unstable_confs
            },
            'test': {
                'labels': test_labels,
                'confs': test_confs,
                'unique_labels': test_unique_labels,
                'stable_points': test_stable_points,
                'unstable_points': test_unstable_points,
                'stable_confs': test_stable_confs,
                'unstable_confs': test_unstable_confs
            }
        }

        self._info = info

        return info
