"""Training objects to handle training, checkpointing, and logging/visualizing metrics
"""
import os
import glob
import time
import wandb
import torch
import numpy as np
import torch.optim as optim
import torch.nn.functional as F

from tqdm import tqdm
from shutil import copy
from abc import abstractmethod
from .evaluation import evaluate_solver
from .utils import logger, model_encoding_str, next_highest_OM 
from .visualization import (
    visualize_solver_heatmap, 
    visualize_solver_phase_space,
    visualize_residual_heatmap,
    visualize_residual_phase_space,
)


class Trainer:
    """Class for handling the base training logic
    """
    
    def __init__(
        self, 
        model, 
        train_dataset, 
        test_dataset, 
        hparams
    ):
        """TODO: add docs
        """
        self.hparams = hparams
        self.device = (
            f"cuda:{hparams.device_id}" if torch.cuda.is_available() 
            else "cpu"
        )

        # TODO: maybe handle distributed training
        self.model = model.to(self.device)
        self.module = model.module if hasattr(self.model, "module") else model
        logger.info(f"*** Training order-{self.model.order} model on {self.device} ***")

        self.optim = getattr(optim, hparams.optimizer.name)(
            self.module.parameters(), 
            **hparams.optimizer.params
        )

        # for saving figures, metrics, and models
        self.prefix = hparams.model.prefix
        self.enc = model_encoding_str(hparams)
        self.details = (
            f"{self.prefix}_"  # arbitrary string identifier
            f"bs-{hparams.batch_size}_"  # batchsize
            f"epochs-{hparams.epochs}_"  # epochs
            f"skip-{hparams.skip}"  # validation interval
        )

        # init dataloaders
        self.dataset = train_dataset
        self.test_dataset = test_dataset
        self.init_data()

        # characteristic of the diffeq
        self.operator = self.dataset.operator

        # initialize save directories
        self.init_dirs()

        # for keeping track of metrics 
        metrics = [
            hparams.loss,  # visible during training
            'RMSE',  # TESTING METRIC: CANNOT VIEW DURING TRAINING
            'MAE',  # TESTING METRIC: CANNOT VIEW DURING TRAINING
            'Max-AE',  # TESTING METRIC: CANNOT VIEW DURING TRAINING
            'Relative Error',  # TESTING METRIC: CANNOT VIEW DURING TRAINING
        ]

        # NOTE: max-RMSR IS visible during training - computed during each checkpoint
        self.metrics = {m: [] for m in metrics + ['Max-RMSR']*(self.hparams.compute_error_magnitude)}

        # checkpointing and saving stuff
        self.saved_epoch = 0  # for logging when best model was saved
        self.checkpoint_epochs = [c for c in hparams.checkpoints]
        self.checkpoint_interval = hparams.checkpoint_interval
        self.valid_epochs = len(self.checkpoint_epochs) > 0
        self.valid_interval = self.checkpoint_interval > 0
        if self.valid_epochs:
            logger.info(f"*** Checkpointing at epochs {self.checkpoint_epochs}/{hparams.epochs} ***")
        elif self.valid_interval:
            logger.info(f"*** Checkpointing model every {self.checkpoint_interval} epochs ***")

        # check params
        assert all(hparams.epochs - c > hparams.skip for c in self.checkpoint_epochs), (
            f"Cannot skip {hparams.skip} epochs/iterations for checkpoints {self.checkpoint_epochs}"
        )
        assert (hparams.save_best if hparams.correct_best else True), (
            "`save_best` must be True if `correct_best` is set to True"
        )

    def init_data(self):
        """Initalize dataloaders (subject to change via splitting one single dataset)
        """
        self.trainloader = self.dataset.dataloader(
            batch_size=self.hparams.batch_size,
            shuffle=True,
            pin_memory=True
        )
        self.testloader = self.test_dataset.dataloader(
            batch_size=self.hparams.batch_size,
            pin_memory=True
        )

    def init_dirs(self):
        """Initialize directories to save data to
        """
        self.save_dir = f"saved_models/{self.enc}"
        self.metric_dir = f"metrics/{self.enc}"
        self.figure_dir = f"figures/{self.enc}"

        os.makedirs(self.save_dir, exist_ok=True)
        os.makedirs(self.metric_dir, exist_ok=True)
        os.makedirs(self.figure_dir, exist_ok=True)

    def train(self):
        """Handles base training logic. `finish_training` call at the end is flexible
        """
        # track gradients of model
        if self.hparams.wandb.log:
            wandb.watch(self.model)

        begin_msg = f"{38 * '='} [BEGIN TRAINING] {38 * '='}"
        logger.info(begin_msg)

        best_loss = float('inf')
        start_time = time.time()
        for epoch in range(self.hparams.epochs):
            # set training mode
            self.model.train() 

            # run the training function
            val_metrics = self.step(epoch)

            # save the best model based on DE residuals
            if self.hparams.save_best and val_metrics[self.hparams.loss] < best_loss:
                self.model.save_model(self.save_dir, f"{self.details}.pt")                
                self.saved_epoch = epoch
                best_loss = val_metrics[self.hparams.loss]

            # store val metrics for plotting
            for m in val_metrics:
                self.metrics[m].append(val_metrics[m])

            # checkpoint the model
            self.checkpoint(epoch)

        finish_msg = f"{30 * '='} [FINISHED TRAINING ({time.time() - start_time:.2e}s)]"
        logger.info(f"{finish_msg} {(len(begin_msg) - len(finish_msg) - 1) * '='}")
        self.finish_training()

    @abstractmethod
    def step(self, epoch):
        """Performs one epoch of training. Must be implemented by child classes
        """
        raise NotImplementedError

    @abstractmethod
    def postprocess(self, *args):
        """Custom user function to do post processing after finishing training
        """
        pass

    def checkpoint(self, epoch):
        """Handles saving and checkpointing logic during training
        """
        # for simplifying the code
        epoch_checkpoint = (epoch in self.checkpoint_epochs)
        interval_checkpoint = (epoch + (epoch == self.hparams.epochs - 1)) % self.checkpoint_interval == 0

        # checkpoint model if specified
        if (self.valid_epochs and epoch_checkpoint) or (self.valid_interval and interval_checkpoint): 
            chkpt_model_name = f"checkpoint-{epoch}_{self.details}.pt"

            if self.hparams.correct_best:  # checkpoint best model up until this point
                model_name = f"{'order-{self.model.order}_' if self.model.order > 0 else ''}{self.details}.pt"
                chkpt_name = f"{'order-{self.model.order}_' if self.model.order > 0 else ''}{chkpt_model_name}"
                copy(f"{self.save_dir}/{model_name}", f"{self.save_dir}/{chkpt_name}")
                model_id = "N" if self.model.order == 0 else "N_e{self.model.order}"
                logger.info(f"Saved best {model_id} (epoch {self.saved_epoch}): {self.save_dir}/{chkpt_name}")
            else:  # checkpoint current model at this point
                self.model.save_model(self.save_dir, chkpt_model_name, logger)

    def log_metrics(self, metrics, epoch):
        """Handle logging to stdout and wandb
        """
        # end at [epochs] instead of [epochs-1] for cosmetics)
        if epoch % self.hparams.skip == 0 or epoch == self.hparams.epochs - 1:

            # log the epoch during which the current BEST model was saved
            if self.hparams.save_best:
                metrics["Best Epoch"] = self.saved_epoch

            metrics_str = ' | '.join(f"{k} = {v:.5g}" for k, v in metrics.items())
            logger.info(f"Iteration {epoch + (epoch == self.hparams.epochs - 1):6d}: {metrics_str}")
            
        # send data to weights & biases
        if self.hparams.wandb.log:
            wandb.log(metrics, step=epoch)

    def save_metrics(self):
        """Save training and validation metrics to files
        """
        save_prefix = self.details
        if self.model.order > 0:
            save_prefix = f"order-{self.model.order}_{save_prefix}"
        save_path = f"{self.metric_dir}/{save_prefix}.npz"
        np.savez(save_path, **self.metrics)
        logger.info(f"Saved metrics: {save_path}")

    def evaluate(self, checkpoint=False):
        """Evaluate model on test set
        """
        metrics, outputs = evaluate_solver(self.model, self.testloader, self.device)

        # load max-RMSR from saved metrics file
        if self.hparams.compute_error_magnitude:
            metrics_files = glob.glob(f"{self.metric_dir}/*{self.prefix}*")
            metrics_files.sort()  # ASSUMING that it will get sorted by order, seems like it does, didnt test
            loaded_metrics = {} if len(metrics_files) == 0 else np.load(metrics_files[0])  # hacky disgusting code
            if len(loaded_metrics.get('Max-RMSR', [])) > 0:  # only report if data exists for this metric
                metrics['Max-RMSR'] = loaded_metrics['Max-RMSR'][0]

        metrics_str = '\n--->| '.join(f"{k} = {v}" for k, v in metrics.items()) + (not checkpoint) * "\n"
        logger.info(f"{'Checkpoint' if checkpoint else 'Test'} results: \n--->| {metrics_str}")
        return outputs

    def visualize(self, predictions):
        """Visualize final model on test set
        """
        save_prefix = self.details
        if self.model.order > 0:
            save_prefix = f"order-{self.model.order}_{save_prefix}"
        save_prefix = f"{self.figure_dir}/{save_prefix}"

        models = ['_{' + f'\epsilon_{o}' + '}' if o > 0 else '' for o in range(1 + self.model.order)]
        self.hparams.figures.title = '+'.join('$\mathcal{N}' + m + "$" for m in models)

        # special case for the nPBE dataset
        if self.hparams.data.name == 'nPBE':
            predictions = predictions[:self.test_dataset.N**2]

        if self.test_dataset.dims > 1:
            visualize_solver_heatmap(  # plot prediction
                predictions, self.test_dataset, save_prefix, self.hparams.figures
            )
            gt_data = self.test_dataset.sol_data.reshape(self.test_dataset.domain_shape)
            self.hparams.figures.title = "Groundtruth"
            visualize_solver_heatmap(  # plot groundtruth
                gt_data, self.test_dataset, f"{self.figure_dir}/groundtruth", self.hparams.figures
            )
        elif self.test_dataset.dims == 1:
            visualize_solver_phase_space(
                predictions, self.test_dataset, save_prefix, self.hparams.figures
            )

    def finish_training(self):
        """Post-training routines, subject to customization
        """
        if self.hparams.save_best:
            model_path = (
                f"{self.save_dir}/{f'order-{self.model.order}_' if self.model.order > 0 else ''}"
                f"{self.details}.pt"
            )
            self.model.load_model(self.model.order, model_path, logger)
        else:  # wasnt saving best model
            self.model.save_model(self.save_dir, f"{self.details}.pt", logger)
        # save training metrics
        self.save_metrics()
        # evaluate on test set
        outputs = self.evaluate()
        # visualize stuff
        self.visualize(outputs['prediction'])
        # custom user-defined post training stuff
        self.postprocess()


