"""
    reproduce SQIL algorithm using pytorch,
    check tensorflow version reference: https://github.com/rddy/sqil-atari/blob/master/sqil.py
    
    References
    ----------
    [1] SQIL: https://siddharth.io/files/sqil.pdf
    [2] Soft Q-learning (SQL): Tuomas Haarnoja, Haoran Tang, Pieter Abbeel, and Sergey Levine,
        "Reinforcement Learning with Deep Energy-Based Policies," International
        Conference on Machine Learning, 2017. https://arxiv.org/abs/1702.08165

    todo refactor (12/4)
"""
from collections import OrderedDict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import gym.spaces
from torch.optim.lr_scheduler import StepLR
from algorithms.expert_dataset import ExpertDataset

from .base_il_agent import BaseILAgent
from .sac_agent import SACAgent
from .dqn_agent import DQNAgent
from .dataset import ReplayBuffer, RandomSampler, ReplayBufferPerStep
from networks import Actor, Critic
from utils.info_dict import Info
from utils.logger import logger
from utils.mpi import mpi_average, mpi_sum
from utils.gym_env import spaces_to_shapes
from utils.pytorch import (
    optimizer_cuda,
    count_parameters,
    compute_gradient_norm,
    compute_weight_norm,
    sync_networks,
    sync_grads,
    to_tensor,
)
import copy

class SQILAgent(BaseILAgent):
    def __init__(self, config, ob_space, ac_space, env_ob_space, layout):

        config.batch_size = int(config.batch_size // 2)
        ### Set batch size so it's approximately half expert, half non-expert
        self.other_task_data_proportion = config.other_task_data_proportion
        self._policy_batch_size = int(config.batch_size // (1 + self.other_task_data_proportion))
        super().__init__(config, ob_space, ac_space, env_ob_space, layout)

        assert self._config.is_relabel_rew is False

        # load other expert demos
        if self.other_task_data_proportion > 0:
            print("Batch size: ", self._policy_batch_size, self._policy_batch_size * self.other_task_data_proportion, np.floor(self._policy_batch_size * self.other_task_data_proportion))
            assert config.other_demo_path is not None
            self._other_dataset = ExpertDataset(
                config.other_demo_path,
                config.demo_subsample_interval,
                ac_space,
                use_low_level=config.demo_low_level,
                sample_range_start=config.demo_sample_range_start,
                sample_range_end=config.demo_sample_range_end,
                num_task=config.num_task,
                target_taskID=None,
                with_taskID=config.with_taskID,
                is_sqil=True,
            )
            self._other_data_loader = torch.utils.data.DataLoader(
                self._other_dataset,
                batch_size=int(np.floor(self._policy_batch_size * self.other_task_data_proportion)),
                shuffle=True,
                drop_last=True,
            )
            self._other_data_iter = iter(self._other_data_loader)

        self._log_creation()

    def _log_creation(self):
        if self._config.is_chef:
            logger.info("Creating a SQIL agent")

    def store_episode(self, rollouts, step=0):
        rollouts["rew"] = [0.0] * len(rollouts["rew"])
        super().store_episode(rollouts, step)

    def train(self, step=0):
        train_info = Info()

        self._num_updates = 1
        for _ in range(self._config.num_actor_updates):
            transitions0 = self._buffer.sample(self._policy_batch_size)
            demo_batches = []
            try:
                expert_data = next(self._data_iter)
            except StopIteration:
                self._data_iter = iter(self._data_loader)
                expert_data = next(self._data_iter)
            demo_batches.append(copy.deepcopy(expert_data))

            if self.other_task_data_proportion > 0:
                try:
                    other_task_data = next(self._other_data_iter)
                except StopIteration:
                    self._other_data_iter = iter(self._other_data_loader)
                    other_task_data = next(self._other_data_iter)
                demo_batches.append(copy.deepcopy(other_task_data))

            # expert_data is a list of transitions, convert to dict of arrays
            # transitions1 = {}
            # for key in expert_data[0].keys():
            #     transitions1[key] = np.concatenate([d[key] for d in expert_data]) # TODO:check if expert_data changed
            # concate two dict of arrays
            transitions = {}
            # copied_expert_data = copy.deepcopy(expert_data)
            for k, v in transitions0.items():
                if isinstance(v, dict):
                    sub_keys = v.keys()
                    transitions[k] = {sub_key: np.concatenate([transitions0[k][sub_key]] +
                                                              [db[k][sub_key].cpu().detach().numpy() for db in demo_batches])
                                      for sub_key in sub_keys}
                    # transitions[k] = {sub_key: np.concatenate((transitions0[k][sub_key], copied_expert_data[k][sub_key].cpu().detach().numpy())) for sub_key in
                    #               sub_keys}
                else:
                    transitions[k] = np.concatenate([transitions0[k]] + [db[k].cpu().detach().numpy() for db in demo_batches])
                    # transitions[k] = np.concatenate((transitions0[k], copied_expert_data[k].cpu().detach().numpy()))
                
            
            _train_info = self._rl_agent._update_network(transitions)
            train_info.add(_train_info)

        return mpi_average(train_info.get_dict(only_scalar=True))


