import os
import jax
import torch
import time
import random
import numpy as np
import pandas as pd
from copy import deepcopy
from collections import OrderedDict
from torch.utils.tensorboard import SummaryWriter

from src.utils import util, ntk_util, batchnorm_utils, diagonality_util
from src.builders import model_builder, dataloader_builder, checkpointer_builder,\
                         optimizer_builder, criterion_builder, scheduler_builder,\
                         meter_builder, evaluator_builder
from src.utils.probability_utils import project_into_probability_simplex, calc_MI_for_pairwise_features

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = '.80'
#torch.manual_seed(0)


class BaseEngine(object):

    def __init__(self, config_path, logger, save_dir):
        # Assign a logger
        self.logger = logger

        # Load configurations
        config = util.load_config(config_path)

        self.model_config = config['model']
        self.train_config = config['train']
        self.eval_config = config['eval']
        self.data_config = config['data']
        self.diagonality_config = self.model_config['diagonality']

        self.eval_standard = self.eval_config['standard']

        # Determine which device to use
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.device = torch.device(device)
        self.num_devices = torch.cuda.device_count()

        if device == 'cpu':
            self.logger.warn('GPU is not available.')
        else:
            self.logger.warn('GPU is available with {} devices.'.format(self.num_devices))
        self.logger.warn('CPU is available with {} devices.'.format(jax.device_count('cpu')))

        # Load a summary writer
        self.save_dir = save_dir
        log_dir = os.path.join(self.save_dir, 'logs')
        os.makedirs(log_dir, exist_ok=True)
        self.writer = SummaryWriter(log_dir=log_dir)

    def run(self):
        pass

    def evaluate(self):
        pass


