import traceback
from copy import deepcopy
from pathlib import Path

import torch
import numpy as np
from torch import optim
from torch.nn import functional as F
from tqdm import tqdm

from relnet.agent.base_agent import Agent
from relnet.io.file_paths import FilePaths
from relnet.utils.config_utils import get_device_placement
from relnet.utils.general_utils import get_memory_usage_str


class PyTorchAgent(Agent):
    NUM_BASELINE_OBJ_SAMPLES = 100

    def __init__(self):
        super().__init__()
        self.enable_assertions = True
        self.hist_out = None

        self.setup_step_metrics()

    def setup_step_metrics(self):
        self.validation_change_threshold = 0.
        self.best_validation_changed_steps = {}
        self.best_validation_losses = {}
        self.step = 0

    def setup_graphs(self, train_g_list, validation_g_list):
        self.train_g_list = train_g_list
        self.validation_g_list = validation_g_list

        self.train_initial_obj_values = self.get_baseline_obj_values(self.train_g_list, use_zeros=True)
        self.validation_initial_obj_values = self.get_baseline_obj_values(self.validation_g_list, use_zeros=True)

    def get_baseline_obj_values(self, g_list, use_zeros=False):
        all_baseline_vals = np.zeros((self.NUM_BASELINE_OBJ_SAMPLES, len(g_list)), dtype=np.float64)
        if use_zeros:
            return np.mean(all_baseline_vals, axis=0)
        raise ValueError("Not supported.")


    def save_model_checkpoints(self, model_suffix=None):
        model_path = self.get_model_path(model_suffix, init_dir=True)
        torch.save(self.net.state_dict(), model_path)

    def restore_model_from_checkpoint(self, model_suffix=None):
        model_path = self.get_model_path(model_suffix, init_dir=True)
        checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage)
        self.net.load_state_dict(checkpoint)

    def get_model_path(self, model_suffix, init_dir=False):
        model_dir = self.checkpoints_path / self.model_identifier_prefix
        if init_dir:
            model_dir.mkdir(parents=True, exist_ok=True)

        if model_suffix is None:
            model_path = model_dir / f"{self.algorithm_name}_agent.model"
        else:
            model_path = model_dir / f"{self.algorithm_name}_agent_{model_suffix}.model"

        return model_path

    def check_validation_loss_if_req(self, step_number, max_steps,
                                     make_action_kwargs=None,
                                     model_tag='default',
                                     save_model_if_better=True,
                                     save_with_tag=False):

        if step_number % self.validation_check_interval == 0 or step_number == max_steps:
            self.check_validation_loss(step_number, max_steps,
                                       make_action_kwargs,
                                       model_tag,
                                       save_model_if_better,
                                       save_with_tag)

    def check_validation_loss(self, step_number, max_steps,
                              make_action_kwargs=None,
                              model_tag='default',
                              save_model_if_better=True,
                              save_with_tag=False):

        if model_tag not in self.best_validation_changed_steps:
            self.best_validation_changed_steps[model_tag] = -1
            self.best_validation_losses[model_tag] = float("inf")

        validation_loss = self.log_validation_loss(step_number, model_tag, make_action_kwargs=make_action_kwargs)
        if self.log_progress: self.logger.info(
            f"<<{self.algorithm_name}>> {model_tag if model_tag != 'default' else 'model'} validation loss: {validation_loss: .4f} at step "
            f"{step_number}.")

        if (self.best_validation_losses[model_tag] - validation_loss) > self.validation_change_threshold:
            if self.log_progress: self.logger.info(
                f"rejoice! found a better validation loss for model {model_tag} at step {step_number}.")
            self.best_validation_changed_steps[model_tag] = step_number
            self.best_validation_losses[model_tag] = validation_loss

            if save_model_if_better:
                if self.log_progress: self.logger.info("saving model since validation loss is better.")
                model_suffix = model_tag if save_with_tag else None
                self.save_model_checkpoints(model_suffix=model_suffix)

    def log_validation_loss(self, step, model_tag, make_action_kwargs=None):
        validation_loss = self.predict_and_score(self.validation_g_list, make_action_kwargs)
        if self.options['log_memory_usage']:
            self.logger.info(get_memory_usage_str())
        if self.log_tf_summaries:
            try:
                import tensorflow as tf
                tf.summary.scalar(f"{model_tag}_validation_loss", validation_loss, step=step)
                self.file_writer.flush()
            except BaseException:
                if self.logger is not None:
                    self.logger.warn("caught an exception when trying to flush TF data.")
                    self.logger.warn(traceback.format_exc())

        if self.hist_out is not None:
            self.hist_out.write('%d,%s,%.6f\n' % (step, model_tag, validation_loss))
            try:
                self.hist_out.flush()
            except BaseException:
                if self.logger is not None:
                    self.logger.warn("caught an exception when trying to flush evaluation history.")
                    self.logger.warn(traceback.format_exc())

        return validation_loss

    def count_parameters(self):
        return sum(p.numel() for p in self.net.parameters() if p.requires_grad)

    def print_model_parameters(self, only_first_layer=True):
        param_list = self.net.parameters()

        for params in param_list:
            print(params.view(-1).data)
            if only_first_layer:
                break

    def setup(self, options, hyperparams):
        super().setup(options, hyperparams)
        if 'validation_check_interval' in options:
            self.validation_check_interval = options['validation_check_interval']
        else:
            self.validation_check_interval = 1

        if 'max_validation_consecutive_steps' in options:
            self.max_validation_consecutive_steps = options['max_validation_consecutive_steps']
        else:
            #self.max_validation_consecutive_steps = 500
            self.max_validation_consecutive_steps = 1500

        if 'pytorch_full_print' in options:
            if options['pytorch_full_print']:
                torch.set_printoptions(profile="full")

        if 'enable_assertions' in options:
            self.enable_assertions = options['enable_assertions']

        if 'model_identifier_prefix' in options:
            self.model_identifier_prefix = options['model_identifier_prefix']
        else:
            self.model_identifier_prefix = FilePaths.DEFAULT_MODEL_PREFIX

        if 'restore_model' in options:
            self.restore_model = options['restore_model']
        else:
            self.restore_model = False

        self.models_path = self.options['file_paths'].models_dir
        self.checkpoints_path = self.options['file_paths'].checkpoints_dir

        if 'log_tf_summaries' in options and options['log_tf_summaries']:
            self.summaries_path = self.options['file_paths'].summaries_dir
            import tensorflow as tf

            self.log_tf_summaries = True
            summary_run_dir = self.get_summaries_run_path()
            self.file_writer = tf.summary.create_file_writer(str(summary_run_dir))
            self.file_writer.set_as_default()
        else:
            self.log_tf_summaries = False


    def get_summaries_run_path(self):
        return self.summaries_path / f"{self.model_identifier_prefix}-summaries"

    def setup_histories_file(self):
        self.eval_histories_path = self.options['file_paths'].eval_histories_dir
        model_history_filename = self.eval_histories_path / FilePaths.construct_history_file_name(
            self.model_identifier_prefix)
        model_history_file = Path(model_history_filename)
        if model_history_file.exists():
            model_history_file.unlink()
        self.hist_out = open(model_history_filename, 'a')

    def finalize(self):
        if self.hist_out is not None and not self.hist_out.closed:
            self.hist_out.close()
        if self.log_tf_summaries:
            self.file_writer.close()



    def train(self, train_g_list, validation_g_list, max_steps, **kwargs):
        self.setup_graphs(train_g_list, validation_g_list)
        self.train_idxes = list(range(len(train_g_list)))

        self.setup_step_metrics()
        self.setup_histories_file()

        self.setup_predictor()
        self.save_model_checkpoints()

        with torch.no_grad():
            self.check_validation_loss_if_req(self.step, max_steps, save_model_if_better=True)

        pbar = tqdm(range(1, max_steps + 1), unit='steps', disable=None, desc='main SL loop')

        optimizer = optim.Adam(self.net.parameters(), lr=self.hyperparams['learning_rate'])

        for self.step in pbar:
            shuffled_gs = [train_g_list[idx] for idx in self.train_idxes]

            gts = self.graph_ds.get_gts_for_hashes(shuffled_gs)
            # TODO: this needs to be adapted for MLP as it's not compatible.
            bs = self.hyperparams['batch_size']

            loader = self.load_data(shuffled_gs, bs)
            for i, data_batch in enumerate(loader):
                if "mlp" in self.algorithm_name:
                    batch_preds = self.net(data_batch, **kwargs).flatten()
                else:
                    batch_preds = torch.stack(self.net(data_batch, **kwargs)).flatten()

                gt_idx_start = i * bs
                gt_idx_end = (i + 1) * bs

                batch_gts = torch.FloatTensor(gts[gt_idx_start: gt_idx_end])
                if get_device_placement() == 'GPU':
                    batch_gts = batch_gts.cuda()

                optimizer.zero_grad()
                loss = F.mse_loss(batch_preds, batch_gts)
                loss.backward()
                optimizer.step()
                pbar.set_description('loss: %0.5f' % (loss))

            # for name, param in self.net.named_parameters():
            #     if param.requires_grad:
            #         print(name, param.data)


            np.random.shuffle(self.train_idxes)
            with torch.no_grad():
                self.check_validation_loss_if_req(self.step, max_steps, save_model_if_better=True)

            if self.step - self.best_validation_changed_steps["default"] > self.max_validation_consecutive_steps:
                self.logger.info(f"validation loss plateaued for {self.max_validation_consecutive_steps} steps; stopping training.")
                break

        self.restore_model_from_checkpoint()
