import torch
import torch.nn as nn
from torchvision.utils import make_grid
import numpy as np

import os
import seaborn as sns
from addict import Dict

from mawm.envs import bb_trajectories_data_loader
from mawm.losses import ReconAccuracy, PositionalEmbeddingCosineLoss, REF1Score, BBConfusionMatrix
from mawm import utils

from train_wm_base import WMTrainer as _WMTrainer

sns.set()


class WMTrainer(_WMTrainer):

    def _build_loaders(self):
        # Build dataloaders
        train_path = self.get('data/train_path')
        val_path = self.get('data/val_path')
        self.train_loader = bb_trajectories_data_loader(path=train_path, **self.get('data/kwargs'))
        self.validate_loader = bb_trajectories_data_loader(path=val_path, **self.get('data/kwargs'))
        # Build test loaders (if required)
        self._build_test_loaders()

    def _build_test_loaders(self):
        paths = self.get('test_data/paths')
        if paths is None:
            self.all_test_loaders = {}
            self.test_loader = None
            return
        all_test_loaders = {}
        for scenario in paths:
            all_test_loaders[scenario] = bb_trajectories_data_loader(path=paths[scenario], **self.get('data/kwargs'))
        self.all_test_loaders = all_test_loaders

    def _build_losses_and_metrics(self):
        # Build objective function and metric
        self.recon_loss = nn.BCEWithLogitsLoss()
        self.positional_loss = PositionalEmbeddingCosineLoss()
        self.recon_metrics = Dict(accuracy=ReconAccuracy(),
                                  f1_score=REF1Score())

    def log_reconstructions(self):
        # For now pick the reconstructed frame at the specific 't' in the sequence.
        t_idx = np.random.randint(40, self.get('data/kwargs/sequence_length')-2)
        samples = self.fetch_sample_reconstructions()
        # Recons are of shape NTA(21)hw, as are targets. Loop over N and
        # select the relevant time-steps (first and last) and agent 0 to obtain a (21)hw tensor.
        self.wandb_pause_step_counter()
        for batch_idx in range(samples.recons.shape[0]):
            # We take samples from one batch and consider the sequential samples as a part
            # of the batch
            image_grid_initial = make_grid(samples.recons[batch_idx, t_idx, :, :, :, :].sigmoid(), nrow=5, pad_value=1)
            image_grid_final = make_grid(samples.target_states[batch_idx, t_idx, :, :, :, :], nrow=5, pad_value=1)
            # Dump to tensorboard
            self.log_image(f'recons/seq_{batch_idx}', image_grid_initial)
            self.log_image(f'groundtruth/seq_{batch_idx}', image_grid_final)
            self.wandb_log_image(f'recons_seq_{batch_idx}', image_grid_initial)
            self.wandb_log_image(f'groundtruth_seq_{batch_idx}', image_grid_final)
        self.wandb_resume_step_counter()

    def log_rollout(self):
        if not self.model.ROLLOUTS_IMPLEMENTED:
            return
        # TODO

    def _build_evaluation_metrics(self):
        self.recon_metrics = Dict(confmat=BBConfusionMatrix(),
                                  **self.recon_metrics)

    def heavy_test_suite(self):
        self.print("Loading Ckpt...")
        ckpt_info = self.load_checkpoint_if_required()
        # Build extra metrics for evaluating confusion matrices
        self._build_evaluation_metrics()
        # Make container for results
        results = Dict()
        if ckpt_info is None:
            print("Warning: no checkpoint loaded.")
            ckpt_info = {}
        results.ckpt_info.best_validation_loss = ckpt_info.get('best_validation_loss')
        results.ckpt_info.step = ckpt_info.get('step')
        results.ckpt_info.epoch = ckpt_info.get('epoch')
        results.ckpt_info.path = self.get_arg('load', None)
        # Test epoch with OOD data
        self.print("Testing with OOD Data...")
        results.ood_tests = self.test_epoch()
        # Sweep over dropped out agents
        self.print("Initiating Robustness sweep...")
        results.robustness_tests = self.sweep_robustness_to_dropped_agents()
        # Write out results
        result_file = os.path.join(self.log_directory, 'heavy_test_suite_results.pkl')
        self.print(f"Writing results to file: {result_file}")
        torch.save(results, result_file)
        self.print("Done!")

    def massive_test_suite(self):
        self.print(">>>> Massive Test Suite <<<<")
        self.print("Loading Ckpt...")
        ckpt_info = self.load_checkpoint_if_required()
        # Build extra metrics for evaluating confusion matrices
        self._build_evaluation_metrics()
        # Make container for results
        results = Dict()
        if ckpt_info is None:
            print("Warning: no checkpoint loaded.")
            ckpt_info = {}
        results.ckpt_info.best_validation_loss = ckpt_info.get('best_validation_loss')
        results.ckpt_info.step = ckpt_info.get('step')
        results.ckpt_info.epoch = ckpt_info.get('epoch')
        results.ckpt_info.path = self.get_arg('load', None)
        # Grid sweep over OOD and robustness
        results = self.sweep_ood_robustness_to_dropped_agents(True)
        # Write out results
        result_file = os.path.join(self.log_directory, 'massive_test_suite_results.pkl')
        self.print(f"Writing results to file: {result_file}")
        torch.save(results, result_file)
        self.print("Done!")

    def sweep_robustness_to_dropped_agents(self, verbose=True):
        results = Dict()
        for keep_proba in self.progress(np.linspace(0.2, 1, 9), desc='SweepRobustness', tag='rob_sweep'):
            results_at_keep_proba = Dict()
            results[keep_proba] = self.test_robustness_to_dropped_agents(keep_proba, verbose=verbose)
            self.log_progress('rob_sweep', rloss=results[keep_proba].loss)
        return results

    def sweep_ood_robustness_to_dropped_agents(self, verbose=True):
        results = Dict()
        for scenario, loader in self.progress(self.all_test_loaders.items()):
            for keep_proba in self.progress(np.linspace(0.2, 1, 9)):
                self.print(f"------ {scenario} ------")
                results[scenario][keep_proba] = self.test_robustness_to_dropped_agents(keep_proba,
                                                                                       verbose=verbose,
                                                                                       loader=loader)
        return results

    def test_robustness_to_dropped_agents(self, agent_keep_proba=0.8, verbose=True, loader=None):
        loss = 0
        metrics = {metric: 0 for metric in self.recon_metrics}
        iters = 0
        self.model.eval()
        loader = self.validate_loader if loader is None else loader
        for actions, positions, goals, states, target_states in self.progress(loader,
                                                                              desc=f'Robustness_{agent_keep_proba}',
                                                                              tag=f'rob_{agent_keep_proba}'):
            # Ship devices to device
            actions = actions.to(self.device)
            target_positions = positions = positions.to(self.device)
            states = states.to(self.device)
            target_states = target_states.to(self.device)
            # Find number of agents to keep
            keep_num_agents = round(states.shape[2] * agent_keep_proba)
            keep_num_agents = max(keep_num_agents, 1)
            # Drop them agents
            actions, positions, states = utils.drop_agents(int(keep_num_agents), actions, positions, states)
            with torch.no_grad():
                # Compute reconstructions, but disable positional RNN
                with self.model.disable_positional_rnn():
                    model_outputs = self.model(actions, positions, states, target_positions=target_positions)
                # Compute loss and accuracy
                iter_loss = self.recon_loss(model_outputs.recons, target_states)
                iter_metrics = {metric: self.recon_metrics[metric](model_outputs.recons, target_states)
                                for metric in self.recon_metrics}
                # Add em up
                loss += iter_loss.item()
                for metric in self.recon_metrics:
                    metrics[metric] += (iter_metrics[metric].item()
                                        if iter_metrics[metric].numel() == 1 else
                                        iter_metrics[metric].detach().cpu().numpy())
                iters += 1
            self.log_progress(f'rob_{agent_keep_proba}', rloss=iter_loss)
        loss /= iters
        metrics = {metric: val / iters for metric, val in metrics.items()}
        results = Dict(loss=loss, metrics=metrics, num_iters=iters)
        if verbose:
            self.print(f"------ Test with Agent Dropout ({agent_keep_proba}) ------")
            # Print out
            self.print("------ Metrics ------")
            for metric_name in metrics:
                self.print(f"{metric_name} : {metrics[metric_name]}")
            self.print("------ Losses ------")
            self.print(f"Recon Loss : {loss}")
        return results

    def test_epoch(self, verbose=True):
        test_results = Dict()
        for scenario, loader in self.all_test_loaders.items():
            test_results[scenario] = self._test_epoch(scenario, loader, verbose=verbose)
        return test_results

    def _test_epoch(self, scenario, loader, verbose=True):
        loss = 0
        metrics = {metric: 0 for metric in self.recon_metrics}
        iters = 0
        self.model.eval()
        for actions, positions, goals, states, target_states in self.progress(loader,
                                                                              desc=f'Test_{scenario}',
                                                                              tag=f'test_{scenario}'):
            # Ship devices to device
            actions = actions.to(self.device)
            target_positions = positions = positions.to(self.device)
            states = states.to(self.device)
            target_states = target_states.to(self.device)
            with torch.no_grad():
                # Compute reconstructions, but disable positional RNN
                with self.model.disable_positional_rnn():
                    model_outputs = self.model(actions, positions, states, target_positions=target_positions)
                # Compute loss and accuracy
                iter_loss = self.recon_loss(model_outputs.recons, target_states)
                iter_metrics = {metric: self.recon_metrics[metric](model_outputs.recons, target_states)
                                for metric in self.recon_metrics}
                # Add em up
                loss += iter_loss.item()
                for metric in self.recon_metrics:
                    metrics[metric] += (iter_metrics[metric].item()
                                        if iter_metrics[metric].numel() == 1 else
                                        iter_metrics[metric].detach().cpu().numpy())
                iters += 1
            self.log_progress(f'test_{scenario}', rloss=iter_loss)
        loss /= iters
        metrics = {metric: val / iters for metric, val in metrics.items()}
        results = Dict(loss=loss, metrics=metrics, num_iters=iters)
        if verbose:
            self.print(f"------ Test Epoch with ({scenario}) Data ------")
            # Print out
            self.print("------ Metrics ------")
            for metric_name in metrics:
                self.print(f"{metric_name} : {metrics[metric_name]}")
            self.print("------ Losses ------")
            self.print(f"Recon Loss : {loss}")
        return results


if __name__ == '__main__':
    WMTrainer().run()
