import torch
import numpy as np

from addict import Dict
import os

import matplotlib.pyplot as plt

from train_wm_base import WMTrainer as _WMTrainer
from mawm.envs.sc2.sc2_trajectories import get_concat_loader, SC2Trajectories
from mawm.losses import SC2ReconLoss, PositionalEmbeddingCosineLoss, SC2HECSReconScore, SC2UnitTypeAccuracy, \
    SC2FriendlyMarkerF1Score, SC2TerrainReconScore, SC2FriendlyMarkerConfusionMatrix, SC2UnitTypeConfusionMatrix
from mawm.visuals import visualize_sc2_recons
from mawm.utils import RoundRobinLoader, drop_agents


class WMTrainer(_WMTrainer):
    def _build_test_loaders(self):
        paths = self.get('test_data/paths')
        if paths is None:
            self.all_test_loaders = {}
            self.test_loader = None
            return
        all_test_loaders = {}
        for scenario in paths:
            test_datasets = [SC2Trajectories(path=path, **self.get('data/kwargs')) for path in paths[scenario]]
            all_test_loaders[scenario] = get_concat_loader(*test_datasets, **self.get('data/loader/kwargs', {}))
        self.all_test_loaders = all_test_loaders
        self.test_loader = RoundRobinLoader(*list(all_test_loaders.values()))

    def _build_loaders(self):
        paths = self.get('data/paths', ensure_exists=True)
        all_train_loaders = {}
        all_validate_loaders = {}
        for scenario in paths:
            train_datasets = [SC2Trajectories(path=path, split=-abs(self.get('data/split', 0.8)),
                                              **self.get('data/kwargs')) for path in paths[scenario]]
            validate_datasets = [SC2Trajectories(path=path, split=abs(self.get('data/split', 0.8)),
                                             **self.get('data/kwargs')) for path in paths[scenario]]
            all_train_loaders[scenario] = get_concat_loader(*train_datasets, **self.get('data/loader/kwargs', {}))
            all_validate_loaders[scenario] = get_concat_loader(*validate_datasets, **self.get('data/loader/kwargs', {}))
        # Concat and make loaders
        self.all_train_loaders = all_train_loaders
        self.all_validate_loaders = all_validate_loaders
        self.train_loader = RoundRobinLoader(*list(all_train_loaders.values()))
        self.validate_loader = RoundRobinLoader(*list(all_validate_loaders.values()))
        # Build test loaders (if required)
        self._build_test_loaders()

    def _build_losses_and_metrics(self):
        self.recon_loss = SC2ReconLoss(**self.get('recon_loss/kwargs', {}))
        self.positional_loss = PositionalEmbeddingCosineLoss()
        self.recon_metrics = Dict(hecs_mse=SC2HECSReconScore(),
                                  unit_type_accuracy=SC2UnitTypeAccuracy(),
                                  friendly_marker_f1_score=SC2FriendlyMarkerF1Score(),
                                  terrain_mse=SC2TerrainReconScore())

    def log_reconstructions(self):
        # Fetch samples
        samples = self.fetch_sample_reconstructions()
        # Recons are of shape NTA(17)rΘ, as are targets. Loop over N and
        # select the relevant time-steps (first and last) and agent 0 to obtain a (17)rΘ tensor.
        self.wandb_pause_step_counter()
        for batch_idx in range(samples.recons.shape[0]):
            # Initial recon and state
            recons_initial = SC2Trajectories.recon_normalizer(samples.recons[batch_idx:(batch_idx + 1), 0:1, 0:1],
                                                              return_components=True)
            states_initial = SC2Trajectories.split_state_tensor(samples.target_states[batch_idx, 1:2, 0:1])
            # Final recon and state
            recons_final = SC2Trajectories.recon_normalizer(samples.recons[batch_idx:(batch_idx + 1), -1:, 0:1],
                                                            return_components=True)
            states_final = SC2Trajectories.split_state_tensor(samples.target_states[batch_idx, -2:-1, 0:1])
            # Visualize initial
            fig = visualize_sc2_recons(recons_initial, states_initial)
            initial_image = self.figure2array(fig, channel_first=True)
            plt.close(fig)
            # Visualize final
            fig = visualize_sc2_recons(recons_final, states_final)
            final_image = self.figure2array(fig, channel_first=True)
            plt.close(fig)
            # Dump to tensorboard
            self.log_image(f'recons/seq_start_{batch_idx}', initial_image)
            self.log_image(f'recons/seq_stop_{batch_idx}', final_image)
            # Dump to wandb
            self.wandb_log_image(f'recons_seq_start_{batch_idx}', initial_image)
            self.wandb_log_image(f'recons_seq_stop_{batch_idx}', final_image)
        self.wandb_resume_step_counter()

    def log_rollout(self):
        # TODO Implement this
        return None

    def light_test_suite(self):
        # Load up the checkpoint
        self.load_checkpoint_if_required()
        self.print_to_file('test_suite_results.txt')
        self._build_evaluation_metrics()
        # Test epoch with OOD data
        self.test_epoch()
        # Test robustness to agent dropout
        self.test_robustness_to_dropped_agents()

    def heavy_test_suite(self):
        self.print("Loading Ckpt...")
        ckpt_info = self.load_checkpoint_if_required()
        # Build extra metrics for evaluating confusion matrices
        self._build_evaluation_metrics()
        # Make container for results
        results = Dict()
        if ckpt_info is None:
            print("Warning: no checkpoint loaded.")
            ckpt_info = {}
        results.ckpt_info.best_validation_loss = ckpt_info.get('best_validation_loss')
        results.ckpt_info.step = ckpt_info.get('step')
        results.ckpt_info.epoch = ckpt_info.get('epoch')
        results.ckpt_info.path = self.get_arg('load', None)
        # Test epoch with OOD data
        self.print("Testing with OOD Data...")
        results.ood_tests = self.test_epoch()
        # Sweep over dropped out agents
        self.print("Initiating Robustness sweep...")
        results.robustness_tests = self.sweep_robustness_to_dropped_agents()
        # Write out results
        result_file = os.path.join(self.log_directory, 'heavy_test_suite_results.pkl')
        self.print(f"Writing results to file: {result_file}")
        torch.save(results, result_file)
        self.print("Done!")

    def _build_evaluation_metrics(self):
        self.recon_metrics = Dict(friendly_marker_confmat=SC2FriendlyMarkerConfusionMatrix(),
                                  unit_type_confmat=SC2UnitTypeConfusionMatrix(),
                                  **self.recon_metrics)

    def test_epoch(self, verbose=True):
        test_results = Dict()
        for scenario, loader in self.all_test_loaders.items():
            test_results[scenario] = self._test_epoch(scenario, loader, verbose=verbose)
        return test_results

    def _test_epoch(self, scenario, loader, verbose=True):
        # Make sure the test loader exists
        if self.test_loader is None:
            return
        # Runs the validation
        loss = 0
        positional_loss = None
        metrics = {metric: 0 for metric in self.recon_metrics}
        iters = 0
        self.model.eval()
        for actions, positions, goals, states, target_states in self.progress(loader, desc=f'Test_{scenario}',
                                                                              tag=f'test_{scenario}'):
            # Ship devices to device
            actions = actions.to(self.device)
            positions = positions.to(self.device)
            states = states.to(self.device)
            target_states = target_states.to(self.device)
            with torch.no_grad():
                # Compute reconstructions
                model_outputs = self.model(actions, positions, states)
                # Compute loss and accuracy
                iter_loss = self.recon_loss(model_outputs.recons, target_states)
                if model_outputs.predicted_position_embeddings is not None:
                    iter_positional_loss = self.positional_loss(model_outputs.predicted_position_embeddings,
                                                                model_outputs.target_position_embeddings)
                else:
                    iter_positional_loss = None
                iter_metrics = {metric: self.recon_metrics[metric](model_outputs.recons, target_states)
                                for metric in self.recon_metrics}
                # Add em up
                loss += iter_loss.item()
                if iter_positional_loss is not None:
                    positional_loss = 0 if positional_loss is None else positional_loss
                    positional_loss += iter_positional_loss.item()
                for metric in self.recon_metrics:
                    metrics[metric] += (iter_metrics[metric].item()
                                        if iter_metrics[metric].numel() == 1 else
                                        iter_metrics[metric].detach().cpu().numpy())
                iters += 1
            self.log_progress(f'test_{scenario}', rloss=iter_loss, ploss=iter_positional_loss)
        loss /= iters
        metrics = {metric: val / iters for metric, val in metrics.items()}
        results = Dict(loss=loss, metrics=metrics, num_iters=iters)
        if verbose:
            # Print out
            self.print(f"------ Test Epoch with {scenario} Data ------")
            self.print("------ Metrics ------")
            for metric_name in metrics:
                self.print(f"{metric_name} : {metrics[metric_name]}")
            self.print("------ Losses ------")
            self.print(f"Recon Loss      : {loss}")
            self.print(f"Positional Loss : {positional_loss / iters}")
        return results

    def sweep_robustness_to_dropped_agents(self, verbose=True):
        results = Dict()
        for keep_proba in self.progress(np.linspace(0.2, 1, 9), desc='SweepRobustness', tag='rob_sweep'):
            results[keep_proba] = self.test_robustness_to_dropped_agents(keep_proba, verbose=verbose)
            self.log_progress('rob_sweep', rloss=results[keep_proba].loss)
        return results

    def test_robustness_to_dropped_agents(self, agent_keep_proba=0.8, verbose=True):
        loss = 0
        metrics = {metric: 0 for metric in self.recon_metrics}
        iters = 0
        self.model.eval()
        for actions, positions, goals, states, target_states in self.progress(self.validate_loader,
                                                                              desc=f'Robustness_{agent_keep_proba}',
                                                                              tag=f'rob_{agent_keep_proba}'):
            # Ship devices to device
            actions = actions.to(self.device)
            target_positions = positions = positions.to(self.device)
            states = states.to(self.device)
            target_states = target_states.to(self.device)
            # Find number of agents to keep
            keep_num_agents = round(states.shape[2] * agent_keep_proba)
            keep_num_agents = max(keep_num_agents, 1)
            # Drop them agents
            actions, positions, states = drop_agents(int(keep_num_agents), actions, positions, states)
            with torch.no_grad():
                # Compute reconstructions, but disable positional RNN
                with self.model.disable_positional_rnn():
                    model_outputs = self.model(actions, positions, states, target_positions=target_positions)
                # Compute loss and accuracy
                iter_loss = self.recon_loss(model_outputs.recons, target_states)
                iter_metrics = {metric: self.recon_metrics[metric](model_outputs.recons, target_states)
                                for metric in self.recon_metrics}
                # Add em up
                loss += iter_loss.item()
                for metric in self.recon_metrics:
                    metrics[metric] += (iter_metrics[metric].item()
                                        if iter_metrics[metric].numel() == 1 else
                                        iter_metrics[metric].detach().cpu().numpy())
                iters += 1
            self.log_progress(f'rob_{agent_keep_proba}', rloss=iter_loss)
        loss /= iters
        metrics = {metric: val / iters for metric, val in metrics.items()}
        results = Dict(loss=loss, metrics=metrics, num_iters=iters)
        if verbose:
            self.print(f"------ Test with Agent Dropout ({agent_keep_proba}) ------")
            # Print out
            self.print("------ Metrics ------")
            for metric_name in metrics:
                self.print(f"{metric_name} : {metrics[metric_name]}")
            self.print("------ Losses ------")
            self.print(f"Recon Loss : {loss}")
        return results


if __name__ == '__main__':
    WMTrainer().run()