class Engine(BaseEngine):

    def __init__(self, config_path, logger, save_dir):
        super(Engine, self).__init__(config_path, logger, save_dir)

    def _build(self, mode, init=False):

        # Build a dataloader
        self.dataloaders = dataloader_builder.build(
            self.data_config, self.logger)

        # Build a model
        self.models = model_builder.build(
            deepcopy(self.model_config), self.data_config, self.logger)

        # Use multi GPUs if available
        if not isinstance(self.models['model'], torch.nn.DataParallel):
            if torch.cuda.device_count() > 1:
                self.models['model'] = util.DataParallel(self.models['model'])
            self.models['model'] = self.models['model'].to(self.device)

        # Build an optimizer, scheduler and criterion
        self.optimizer = optimizer_builder.build(
            self.train_config['optimizer'], self.models['model'].parameters(), self.logger)
        self.scheduler = scheduler_builder.build(
            self.train_config, self.optimizer, self.logger,
            self.train_config['num_epochs'], len(self.dataloaders['train']))
        self.criterion = criterion_builder.build(
            self.train_config, self.model_config, self.logger)
        self.loss_meter, self.pr_meter = meter_builder.build(
            self.model_config, self.logger)
        self.evaluator = evaluator_builder.build(
            self.eval_config, self.logger)

        # Build a checkpointer
        self.checkpointer = checkpointer_builder.build(
            self.save_dir, self.logger, self.models['model'], self.optimizer,
            self.scheduler, self.eval_standard, init=init)
        checkpoint_path = self.model_config.get('checkpoint_path', '')
        self.misc = self.checkpointer.load(
            mode=mode, checkpoint_path=checkpoint_path, use_latest=False)

        checkpoint_path = os.path.join(self.save_dir, 'checkpoint_initial.pth')
        model_params = {'trial': 0}
        if torch.cuda.device_count() > 1:
            model_params['state_dict'] = self.models['model'].module.state_dict()
        else:
            model_params['state_dict'] = self.models['model'].state_dict()
        torch.save(model_params, checkpoint_path)

    def run(self):
        trials = self.train_config.get('trials', 1)
        for trial in range(trials):
            # Build components
            init = True if trial == 0 else False
            self._build(mode='train', init=init)
            self._train(trial)
            self.checkpointer.record_results(trial)
            self.checkpointer.reset()

    def _train(self, trial):
        start_epoch, num_steps = 0, 0
        num_epochs = self.train_config.get('num_epochs', 200)
        evaluate_after = self.train_config.get('evaluate_after', 0.8)
        #checkpoint_step = self.train_config.get('checkpoint_step', 10000)

        self.logger.info(
            'Trial {} - train for {} epochs starting from epoch {}'.format(
                trial, num_epochs, start_epoch))

        if self.train_config.get('manual_train_control', False):
            print('Training...')
            import IPython; IPython.embed()
            # originals = {'opt': deepcopy(self.optimizer), 'sch': deepcopy(self.scheduler), 'model': deepcopy(self.models['model'])}
            # self.optimizer = originals['opt']; self.scheduler = originals['sch']; self.models['model'] = originals['model']

        # Start training
        for epoch in range(start_epoch, start_epoch + num_epochs):

            applied_test = False

            epoch_diff = epoch - start_epoch
            if self.diagonality_config.get('schedule') == 'periodic':
                if epoch_diff % self.diagonality_config['period'] == 0:
                    self.test_diagonality(epoch_diff)
                    applied_test = True
            else:
                if epoch_diff in self.diagonality_config['target_epochs']:
                    self.test_diagonality(epoch_diff)
                    applied_test = True

            if not self.train_config.get('needs_training', True):
                break

            train_start = time.time()
            num_steps = self._train_one_epoch(epoch, num_steps)
            torch.cuda.empty_cache()  # after train
            train_time = time.time() - train_start

            lr = self.scheduler.get_lr()[0]

            self.logger.infov(
                '[Epoch {}] with lr: {:5f} completed in {:3f} - train loss: {:4f}'\
                .format(epoch, lr, train_time, self.loss_meter.avg))
            self.writer.add_scalar('Train/learning_rate', lr, global_step=num_steps)

            if not self.train_config['lr_schedule']['name'] in ['onecycle']:
                self.scheduler.step()

            self.loss_meter.reset()

            if epoch - start_epoch > evaluate_after * num_epochs:
                is_last_epoch = start_epoch + num_epochs == epoch + 1
                eval_metrics = self._evaluate_once(trial, epoch, num_steps, is_last_epoch)
                torch.cuda.empty_cache()  # after eval
                self.checkpointer.save(epoch, num_steps, trial, eval_metrics)
                #self.checkpointer.update(eval_metrics)
                self.logger.info(
                    '[Epoch {}] - {}: {:4f}'.format(
                        epoch, self.eval_standard, eval_metrics[self.eval_standard]))
                self.logger.info(
                    '[Epoch {}] - best {}: {:4f}'.format(
                        epoch, self.eval_standard, self.checkpointer.best_eval_metric))

            if not applied_test and epoch - start_epoch == num_epochs - 1:  # they usually end at xx9'th epoch
                self.test_diagonality(epoch+1)

        if self.train_config.get('adjust_batchnorm_stats', True):
            self._adjust_batchnorm_to_population()
            eval_metrics = self._evaluate_once(trial, start_epoch + num_epochs, num_steps, True)
            self.checkpointer.save(start_epoch + num_epochs, num_steps, trial, eval_metrics)
            # self.checkpointer.update(eval_metrics)
            self.logger.info(
                'After BN adjust - {}: {:4f}'.format(self.eval_standard, eval_metrics[self.eval_standard]))
            self.logger.info(
                'After BN adjust - best {}: {:4f}'.format(self.eval_standard, self.checkpointer.best_eval_metric))

    def _train_one_epoch(self, epoch, num_steps):

        self.models['model'].train()
        dataloader = self.dataloaders['train']
        num_batches = len(dataloader)

        for i, input_dict in enumerate(dataloader):
            input_dict = util.to_device(input_dict, self.device)

            # Forward propagation
            self.optimizer.zero_grad()
            output_dict = self.models['model'](input_dict)

            # Compute losses
            output_dict['labels'] = input_dict['labels']

            losses = self.criterion(output_dict)
            loss = losses['loss']

            # Backward propagation
            loss.backward()
            self.optimizer.step()

            # Print losses
            batch_size = input_dict['inputs'].size(0)
            self.loss_meter.update(loss.item(), batch_size)
            if i % (len(dataloader) / 10) == 0:
                self.loss_meter.print_log(epoch, i, num_batches)

            # step scheduler if needed
            if self.train_config['lr_schedule']['name'] in ['onecycle']:
                self.scheduler.step()

            # Save a checkpoint
            num_steps += batch_size
            del input_dict

        return num_steps

    def evaluate(self):
        def _get_misc_info(misc):
            infos = ['epoch', 'num_steps', 'checkpoint_path']
            return (misc[info] for info in infos)

        self._build(mode='eval')
        epoch, num_steps, current_checkpoint_path = _get_misc_info(self.misc)
        last_evaluated_checkpoint_path = None
        while True:
            if last_evaluated_checkpoint_path == current_checkpoint_path:
                self.logger.warn('Found already evaluated checkpoint. Will try again in 60 seconds.')
                time.sleep(60)
            else:
                eval_metrics = self._evaluate_once(epoch, num_steps)
                last_evaluated_checkpoint_path = current_checkpoint_path
                self.checkpointer.save(
                    epoch, num_steps, eval_metrics=eval_metrics)

            # Reload a checkpoint. Break if file path was given as checkpoint path.
            checkpoint_path = self.model_config.get('checkpoint_path', '')
            if os.path.isfile(checkpoint_path): break
            misc = self.checkpointer.load(checkpoint_path, use_latest=True)
            epoch, num_step, current_checkpoint_path = _get_misc_info(misc)

    def _evaluate_once(self, trial, epoch, num_steps, is_last_epoch=False):
        dataloader = self.dataloaders['val']
        num_batches = len(dataloader)

        self.models['model'].eval()
        self.logger.info('[Epoch {}] Evaluating...'.format(epoch))
        labels = []
        outputs = []

        for input_dict in dataloader:
            with torch.no_grad():
                input_dict = util.to_device(input_dict, self.device)
                # Forward propagation
                output_dict = self.models['model'](input_dict)
                output_dict['labels'] = input_dict['labels']
                labels.append(input_dict['labels'])
                outputs.append(output_dict['logits'])
                del input_dict['inputs']

        output_dict = {
            'logits': torch.cat(outputs),
            'labels': torch.cat(labels)
        }

        if is_last_epoch and False:
            probs = project_into_probability_simplex(output_dict['logits'].detach().cpu().numpy())
            mis = calc_MI_for_pairwise_features(probs)
            print('Mutual information table:')
            print(pd.DataFrame(mis))
            print('Norm of them:', np.linalg.norm(mis))

        # Print losses
        self.evaluator.update(output_dict)

        del output_dict
        torch.cuda.empty_cache()

        self.evaluator.print_log(epoch, num_steps)
        # torch.cuda.empty_cache()
        eval_metric = self.evaluator.compute()

        # Reset the evaluator
        self.evaluator.reset()
        return {self.eval_standard: eval_metric}

    def _forward_pass(self, partition='subset'):
        unlabeled_dataset = self.dataloaders['unlabeled'].dataset
        data = self.dataloaders[partition]
        X, _ = ntk_util.get_full_data(unlabeled_dataset, data)
        batch_size = X.shape[0]
        predictions = []
        with torch.no_grad():
            for i in range(0, batch_size, 1000):
                pred = self.models['model'](
                    {'inputs': X[i:i+1000].to(self.device)})['logits'].detach().cpu().numpy()
                predictions.append(pred)
        return np.concatenate(predictions)

    def _adjust_batchnorm_to_population(self):

        self.logger.info('Adjusting BatchNorm statistics to population values...')

        net = self.models['model']
        train_dataset = self.dataloaders['train'].dataset
        trainloader = torch.utils.data.DataLoader(train_dataset,
                                                 batch_size=self.data_config['batch_size'],
                                                 num_workers=self.data_config['num_workers'],
                                                 drop_last=True)

        net.apply(batchnorm_utils.adjust_bn_layers_to_compute_populatin_stats)
        for _ in range(3):
            with torch.no_grad():
                for input_dict in trainloader:
                    input_dict = util.to_device(input_dict, self.device)
                    net(input_dict)
        net.apply(batchnorm_utils.restore_original_settings_of_bn_layers)

        self.logger.info('BatchNorm statistics adjusted.')

    def test_diagonality(self, epoch):

        bn_with_running_stats = self.model_config['model_arch'].get('bn_with_running_stats', True)

        if self.diagonality_config.get('update_model_params', True):
            self.models['ntk_params'] = ntk_util.update_ntk_params(
                self.models['ntk_params'], self.models['model'], bn_with_running_stats)

        # print('Check model differences here...')
        # import IPython; IPython.embed()

        # t_model = self.models['model'].eval()
        # f_model = self.models['ntk_model']
        # f_params = self.models['ntk_params']
        #
        # from src.utils.model_utils import FlaxSequential as Sequential
        # from jax import numpy as jnp
        #
        # unlabeled_dataset = self.dataloaders['unlabeled'].dataset
        # data = self.dataloaders['labeled_set'][:1000]
        # X, _ = ntk_util.get_full_data(unlabeled_dataset, data)
        # j_X = jnp.asarray(X.detach().cpu().numpy().transpose(0, 2, 3, 1), dtype=jnp.float64)
        #
        # custom_model = Sequential(f_model.layers[:2])
        # f_out = custom_model.apply(f_params, j_X[:100])
        # t_out = t_model.layer1(t_model.conv1(X[:100].to(self.device)))
        # t_out = t_model.layer2(t_model.layer1(t_model.conv1(X[:100].to(self.device))))
        #
        # f_out = f_model.apply(f_params, j_X[:100])
        # t_out = t_model({'inputs': X[:100].to(self.device)})['logits']
        #
        # np.testing.assert_almost_equal(f_out, t_out.detach().cpu().numpy().transpose((0, 2, 3, 1)), decimal=6)
        # np.testing.assert_almost_equal(f_out, t_out.detach().cpu().numpy(), decimal=6)

        device_count = self.num_devices if torch.cuda.device_count() > 0 else 1
        diagonality_util.compute_and_save_ntk(self.dataloaders, self.models, epoch,
                                              self.model_config, self.data_config, self.save_dir, device_count)