class OperatorTrainer(Trainer):

    def __init__(self, model, train_dataset, test_dataset, hparams, load_checkpoint=None):
        """
        TODO: finish documentation
        """
        super().__init__(model, train_dataset, test_dataset, hparams)

        # handle loading checkpoint if specified
        if load_checkpoint is not None:
            self.model.load_model(*load_checkpoint, logger)

        # choose regression loss function
        self.loss_fn = {
            'L1': F.l1_loss,
            'MSE': F.mse_loss,  # mean square residual
            'Huber': F.smooth_l1_loss
        }[hparams.loss]

        if hparams.compute_error_magnitude and model.order > 0:
            logger.info(f"*** Computed error model magnitudes: {self.model.mags[:model.order]} ***")

    def step(self, epoch):
        """Train the model over one pass of the data
        """
        running_loss = 0.0
        max_residuals = []
        for x in tqdm(self.trainloader,
                      desc=f"Training for epoch {epoch + 1:4d}",
                      total=len(self.trainloader),
                      leave=False):
            x = x.to(self.device)

            # forward pass (regress on PDE operator value) and constraints
            estimate = self.model(x)
            operator_residuals = self.operator(x, estimate)

            # store max operator residual only on the epoch right before an error correction is performed
            if epoch in self.checkpoint_epochs and self.hparams.compute_error_magnitude:
                max_residuals.append(
                    torch.sqrt(operator_residuals.detach().pow(2).mean(dim=-1)).max()
                )

            informed_loss = self.loss_fn(operator_residuals, torch.zeros_like(operator_residuals))
            constraint_loss = 0.0

            # boundary condition error
            if self.dataset.has_bcs and len(x) * self.hparams.bc_prop > 1:
                boundary_points = self.dataset.sample_boundary(
                    int(len(x) * self.hparams.bc_prop)
                ).to(self.device)
                boundary_estimate = self.model(boundary_points)
                epsilon_bc = self.dataset.boundary_condition(boundary_points, boundary_estimate)
                bnd_loss = self.loss_fn(epsilon_bc, torch.zeros_like(epsilon_bc))
                constraint_loss += bnd_loss

            # initial condition error
            if self.dataset.has_ics and len(x) * self.hparams.ic_prop > 1:
                initial_points = self.dataset.sample_init(
                    int(len(x) * self.hparams.ic_prop)
                ).to(self.device)
                initial_estimate = self.model(initial_points)
                epsilon_ic = self.dataset.initial_condition(initial_points, initial_estimate) 
                inc_loss = self.loss_fn(epsilon_ic, torch.zeros_like(epsilon_ic))
                constraint_loss += inc_loss

            loss = informed_loss + self.hparams.alpha * constraint_loss
            running_loss += loss.item()

            # backpropagate
            self.optim.zero_grad()
            loss.backward()
            self.optim.step()

        # --- compute metrics ---
        # NOTE: eval_metrics are NOT visible during training - purely for posthoc analysis
        metrics = {self.hparams.loss: running_loss / len(self.trainloader)}

        # log training progress
        self.log_metrics(
            {**metrics, **{"Informed Loss": informed_loss, "Constraint Loss": constraint_loss}}, 
            epoch
        )

        # must be invisible during training
        if epoch % self.hparams.skip == 0 or epoch == self.hparams.epochs - 1:
            eval_metrics, _ = evaluate_solver(self.model, self.testloader, self.device)
            metrics.update(eval_metrics)

        # if error model scale is being automatically computed
        # heuristically estimate by rounding the max residual to the next highest order of magnitude
        if epoch in self.checkpoint_epochs and self.hparams.compute_error_magnitude:
            max_res = max(max_residuals)
            self.model.mags[self.model.order] = next_highest_OM(max_res)

            # log this metric only once
            metrics["Max-RMSR"] = max_res.item()

        return metrics

    def postprocess(self):
        """Visualize model residuals
        """
        if self.model.order > 0:
            
            # run inference on the base model and residual models
            residual_prediction, prediction = [], []
            self.model.eval()
            for x, _ in self.testloader:
                x = x.to(self.device)
                # manually forward pass up until the order-1-th error model
                pred = self.model.reparameterize(self.model.N(x), x) + \
                    sum(
                        self.model.mags[o] * self.model.reparameterize(
                            self.model.error_models[f"N_e{o+1}"](x), x, error=True
                        )
                        for o in range(self.model.order-1)
                    )
                # order-1-th error model
                res_pred = self.model.mags[self.model.order-1] * self.model.reparameterize(
                    self.model.error_models[f"N_e{self.model.order}"](x), x, error=True
                )
                prediction.append(pred)
                residual_prediction.append(res_pred)

            residual = self.test_dataset.sol_data - torch.cat(prediction).detach().cpu().numpy()
            est_residual = torch.cat(residual_prediction).detach().cpu().numpy()

            # plotting stuff
            save_prefix = self.details
            if self.model.order > 0:
                save_prefix = f"order-{self.model.order}_{save_prefix}"
            save_prefix = f"{self.figure_dir}/{save_prefix}"

            self.hparams.figures.title = "$\mathcal{N}_{\epsilon_" + f"{self.model.order}" + "}$"

            # special case for the nPBE dataset
            if self.hparams.data.name == 'nPBE':
                residual = residual[:self.test_dataset.N**2]
                est_residual = est_residual[:self.test_dataset.N**2]

            if self.test_dataset.dims > 1:
                visualize_residual_heatmap(
                    residual, est_residual, self.test_dataset, save_prefix, self.hparams.figures)
            elif self.test_dataset.dims == 1:
                visualize_residual_phase_space(
                    residual, est_residual, self.test_dataset, save_prefix, self.hparams.figures)


