import tensorflow as tf
import numpy as np

from metrics.osr import evaluation
from metrics.metrics import weights_convergence_metric

class Model():
    
    def __init__(self, 
                 model, 
                 optimizer, 
                 loss_helper, 
                 metrics, 
                 parallel_strategy, 
                 global_batch_size, 
                 nb_batches, 
                 verbose=1
            ):
        self.model = model
        self.optimizer = optimizer
        self.loss_helper = loss_helper
        self.parallel_strategy = parallel_strategy
        self.nb_batches = nb_batches
        self.verbose = verbose
        
        with self.parallel_strategy.scope():            
            def compute_loss(labels, predictions, model_losses):
                per_example_loss = self.loss_helper.loss(labels, predictions)
                loss = tf.nn.compute_average_loss(
                    per_example_loss, 
                    global_batch_size=global_batch_size
                )
                if model_losses:
                    ml = tf.nn.scale_regularization_loss(
                        tf.add_n(model_losses)
                    )
                    loss += ml # TODO : test difference when not added to loss
                else:
                    ml = 0.0
                return loss, ml
            self.compute_loss = compute_loss
            
            self.val_loss = tf.keras.metrics.Mean(name='val_loss')
            self.model_losses = tf.keras.metrics.Mean(name='model_losses')
        
        self.metrics = metrics
        self.__init_history()
        
        
    #--------------------------------------------------------------------------#
    # Metrics
    
    def update_metrics(self, y_true, y_pred):
        for metric in self.metrics:
            pred_label = self.loss_helper.predicted_class(y_pred)
            metric.update_state(y_true, pred_label)
            
    def reset_metrics(self):
        for metric in self.metrics:
            metric.reset_states()
            
    def get_metrics(self):
        return {m.name: m.result() for m in self.metrics}
                # if type(m) != dict 
                # else {k: v.dtype() for k, v in m.items()}
                # for m in self.metrics}
    
    def __init_history(self):
        res = self.get_metrics()
        
        self.history = {}
        for key, val in res.items():
                if type(val) == dict:
                    self.history[key] = {k: [] for k in val}
                    self.history['val_'+key] = {k: [] for k in val}
                    self.history['test_'+key] = {k: [] for k in val}
                else:
                    self.history[key] = []
                    self.history['val_'+key] = []
                    self.history['test_'+key] = []

        self.history['loss'] = []
        self.history['val_loss'] = []
        self.history['test_loss'] = []
        self.history['model_losses'] = []
        self.history['weights_metric'] = []
        
        self.history['real_auroc'] = []
        self.history['max_val_auroc'] = []
        self.history['oscr'] = []
    
    def __update_history(self, 
                         train_loss, train_metrics, 
                         val_loss, val_metrics, 
                         test_loss, test_metrics,
                         weights_metric,
                         osr_results):
        self.history['loss'].append(train_loss)
        self.history['val_loss'].append(val_loss)
        self.history['test_loss'].append(test_loss)
        self.history['model_losses'].append(self.model_losses.result().numpy())
        if weights_metric is not None:
            self.history['weights_metric'].append(list(weights_metric.numpy()))
        
        if osr_results != None:
            self.history['real_auroc'].append(osr_results['real_auroc'])
            self.history['max_val_auroc'].append(osr_results['max_val_auroc'])
            self.history['oscr'].append(osr_results['oscr'])
        
        def update_dict(metrics, prefix):
            for key, val in metrics.items():
                if type(val) == dict:
                    for k, v in val.items():
                        self.history[prefix+key][k].append(v)
                else:
                    self.history[prefix+key].append(val) 
                                   
        update_dict(train_metrics, prefix="")
        update_dict(val_metrics, prefix="val_")
        update_dict(test_metrics, prefix="test_")

    
    #--------------------------------------------------------------------------#
    # Training
    def train_step(self, inputs):
        images, labels = inputs
        
        with tf.GradientTape() as tape:
            predictions = self.model(images, training=True)
            loss, ml = self.compute_loss(labels, predictions, self.model.losses)
                    
        gradients = tape.gradient(loss, self.model.trainable_variables)
        self.optimizer.apply_gradients(
            zip(gradients, self.model.trainable_variables)
        )
        
        self.model_losses.update_state(ml)
        self.update_metrics(labels, predictions)
        return loss
    
    @tf.function
    def distributed_train_step(self, dataset_inputs):
        per_replica_losses = self.parallel_strategy.run(
            self.train_step, 
            args=(dataset_inputs,)
        )
        # reduce metrics if dict ?
        return self.parallel_strategy.reduce(
            tf.distribute.ReduceOp.SUM, 
            per_replica_losses, 
            axis=None
        )
    
    #--------------------------------------------------------------------------#
    # Evaluation
    def test_step(self, inputs):
        images, labels = inputs

        predictions = self.model(images, training=False)
        t_loss = self.loss_helper.loss(labels, predictions)
        
        self.val_loss.update_state(t_loss)
        self.update_metrics(labels, predictions)
        
    @tf.function
    def distributed_test_step(self, dataset_inputs):
        return self.parallel_strategy.run(
            self.test_step, 
            args=(dataset_inputs,)
        )
        
    def test(self, ds_test_dict):
        self.reset_metrics()
        self.val_loss.reset_states()
        
        for x in ds_test_dict:
            self.distributed_test_step(x)
            
        metrics_dict = self.get_metrics()
        metrics = {metric : val.numpy() 
                   if type(val) != dict 
                   else {k : v.numpy() for k, v in val.items()}
                   for metric, val in metrics_dict.items()}
            
        return self.val_loss.result().numpy(), metrics
        
    #--------------------------------------------------------------------------#
    # Prediction
    def predict_step(self, inputs, compute_metrics=False):
        images, labels = inputs        
        predictions = self.model(images, training=False)
        
        if compute_metrics:
            t_loss = self.loss_helper.loss(labels, predictions)
            
            self.val_loss.update_state(t_loss)
            self.update_metrics(labels, predictions)
        
        return predictions
    
    @tf.function
    def distributed_prediction_step(self, inputs, compute_metrics=False):
        # Perform the forward pass
        predictions = self.parallel_strategy.run(
            self.predict_step,
            args=(inputs, compute_metrics)
        )

        # Return predictions
        return self.parallel_strategy.experimental_local_results(predictions)
    
    def predict(self, ds, compute_metrics=False):
        if compute_metrics:
            self.reset_metrics()
            self.val_loss.reset_states()
            
        all_predictions = []

        for x in ds:
            predictions = self.distributed_prediction_step(x, compute_metrics)
            
            all_predictions.extend(predictions)

        return tf.concat(all_predictions, axis=0).numpy()
    
    #--------------------------------------------------------------------------#
    # Training loop
    def train(self, datasets, epochs):
        ds_train_dist = datasets["ds_train_known"]
        ds_val_known = datasets["ds_val_known"]
                    
        # Train
        for epoch in range(epochs):
            tf.print("Learning rate:", self.optimizer.lr)
            
            # Reset the metrics at the start of the next epoch
            self.reset_metrics()
            self.val_loss.reset_states()
            self.model_losses.reset_states()
            
            total_loss = 0.0
            n_batch = 0
                
            print(f'Epoch {epoch + 1:>2}/{epochs:<3}')
            progbar = tf.keras.utils.Progbar(self.nb_batches, stateful_metrics=['loss', 'val_loss'], verbose=self.verbose, )
            
            for x in ds_train_dist:
                n_batch += 1
                total_loss += self.distributed_train_step(x)

                progbar.update(n_batch, values=[('loss', (total_loss / tf.cast(n_batch, dtype=tf.float32)))], finalize=False)

            train_loss = total_loss / tf.cast(n_batch, dtype=tf.float32)
            train_metrics = self.get_metrics()
            print(train_metrics)
            
            prog_values = []
            for k in train_metrics:
                if type(train_metrics[k]) != dict:
                    prog_values.append((k, train_metrics[k]))
                else:
                    for k2 in train_metrics[k]:
                        prog_values.append((k2, train_metrics[k][k2]))
            
            weights_metric = None
            if self.loss_helper.distance_based:
                try: 
                    layer = self.model.get_layer("last_conv")
                except: 
                    layer = None
                
                if layer is not None:
                    weights = layer.get_weights()[0]
                    weights_metric = weights_convergence_metric(weights,
                                                    self.loss_helper.nb_classes,
                                                    self.loss_helper.nb_features)
                    print("Weights convergence metric (std):", weights_metric.numpy())
            
            prog_values.insert(0, ('loss', train_loss))
            progbar.update(n_batch, values=prog_values, finalize=True)
            
            # Validation
            if ds_val_known is not None:
                self.reset_metrics()
                val_loss, metrics = self.test(ds_val_known)
                val_metrics = metrics
                
                print("Validation results :")
                print(f"\tLoss: {val_loss}\n\t", end="")
                for k, v in val_metrics.items():
                    print(f"{k}: {v}", end=" ")
                print()
            
            # Test
            self.reset_metrics()
            
            pred_known = self.predict(datasets["ds_test_known"], compute_metrics=True)
            if datasets["ds_test_unknown"] is not None:
                pred_unknown = self.predict(datasets["ds_test_unknown"])
            
            labels = np.array([labels for _, labels in datasets["ds_test_known"].unbatch()])
            test_loss = self.val_loss.result().numpy()
            
            print("Test results :")                
            print(f"\tLoss (on known):{test_loss}\n\t", end="")
            test_metrics = self.get_metrics()
            for k, v in test_metrics.items():
                print(f"{k}: {v}", end=" ")
            print()
            
            if datasets["ds_test_unknown"] is not None:
                results = evaluation(pred_known, pred_unknown, labels, self.loss_helper)
            else:
                results = None
                
            self.__update_history(train_loss, train_metrics, 
                                  val_loss if ds_val_known else None,
                                  val_metrics if ds_val_known else {}, 
                                  test_loss, test_metrics, 
                                  weights_metric,
                                  results)
            print()
            
        return self.history
