import os

import torch
from torch.nn.parallel.data_parallel import data_parallel

import seaborn as sns

from speedrun import BaseExperiment, TensorboardMixin, IOMixin
from speedrun.logging.plotting import MatplotlibMixin

# from speedrun.logging.wandb import WandBMixin
from mawm.utils import MockWandBMixin as WandBMixin

from mawm.models import world_models as wm
from mawm.models import utils as mutils
from mawm import utils


sns.set()


class WMTrainer(TensorboardMixin, MatplotlibMixin, IOMixin, WandBMixin, BaseExperiment):
    WANDB_PROJECT = 'mawm'

    def __init__(self, experiment_directory=None, dump_configuration=True, update_git_revision=True,
                 auto_setup=True, build=True):
        if not auto_setup:
            assert experiment_directory is not None
        super(WMTrainer, self).__init__(experiment_directory=experiment_directory)
        if auto_setup:
            self.auto_setup(dump_configuration=dump_configuration, update_git_revision=update_git_revision)
        else:
            # Read only mode
            self.read_only_setup()
        if build:
            self._build()

    def read_only_setup(self):
        self.record_args()
        self.read_config_file()
        self.read_macro()
        self.update_configuration_from_args()
        return self

    def _build_loaders(self):
        # Build dataloaders
        self.train_loader = None
        self.validate_loader = None

    def _build_model(self):
        # Build model
        # noinspection PyPep8Naming
        ModelClass = getattr(wm, self.get('model/name', 'MAWMWithRNN'))
        self.model: torch.nn.Module = ModelClass(**self.get('model/kwargs', {})).to(self.device)

    def _build_optim_and_scheduler(self):
        # Build optimizer
        optim_class = getattr(torch.optim, self.get('optim/name', 'Adam'))
        self.optim = optim_class(self.model.parameters(), **self.get('optim/kwargs', {'lr': 0.0003}))
        # Build scheduler
        scheduler_class = getattr(torch.optim.lr_scheduler,
                                  self.get('optim/lr_annealer_name', 'ReduceLROnPlateau'))
        self.scheduler = scheduler_class(self.optim, **self.get('optim/lr_annealer_kwargs', {}))

    def _build_losses_and_metrics(self):
        # Build objective function and metric
        self.recon_loss = None
        self.positional_loss = None
        self.recon_metrics = None

    def _build(self):
        # Build dataloaders
        self._build_loaders()
        # Build model
        self._build_model()
        # Build optimizer and scheduler
        self._build_optim_and_scheduler()
        # Build objective function and metric
        self._build_losses_and_metrics()
        # Bind the right recon-to-state processor for rollouts
        self.model.register_recon_to_state_processor(
            utils.make_reconstruction_to_state_processor(
                **self.get('model/rollouts/recon_to_state_processor_kwargs', {})))

    @property
    def device(self):
        return self.get('device', 'cuda')

    @property
    def is_multi_gpu(self):
        return self.device == 'cuda' and torch.cuda.device_count() > 1 and self.get_arg('multi_gpu', False)

    @property
    def device_ids(self):
        return list(range(torch.cuda.device_count()))

    DEFAULT_DISPATCH = 'train'

    def train(self):
        # Log config to tensorboard
        self.log_configuration()
        # Load checkpoint if required
        self.load_checkpoint_if_required()
        # Initialize wandb
        self.initialize_wandb(resume=(self.get('load') is not None and self.get('preserve_steps', False)))
        # Go!
        for _ in self.progress(range(self.epoch, self.get('training/num_epochs')), desc='Epochs', tag='ep'):
            # Train
            self.train_epoch()
            # Valdiate
            validation_results = self.validate_epoch()
            # Checkpoint if better than best
            if validation_results.best or self.get_arg('checkpoint_every_epoch', False):
                self.checkpoint()
            # Anneal lr
            if isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                # Unlike other schedulers, ReduceLROnPlateau needs access to the validation loss
                self.scheduler.step(validation_results.validation_loss)
            else:
                self.scheduler.step()
            # Log to progress bar
            self.log_progress('ep', vloss=validation_results.validation_loss, **validation_results.metrics)
            # Increment counter
            self.next_epoch()

    def train_epoch(self):
        self.model.train()
        for actions, positions, goals, states, target_states in self.progress(self.train_loader, desc='Training',
                                                                              tag='train'):
            # Ship devices to device
            actions = actions.to(self.device)
            positions = positions.to(self.device)
            states = states.to(self.device)
            target_states = target_states.to(self.device)
            # Compute reconstructions
            if not self.is_multi_gpu:
                model_outputs = self.model(actions, positions, states)
            else:
                model_outputs = data_parallel(self.model, (actions, positions, states))
            # Compute loss
            loss = recon_loss = self.recon_loss(model_outputs.recons, target_states)
            if model_outputs.predicted_position_embeddings is not None:
                positional_loss = self.positional_loss(model_outputs.predicted_position_embeddings,
                                                       model_outputs.target_position_embeddings)
                loss = loss + positional_loss
            else:
                positional_loss = None
            # Update
            self.optim.zero_grad()
            loss.backward()
            # Clip grad if required
            if self.get('training/grad_clip') is not None:
                torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                               max_norm=self.get('training/grad_clip'))
            self.optim.step()
            # Visualize
            if self.log_scalars_now:
                wandb_log = {}
                with torch.no_grad():
                    for metric in self.recon_metrics:
                        metric_value = self.recon_metrics[metric](model_outputs.recons, target_states)
                        self.log_scalar(f'training/{metric}', metric_value)
                        wandb_log[f"training_{metric}"] = metric_value
                self.log_scalar('training/loss', recon_loss.item())
                wandb_log['training_recon_loss'] = recon_loss.item()
                if positional_loss is not None:
                    self.log_scalar('training/positional_loss', positional_loss.item())
                    self.log_scalar('training/total_loss', loss.item())
                    wandb_log['training_positional_loss'] = positional_loss.item()
                    wandb_log['training_total_loss'] = loss.item()
                self.log_scalar('training/lr', self.optim.param_groups[0]['lr'])
                wandb_log['lr'] = self.optim.param_groups[0]['lr']
                # Log wandb
                self.wandb_log(**wandb_log)
            # Log images
            if self.log_images_now:
                self.log_reconstructions()
            # Log progress
            self.log_progress('train', rloss=recon_loss.item(),
                              ploss=(positional_loss.item() if positional_loss is not None else None))
            # Next step
            self.next_step()

    def validate_epoch(self):
        loss = 0
        positional_loss = None
        metrics = {metric: 0 for metric in self.recon_metrics}
        iters = 0
        self.model.eval()
        for actions, positions, goals, states, target_states in self.progress(self.validate_loader, desc='Validation',
                                                                              tag='val'):
            # Ship devices to device
            actions = actions.to(self.device)
            positions = positions.to(self.device)
            states = states.to(self.device)
            target_states = target_states.to(self.device)
            with torch.no_grad():
                # Compute reconstructions
                if not self.is_multi_gpu:
                    model_outputs = self.model(actions, positions, states)
                else:
                    model_outputs = data_parallel(self.model, (actions, positions, states))
                # Compute loss and accuracy
                iter_loss = self.recon_loss(model_outputs.recons, target_states)
                if model_outputs.predicted_position_embeddings is not None:
                    iter_positional_loss = self.positional_loss(model_outputs.predicted_position_embeddings,
                                                                model_outputs.target_position_embeddings)
                else:
                    iter_positional_loss = None
                iter_metrics = {metric: self.recon_metrics[metric](model_outputs.recons, target_states)
                                for metric in self.recon_metrics}
                # Add em up
                loss += iter_loss.item()
                if iter_positional_loss is not None:
                    positional_loss = 0 if positional_loss is None else positional_loss
                    positional_loss += iter_positional_loss.item()
                for metric in self.recon_metrics:
                    metrics[metric] += iter_metrics[metric].item()
                iters += 1
            self.log_progress('val', rloss=iter_loss, ploss=iter_positional_loss)
        loss /= iters
        metrics = {metric: val / iters for metric, val in metrics.items()}
        # Log
        wandb_log = {}
        self.log_scalar('validation/loss', loss)
        wandb_log['validation_loss'] = loss
        if positional_loss is not None:
            self.log_scalar('validation/positional_loss', positional_loss)
            wandb_log['validation_positional_loss'] = positional_loss
        for metric in metrics:
            self.log_scalar(f'validation/{metric}', metrics[metric])
            wandb_log[f"validation_{metric}"] = metrics[metric]
        self.log_rollout()
        # Find out if loss is best
        if loss < self.read_from_cache('best_validation_loss', float('inf')):
            self.write_to_cache('best_validation_loss', loss)
            best = True
        else:
            best = False
        is_best_positional_loss = \
            (positional_loss is not None and
             (positional_loss < self.read_from_cache('best_validation_positional_loss', float('inf'))))
        if is_best_positional_loss:
            self.write_to_cache('best_validation_positional_loss', positional_loss)
            best = best or True
        else:
            best = best or False
        wandb_log['best'] = best
        self.wandb_log(**wandb_log)
        return self.bundle(validation_loss=loss, positional_loss=positional_loss, best=best, metrics=metrics)

    def checkpoint(self, force=True):
        info = {'model': self.model.state_dict(),
                'optim': self.optim.state_dict(),
                'scheduler': self.scheduler.state_dict(),
                'best_validation_loss': self.read_from_cache('best_validation_loss', float('inf')),
                'step': self.step,
                'epoch': self.epoch}
        torch.save(info, self.checkpoint_path)
        # self.mark_as_latest_checkpoint()

    def mark_as_latest_checkpoint(self):
        if os.path.exists(self.checkpoint_path):
            if os.path.exists(os.path.join(self.checkpoint_directory, 'ckpt_latest.pt')):
                os.remove(os.path.join(self.checkpoint_directory, 'ckpt_latest.pt'))
            os.symlink(self.checkpoint_path, os.path.join(self.checkpoint_directory, 'ckpt_latest.pt'))

    def load_checkpoint_if_required(self, path=None):
        path = self.get_arg('load', None) if path is None else path
        if path is None:
            return
        info = torch.load(path)
        # Set state dicts
        self.model.load_state_dict(info['model'])
        self.optim.load_state_dict(info['optim'])
        self.scheduler.load_state_dict(info['scheduler'])
        self.print(f"Loaded checkpoint from {path}; best validation loss was {info['best_validation_loss']}.")
        # Preserve steps for tensorboard
        preserve_steps = self.get_arg('preserve_steps', False)
        if preserve_steps:
            step = info.get('step')
            epoch = info.get('epoch')
            self._step = step if step is not None else self._step
            self._epoch = epoch if epoch is not None else self._epoch
        return info

    def read_sample_inputs(self):
        if 'sample_inputs' not in self.cache_keys:
            sample_inputs = [t[0:3] for t in next(iter(self.validate_loader))]
            self.write_to_cache('sample_inputs', sample_inputs)
        # Infer on sample inputs
        actions, positions, goals, states, target_states = [t.to(self.device)
                                                            for t in self.read_from_cache('sample_inputs',
                                                                                          ensure_exists=True)]
        return actions, positions, goals, states, target_states

    def fetch_sample_reconstructions(self):
        # Infer on sample inputs
        actions, positions, goals, states, target_states = self.read_sample_inputs()
        with torch.no_grad():
            model_outputs = self.model(actions, positions, states)
        # Return
        return self.bundle(recons=model_outputs.recons.cpu(), target_states=target_states.cpu(),
                           actions=actions.cpu(), positions=positions.cpu(), states=states.cpu())

    def fetch_sample_rollouts(self):
        # Infer on sample inputs
        actions, positions, goals, states, target_states = self.read_sample_inputs()
        # Rollout the model
        rollout_info = mutils.rollout(self.model, actions, positions, states, goals,
                                      prompt_till_step=self.get('validation/rollouts/prompt_till_step', 0.2))
        rollout_info.sample_inputs = (actions, positions, goals, states, target_states)
        return rollout_info

    def log_reconstructions(self):
        pass

    def log_rollout(self):
        if not self.model.ROLLOUTS_IMPLEMENTED:
            return
        pass

    def log_configuration(self):
        self.logger.add_text('config', str(self._config))