class FullBatchTrainer(OperatorTrainer, Trainer):

    def __init__(self, model, train_dataset, test_dataset, hparams, load_checkpoint=None):
        """Trains model using full-batch optimizer e.g. LBFGS
        """
        # pre init stuff
        hparams.batch_size = len(test_dataset)

        super().__init__(model, train_dataset, test_dataset, hparams, load_checkpoint)

        # overwrite the optimizer to use LBFGS
        self.optim = torch.optim.LBFGS(
            params=self.model.parameters(), 
            max_iter=hparams.epochs, 
            max_eval=hparams.epochs,
            **hparams.optimizer.params
        )
        logger.info(f"*** Using full-batch trainer with LBFGS optimizer ***")
        
        self.iter = 0
        self.best_loss = float('inf')

    def closure(self):
        """Optimizer closure for LBFGS
        """
        # forward pass (regress on PDE operator value) and constraints
        x = self.dataset.data
        estimate = self.model(x)
        operator_residuals = self.operator(x, estimate)

        informed_loss = self.loss_fn(operator_residuals, torch.zeros_like(operator_residuals))
        constraint_loss = 0.0

        # boundary condition error
        # NOTE: Assumes existence of data and solution data for boundary and initial conditions
        if self.dataset.has_bcs:
            boundary_estimate = self.model(self.dataset.bc_data)
            epsilon_bc = boundary_estimate - self.dataset.bc_sol
            bnd_loss = self.loss_fn(epsilon_bc, torch.zeros_like(epsilon_bc))
            constraint_loss += bnd_loss

        # initial condition error
        if self.dataset.has_ics:
            initial_estimate = self.model(self.dataset.ic_data)
            epsilon_ic = initial_estimate - self.dataset.ic_sol
            inc_loss = self.loss_fn(epsilon_ic, torch.zeros_like(epsilon_ic))
            constraint_loss += inc_loss       

        loss = informed_loss + self.hparams.alpha * constraint_loss

        # backpropagate
        self.optim.zero_grad()
        loss.backward()

        # ------------------------------- compute & log metrics -------------------------------
        # NOTE: eval_metrics are NOT visible during training - purely for posthoc analysis
        metrics = {self.hparams.loss: loss.item()}

        # log training progress
        self.log_metrics(
            {**metrics, **{"Informed Loss": informed_loss, "Constraint Loss": constraint_loss}}, 
            self.iter
        )

        # must be invisible during training
        if self.iter % self.hparams.skip == 0 or self.iter == self.hparams.epochs - 1:
            eval_metrics, _ = evaluate_solver(self.model, self.testloader, self.device)
            metrics.update(eval_metrics)

        # store max operator residual only on the epoch right before an error correction is performed
        # if error model scale is being automatically computed
        # heuristically estimate by rounding the max residual to the next highest order of magnitude
        if self.iter in self.checkpoint_epochs and self.hparams.compute_error_magnitude:
            max_res = torch.sqrt(operator_residuals.detach().pow(2).mean(dim=-1)).max()
            self.model.mags[self.model.order] = next_highest_OM(max_res)
            metrics["Max-RMSR"] = max_res.item()
        #  ------------------------------------------------------------------------------------

        # ------------------------------ store metrics and save -------------------------------
        # save the best model based on DiffEq residuals
        if self.hparams.save_best and metrics[self.hparams.loss] < self.best_loss:
            self.model.save_model(self.save_dir, f"{self.details}.pt")                
            self.saved_epoch = self.iter
            self.best_loss = metrics[self.hparams.loss]

        # store val metrics for plotting
        for m in metrics:
            self.metrics[m].append(metrics[m])

        # checkpoint the model
        self.checkpoint(self.iter)
        #  ------------------------------------------------------------------------------------

        self.iter += 1
        return loss

    def train(self):
        """Overwrites and modifies the training logic for LBFGS full-batch training
        """
        # track gradients of model
        if self.hparams.wandb.log:
            wandb.watch(self.model)
        
        begin_msg = f"{38 * '='} [BEGIN TRAINING] {38 * '='}"
        logger.info(begin_msg)

        start_time = time.time()

        # set training mode
        self.model.train()
        self.optim.step(self.closure)

        finish_msg = f"{30 * '='} [FINISHED TRAINING ({time.time() - start_time:.2e}s)]"
        logger.info(f"{finish_msg} {(len(begin_msg) - len(finish_msg) - 1) * '='}")
        self.finish_training()
