from collections import OrderedDict
import math

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data.sampler import SubsetRandomSampler
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F
import gym

from .bc_agent import BCAgent
from .expert_dataset import ExpertDataset
from networks import Actor
from utils.info_dict import Info
from utils.logger import logger
from utils.mpi import mpi_average
from utils.pytorch import (
    optimizer_cuda,
    count_parameters,
    compute_gradient_norm,
    compute_weight_norm,
    sync_networks,
    sync_grads,
    to_tensor,
    sample_from_dataloader,
)
from utils.general import cat_dict_numpy, cat_dict_tensor


class MTBCAgent(BCAgent):
    """Multi-task BC Agent"""

    def __init__(self, config, ob_space, ac_space, env_ob_space, layout):
        ### HACK HACK HACK: For now only for demo-conditioned policies but should re-factor for all multitask BC policies
        assert config.demo_conditioned_policy

        super().__init__(config, ob_space, ac_space, env_ob_space, layout)

        if config.is_train:
            path = config.demo_path.split("#")
            task_batch_size = math.floor(
                self._config.batch_size * (1 - self._config.mt_balance)
            )
            mt_batch_size = math.floor(
                self._config.batch_size * self._config.mt_balance // (len(path) - 1)
            )
            self._data_dataset = []
            self._data_loader, self._data_iter = [], []
            self._traj_loader, self._traj_iter = [], []
            for i, _p in enumerate(path):
                if _p == self._config.target_demo_path:
                    self._target_task_idx = i
                    data_dataset, data_loader, data_iter = self._load_dataset(
                        _p, batch_size=task_batch_size
                    )
                else:
                    data_dataset, data_loader, data_iter = self._load_dataset(
                        _p, batch_size=mt_batch_size
                    )
                self._data_dataset.append(data_dataset)
                self._data_loader.append(data_loader)
                self._data_iter.append(data_iter)

            temp_index = None
            for _dataset_ite in self._data_dataset:
                if _dataset_ite.sampled_few_demo_index is not None:
                    temp_index = _dataset_ite.sampled_few_demo_index
                    break

            for _, _p in enumerate(path):
                if _p == self._config.target_demo_path:
                    traj_dataset, traj_loader, traj_iter = self._load_traj(
                        _p,
                        sampled_few_demo_index=temp_index,
                        batch_size=task_batch_size,
                    )
                else:
                    traj_dataset, traj_loader, traj_iter = self._load_traj(
                        _p,
                        sampled_few_demo_index=temp_index,
                        batch_size=mt_batch_size,
                    )
                self._traj_loader.append(traj_loader)
                self._traj_iter.append(traj_iter)

    def _log_creation(self):
        if self._config.is_chef:
            logger.info("Creating a Multi-task BC agent")
            logger.info("The actor has %d parameters", count_parameters(self._actor))

    def _load_mt_transitions(self):
        j = 0
        while True:
            transitions = {}
            for i, data_iter in enumerate(self._data_iter):
                try:
                    task_i_transitions = next(data_iter)
                except StopIteration:
                    self._data_iter[i] = iter(self._data_loader[i])
                    if i == self._target_task_idx:
                        return
                    else:
                        task_i_transitions = next(self._data_iter[i])
                ## append transitions
                if transitions == {}:
                    transitions = task_i_transitions
                else:
                    for k, v in transitions.items():
                        if isinstance(v, dict):
                            transitions[k] = cat_dict_numpy([v, task_i_transitions[k]])
                        else:
                            transitions[k] = np.concatenate(
                                [v, task_i_transitions[k]], axis=0
                            )

            if self._config.demo_conditioned_policy:
                demos = {}

                for demo_loader in self._traj_loader:
                    demo_o, demo_ac = self._sample_expert_traj(demo_loader)
                    if demos == {}:
                        demos = demo_o
                    else:
                        demos = cat_dict_tensor([demos, demo_o])

                transitions["demo"] = demos

            j += 1
            yield transitions

    def train(self, step=0):
        train_info = Info()

        for transitions in self._load_mt_transitions():
            _train_info = self._update_network(transitions, train=True, step=step)
            train_info.add(_train_info)
        self._epoch += 1
        self._actor_lr_scheduler.step()

        train_info.add(
            {
                "actor_grad_norm": compute_gradient_norm(self._actor),
                "actor_weight_norm": compute_weight_norm(self._actor),
            }
        )
        train_info = train_info.get_dict(only_scalar=True)
        logger.info("BC loss %f", train_info["actor_loss"])
        return train_info
