import numpy as np
import copy
from sklearn import metrics


class train_logger:

    '''
    An instance of this class keeps track of various metrics throughout
    the training process.
    '''

    def __init__(self, params):

        self.params = params

        # epoch-level objects:
        self.best_stop_metric = -np.Inf
        self.best_epoch = -1
        self.running_loss = 0.0
        self.num_examples = 0

        # batch-level objects:
        self.temp_preds = []
        self.temp_true = [] # true labels
        # self.temp_indices = [] # indices for each example
        # self.temp_batch_class_loss = []
        self.temp_batch_domain_loss = []

        # output objects:
        self.logs = {}
        self.logs['metrics'] = {}
        self.logs['best_preds'] = {}
        self.logs['gt'] ={}
        self.logs['targ'] = {}
        # self.logs['idx'] = {}
        for field in self.logs:
            for phase in ['train', 'test']:
                self.logs[field][phase] = {}

    def compute_phase_metrics(self, phase, epoch, exp_mode):
        '''
        Compute and store end-of-phase metrics.
        '''

        self.logs['metrics'][phase][epoch] = {}

        # compute metrics w.r.t. ground truth labels:
        metrics_clean = compute_metrics(self.temp_preds, self.temp_true)
        for k in metrics_clean:
            self.logs['metrics'][phase][epoch][k] = metrics_clean[k]

        if phase == 'train':
            self.logs['metrics'][phase][epoch]['class_loss'] = self.running_loss / self.num_examples

            if exp_mode == 'adaptation':
                self.logs['metrics'][phase][epoch]['domain_loss'] = np.mean(self.temp_batch_domain_loss)
            else:
                self.logs['metrics'][phase][epoch]['domain_loss'] = -999
        # if (phase == 'train') and (exp_mode == 'adaptation'):
        #     self.logs['metrics'][phase][epoch]['class_loss'] = self.running_loss / self.num_examples
        #     self.logs['metrics'][phase][epoch]['domain_loss'] = np.mean(self.temp_batch_domain_loss)
        #
        # elif (phase == 'train') and (exp_mode == 'source_only'):
        #     self.logs['metrics'][phase][epoch]['class_loss'] = self.running_loss / self.num_examples
        #     self.logs['metrics'][phase][epoch]['domain_loss'] = -999

        else:
            self.logs['metrics'][phase][epoch]['class_loss'] = None
            self.logs['metrics'][phase][epoch]['domain_loss'] = None

        # self.logs['metrics'][phase][epoch]['preds_k_hat'] = np.mean(np.sum(self.temp_preds, axis=1))

    def get_stop_metric(self, phase, epoch):
        '''
        Query the stop metric.
        '''
        return self.logs['metrics'][phase][epoch][self.params['stop_metric']]

    def update_phase_data(self, batch, P, phase):
        '''
        Store data from a batch for later use in computing metrics.
        '''

        if (P['exp_mode'] == 'adaptation') and (phase=='train'):
            nb_effective_samples = batch['labels_np'].shape[0]
            self.temp_batch_domain_loss.append(float(batch['da_loss_np']))
        else:
            nb_effective_samples = batch['images'].shape[0]

        for i in range(nb_effective_samples):
            self.temp_preds.append(batch['preds_np'][i, :].tolist())
            self.temp_true.append(batch['labels_np'][i])
            self.num_examples += 1

        if phase=='train':
            self.running_loss += float(batch['cs_loss_np'] * nb_effective_samples)


    def reset_phase_data(self):

        '''
        Reset for a new phase.
        '''

        self.temp_preds = []
        self.temp_true = []
        self.temp_indices = []
        self.temp_batch_domain_loss = []
        self.running_loss = 0.0
        self.num_examples = 0.0

    def update_best_results(self, phase, epoch):

        '''
        Update the current best epoch info if applicable.
        '''

        if phase == 'train':
            return False
        elif phase == 'val':
            cur_stop_metric = self.get_stop_metric(phase, epoch)
            if cur_stop_metric > self.best_stop_metric:
                self.best_stop_metric = cur_stop_metric
                self.best_epoch = epoch
                self.logs['best_preds'][phase] = self.temp_preds
                self.logs['gt'][phase] = self.temp_true
                self.logs['idx'][phase] = self.temp_indices
                return True # new best found
            else:
                return False # new best not found

        elif phase == 'test':
            if epoch == self.best_epoch:
                self.logs['best_preds'][phase] = self.temp_preds
                self.logs['gt'][phase] = self.temp_true
                self.logs['idx'][phase] = self.temp_indices
            return False

    def get_logs(self):

        '''
        Return a copy of all log data.
        '''

        return copy.deepcopy(self.logs)

    def report(self, t_i, t_f, phase, epoch, lmd=None):
        if phase == 'train':
            if lmd is None:
                report = '[{}] time: {:.2f} min, CE: {:.5f}, ACC: {:.3f}'.format(
                    phase,
                    (t_f - t_i) / 60.0,
                    self.logs['metrics'][phase][epoch]['class_loss'],
                    self.logs['metrics'][phase][epoch]['acc']
                    )
            else:
                report = '[{}] time: {:.2f} min, CE: {:.5f}, DA: {:.4f}, ACC: {:.3f}, lamda: {:.3f}'.format(
                    phase,
                    (t_f - t_i) / 60.0,
                    self.logs['metrics'][phase][epoch]['class_loss'],
                    self.logs['metrics'][phase][epoch]['domain_loss'],
                    self.logs['metrics'][phase][epoch]['acc'],
                    lmd
                )
        else:
            report = '[{}] time: {:.2f} min, ACC: {:.3f}'.format(
                phase,
                (t_f - t_i) / 60.0,
                self.logs['metrics'][phase][epoch]['acc']
                )


        print(report, flush=True)


def compute_metrics(y_pred, y_true):

    '''
    Given predictions and labels, compute a few metrics.
    '''

    # num_examples, num_classes = np.shape(y_pred)

    results = {}
    average_precision_list = []
    y_pred = np.array(y_pred)
    y_true = np.array(y_true)

    y_pred = np.argmax(y_pred, axis=1) # convert from one hot vec to class indices

    results['acc'] = metrics.accuracy_score(y_true, y_pred)

    return results