from collections import OrderedDict

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.autograd as autograd
import torch.distributions
from torch.optim.lr_scheduler import StepLR

from .base_agent import BaseAgent
from .ppo_agent import PPOAgent
from .dataset import ReplayBuffer, RandomSampler
from .expert_dataset import ExpertDataset
from networks.discriminator import AIRLDiscriminator
from utils.info_dict import Info
from utils.logger import logger
from utils.mpi import mpi_average
from utils.pytorch import (
    optimizer_cuda,
    count_parameters,
    sync_networks,
    sync_grads,
    to_tensor,
)
import copy
import wandb

"""code based on https://github.com/toshikwa/gail-airl-ppo.pytorch/blob/master/gail_airl_ppo/algo/airl.py"""


class AIRLAgent(BaseAgent):
    def __init__(self, config, ob_space, ac_space, env_ob_space):
        super().__init__(config, ob_space, ac_space)

        self._ob_space = ob_space
        self._ac_space = ac_space

        if self._config.gail_rl_algo == "ppo":
            self._rl_agent = PPOAgent(config, ob_space, ac_space, env_ob_space)
        self._rl_agent.set_reward_function(self._predict_reward)

        # build up networks
        self._discriminator = AIRLDiscriminator(config, ob_space, ac_space)
        self._discriminator_loss = nn.BCEWithLogitsLoss()
        if self._config.D_i:
            self._D_i_list = []
            self.dis_loss_list = []
            for _ in range(self._config.num_task-1):
                self._D_i_list.append(AIRLDiscriminator(config, ob_space, ac_space if not config.gail_no_action else None))
                self.dis_loss_list.append(nn.BCEWithLogitsLoss())

        self._network_cuda(config.device)

        # build optimizers
        self._discriminator_optim = optim.Adam(
            self._discriminator.parameters(), lr=config.discriminator_lr
        )

        # build learning rate scheduler
        self._discriminator_lr_scheduler = StepLR(
            self._discriminator_optim,
            step_size=self._config.max_global_step // self._config.rollout_length // 5,
            gamma=0.5,
        )
        if self._config.D_i:
            self._discriminator_optim_list = []
            self._discriminator_lr_scheduler_list = []
            for i in range(self._config.num_task-1):
                self._discriminator_optim_list.append(optim.Adam(
                    self._D_i_list[i].parameters(), lr=config.discriminator_lr
                ))
                self._discriminator_lr_scheduler_list.append(StepLR(self._discriminator_optim_list[i],
                    step_size=self._config.max_global_step // self._config.rollout_length // 5,
                    gamma=0.5,
                ))


        # expert dataset
        if config.is_train:
            if self._config.D_i:
                def load_dataset(path):
                    self._dataset = ExpertDataset(
                        path,
                        config.demo_subsample_interval,
                        ac_space,
                        use_low_level=config.demo_low_level,
                        sample_range_start=config.demo_sample_range_start,
                        sample_range_end=config.demo_sample_range_end,
                        num_task=config.num_task,
                        target_taskID=None,
                        with_taskID=config.with_taskID,
                        num_target_demos=config.num_target_demos,
                        target_demo_path=config.target_demo_path,
                    )
                    _data_loader = torch.utils.data.DataLoader(
                        self._dataset,
                        batch_size=self._config.batch_size,
                        shuffle=True,
                        drop_last=True,
                    )
                    _data_iter = iter(_data_loader)
                    return _data_loader, _data_iter
            
                path = config.demo_path.split("#")
                self._data_loader, self._data_iter = [], []
                self._traj_loader, self._traj_iter = [], []
                for _p in path:
                    data_loader, data_iter = load_dataset(_p)
                    self._data_loader.append(data_loader)
                    self._data_iter.append(data_iter)

            else:
                self._dataset = ExpertDataset(
                    config.demo_path,
                    config.demo_subsample_interval,
                    ac_space,
                    use_low_level=config.demo_low_level,
                    sample_range_start=config.demo_sample_range_start,
                    sample_range_end=config.demo_sample_range_end,
                    num_task=config.num_task,
                    target_taskID=config.target_taskID,
                    with_taskID=config.with_taskID,
                    num_target_demos=config.num_target_demos,
                    target_demo_path=config.target_demo_path,
                )
                self._data_loader = torch.utils.data.DataLoader(
                    self._dataset,
                    batch_size=self._config.batch_size,
                    shuffle=True,
                    drop_last=True,
                )
                self._data_iter = iter(self._data_loader)

        # load other expert demos
        if config.pretrain_discriminator:
            self._other_dataset = ExpertDataset(
                config.other_demo_path,
                config.demo_subsample_interval,
                ac_space,
                use_low_level=config.demo_low_level,
                sample_range_start=config.demo_sample_range_start,
                sample_range_end=config.demo_sample_range_end,
                num_task=config.num_task,
                target_taskID=None,
            )
            self._other_data_loader = torch.utils.data.DataLoader(
                self._other_dataset,
                batch_size=self._config.batch_size,
                shuffle=True,
                drop_last=True,
            )
            self._other_data_iter = iter(self._other_data_loader)

        # policy dataset
        sampler = RandomSampler(image_crop_size = config.encoder_image_size)
        if self._config.with_taskID:
            keys = ["ob", "ob_next", "ac", "done", "rew", "ret", "adv","ac_before_activation", "log_pis", "id"]
        else:
            keys = ["ob", "ob_next", "ac", "done", "rew", "ret", "adv","ac_before_activation", "log_pis"]

        self._buffer = ReplayBuffer(keys, config.rollout_length, sampler.sample_func,)

        self._rl_agent.set_buffer(self._buffer)

        # update observation normalizer with dataset
        self.update_normalizer()

        self._log_creation()

    def pretrain_dis(self):
        num_iter = len(self._other_dataset) // self._config.batch_size
        
        def _train_for_other_task():
            num_file = self._other_dataset.num_files # = 3
            one_hot_encode = self._other_dataset.one_hot_encode
            for i in range(num_file):
                target_taskID = np.array(one_hot_encode[i]).reshape(1, -1)
                task_id = torch.as_tensor(target_taskID)

                for s_ in range(num_iter*num_file):
                    try:
                        other_expert_data = next(self._other_data_iter)
                    except StopIteration:
                        self._other_data_iter = iter(self._other_data_loader)
                        other_expert_data = next(self._other_data_iter)
                             
                    target_data, other_data = {}, {}
                    current_data_indx = (other_expert_data["id"]==task_id).all(1)
                    other_data_index = ~current_data_indx

                    for k, v in other_expert_data.items():
                        if isinstance(v, dict):
                            sub_keys = v.keys()
                            target_data[k] = {sub_key: v[sub_key][current_data_indx] for sub_key in sub_keys}
                            other_data[k] = {sub_key: v[sub_key][other_data_index] for sub_key in sub_keys}
                        else:
                            target_data[k] = v[current_data_indx]
                            other_data[k] = v[other_data_index]

                    len_current_data, len_other_data = sum(current_data_indx), sum(other_data_index)
                    if len_current_data != len_other_data:
                        bs = min(len_current_data, len_other_data)
                        if len_current_data > len_other_data:
                            for k, v in target_data.items():
                                if isinstance(v, dict):
                                    sub_keys = v.keys()
                                    target_data[k] = {sub_key: v[sub_key][:bs] for sub_key in sub_keys}
                                else:
                                    target_data[k] = v[:bs]
                        else:
                            for k, v in other_data.items():
                                if isinstance(v, dict):
                                    sub_keys = v.keys()
                                    other_data[k] = {sub_key: v[sub_key][:bs] for sub_key in sub_keys}
                                else:
                                    other_data[k] = v[:bs]

                    _train_info = self._update_discriminator(other_data, target_data)
                    for k_, v_ in _train_info.items():
                        if np.isscalar(v_) or (hasattr(v_, "shape") and np.prod(v_.shape) == 1):
                            wandb.log({"pretrain_other_task_dis/%s" % k_: v_}, step=s_)
            logger.info("Pretrain other tasks finished (first method)")

        if self._config.pretrain_discriminator == 1:
            _train_for_other_task()
        
        for s_ in range(num_iter):
            try:
                expert_data = next(self._data_iter)
            except StopIteration:
                self._data_iter = iter(self._data_loader)
                expert_data = next(self._data_iter)

            try:
                other_expert_data = next(self._other_data_iter)
            except StopIteration:
                self._other_data_iter = iter(self._other_data_loader)
                other_expert_data = next(self._other_data_iter)

            _train_info = self._update_discriminator(other_expert_data, expert_data)

            for k, v in _train_info.items():
                if np.isscalar(v) or (hasattr(v, "shape") and np.prod(v.shape) == 1):
                    wandb.log({"pretrain_dis/%s" % k: v}, step=s_)

        logger.info("Pretrain discriminator finished") 


    def _predict_reward(self, ob, done, log_pis, ob_next, ID=None):

        with torch.no_grad():
            ret = self._discriminator(ob, done, log_pis, ob_next, ID)
            eps = 1e-10
            s = torch.sigmoid(ret)
            if self._config.gail_reward == "vanilla":  # do i need to delete this?
                reward = -(1 - s + eps).log()
            elif self._config.gail_reward == "gan":
                reward = (s + eps).log() - (1 - s + eps).log()
            elif self._config.gail_reward == "d":
                reward = ret

            if self._config.D_i:
                ret_i = []
                for dis_index in range(self._config.num_task-1):
                    ret_ = self._D_i_list[dis_index](ob, done, log_pis, ob_next, ID)
                    ret_i.append(ret_)
                
                if self._config.discriminator_reward_type == "general":
                    ret_i_ = torch.cat(ret_i, dim=1)
                    ret_i_ = ret_i_.sum(dim=1)
                    reward = reward.squeeze(1)
                    reward += ret_i_*self._config.D_i_coeff
                elif self._config.discriminator_reward_type == "specific":
                    # average
                    ret_i_ = torch.cat(ret_i, dim=1)
                    ret_i_ = ret_i_.mean(dim=1)
                    reward = reward.squeeze(1)
                    reward -= ret_i_*self._config.D_i_coeff

        if self._config.D_i:
            return reward, ret_i
        else:
            return reward

    def predict_reward(self, ob, done, log_pis, ob_next, ID=None):
        ob = self.normalize(ob)
        ob = to_tensor(ob, self._config.device)
        ob_next = self.normalize(ob_next)
        ob_next = to_tensor(ob_next, self._config.device)
        done = to_tensor(done, self._config.device)
        log_pis = to_tensor(log_pis, self._config.device)
        if ID is not None:
            ID = to_tensor(ID, self._config.device)

        if self._config.D_i:
            reward, ret_i = self._predict_reward(ob, done, log_pis, ob_next, ID)
            scalar = (torch.numel(reward) == 1)
            if scalar:
                reward = reward.cpu().item()
            else:
                reward = reward.cpu().numpy()
            for i in range(len(ret_i)):
                if scalar:
                    ret_i[i] = ret_i[i].cpu().item()
                else:
                    ret_i[i] = ret_i[i].cpu().numpy()
            return reward.cpu().item(), ret_i
        else:
            reward = self._predict_reward(ob, done, log_pis, ob_next, ID)
            if torch.numel(v) == 1:
                return reward.cpu().item()
            else:
                return reward.cpu().numpy()

    def _log_creation(self):
        if self._config.is_chef:
            logger.info("Creating a AIRL agent")
            logger.info(
                "The discriminator has %d parameters",
                count_parameters(self._discriminator),
            )

    def store_episode(self, rollouts):
        self._rl_agent.store_episode(rollouts)

    def state_dict(self):
        if self._config.D_i:
            return {
                "rl_agent": self._rl_agent.state_dict(),
                "discriminator_state_dict": self._discriminator.state_dict(),
                "discriminator_optim_state_dict": self._discriminator_optim.state_dict(),
                "ob_norm_state_dict": self._ob_norm.state_dict(),
                "D_i_list": [self._D_i_list[i].state_dict() for i in range(self._config.num_task-1)],
                "discriminator_optim_state_dict_list": [self._discriminator_optim_list[i].state_dict() for i in range(self._config.num_task-1)],
            }
        else:
                return {
                "rl_agent": self._rl_agent.state_dict(),
                "discriminator_state_dict": self._discriminator.state_dict(),
                "discriminator_optim_state_dict": self._discriminator_optim.state_dict(),
                "ob_norm_state_dict": self._ob_norm.state_dict(),
            }

    def load_state_dict(self, ckpt):
        if "rl_agent" in ckpt:
            self._rl_agent.load_state_dict(ckpt["rl_agent"])
        else:
            self._rl_agent.load_state_dict(ckpt)
            self._network_cuda(self._config.device)
            return

        self._discriminator.load_state_dict(ckpt["discriminator_state_dict"])
        self._ob_norm.load_state_dict(ckpt["ob_norm_state_dict"])
        if self._config.D_i:
            for i in range(self._config.num_task-1):
                self._D_i_list[i].load_state_dict(ckpt["D_i_list"][i])
                self._discriminator_optim_list[i].load_state_dict(
                    ckpt["discriminator_optim_state_dict_list"]
                )
        self._network_cuda(self._config.device)

        self._discriminator_optim.load_state_dict(
            ckpt["discriminator_optim_state_dict"]
        )
        optimizer_cuda(self._discriminator_optim, self._config.device)

    def _network_cuda(self, device):
        self._discriminator.to(device)
        if self._config.D_i:
            for i in range(self._config.num_task-1):
                self._D_i_list[i].to(device)

    def sync_networks(self):
        self._rl_agent.sync_networks()
        sync_networks(self._discriminator)

    def update_normalizer(self, obs=None):
        """ Updates normalizers for discriminator and PPO agent. """
        if self._config.ob_norm:
            if obs is None:
                data_loader = torch.utils.data.DataLoader(
                    self._dataset,
                    batch_size=self._config.batch_size,
                    shuffle=False,
                    drop_last=False,
                )
                for obs in data_loader:
                    super().update_normalizer(obs)
                    self._rl_agent.update_normalizer(obs)
            else:
                super().update_normalizer(obs)
                self._rl_agent.update_normalizer(obs)

    def _blend(self, policy_data, other_expert_data, blend_ratio):
        """ Blend expert and policy data. """
        other_demo_num = int(self._config.batch_size*blend_ratio)
        new_transitions = {}
        for k, v in other_expert_data.items():
            if isinstance(v, dict):
                sub_keys = v.keys()
                new_transitions[k] = {
                    sub_key: np.concatenate((policy_data[k][sub_key][:other_demo_num], v[sub_key][other_demo_num:])) for sub_key in sub_keys
                }
            else:
                new_transitions[k] = np.concatenate((policy_data[k][:other_demo_num], v[other_demo_num:]))

        return new_transitions

    def train(self, step=0):
        train_info = Info()

        self._discriminator_lr_scheduler.step()
        if self._config.D_i:
            for i in range(self._config.num_task-1):
                self._discriminator_lr_scheduler_list[i].step()

        num_batches = (  # = 3
                self._config.rollout_length
                // self._config.batch_size
                // self._config.discriminator_update_freq
        )
        assert num_batches > 0
        if self._config.D_i:
            for _ in range(num_batches):
                policy_data = self._buffer.sample(self._config.batch_size)
                try:
                    expert_data = next(self._data_iter[self._config.target_task_index_in_demo_path])
                except StopIteration:
                    self._data_iter[self._config.target_task_index_in_demo_path] = iter(self._data_loader[self._config.target_task_index_in_demo_path])
                    expert_data = next(self._data_iter[self._config.target_task_index_in_demo_path])
                
                # add other demos to policy samples
                policy_data_deepcopy = copy.deepcopy(policy_data)
                # if self._config.blend_ratio!=0.0 and step<self._config.warm_up_steps+self._config.blend_steps:
                #     try:
                #         other_expert_data = next(self._other_data_iter)
                #     except StopIteration:
                #         self._other_data_iter = iter(self._other_data_loader)
                #         other_expert_data = next(self._other_data_iter)
                #     #blend other demos with policy samples
                #     policy_data_deepcopy = self._blend(policy_data_deepcopy, other_expert_data)
                
                _train_info = self._update_discriminator(policy_data_deepcopy, expert_data)
                train_info.add(_train_info)
        else:
            for _ in range(num_batches):
                policy_data = self._buffer.sample(self._config.batch_size)
                try:
                    expert_data = next(self._data_iter)
                except StopIteration:
                    self._data_iter = iter(self._data_loader)
                    expert_data = next(self._data_iter)
                
                # add other demos to policy samples
                policy_data_deepcopy = copy.deepcopy(policy_data)
                if self._config.blend_ratio!=0.0 and step<self._config.warm_up_steps+self._config.blend_steps:
                    try:
                        other_expert_data = next(self._other_data_iter)
                    except StopIteration:
                        self._other_data_iter = iter(self._other_data_loader)
                        other_expert_data = next(self._other_data_iter)
                    #blend other demos with policy samples
                    policy_data_deepcopy = self._blend(policy_data_deepcopy, other_expert_data)
                
                _train_info = self._update_discriminator(policy_data_deepcopy, expert_data)
                train_info.add(_train_info)

        
        # update D_i
        if self._config.D_i:
            total_num_blend = self._config.warm_up_steps+self._config.blend_steps
            half_num_blend = self._config.warm_up_steps+self._config.blend_steps/2

            for _ in range(num_batches):
                policy_data = self._buffer.sample(self._config.batch_size)
                D_i_index = 0
                for index in range(self._config.num_task):
                    if index == self._config.target_task_index_in_demo_path:
                        continue
                    
                    try:
                        expert_data = next(self._data_iter[index])
                    except StopIteration:
                        self._data_iter[index] = iter(self._data_loader[index])
                        expert_data = next(self._data_iter[index])
                    
                    # add other demos to policy samples
                    policy_data_deepcopy = copy.deepcopy(policy_data)
                    if self._config.blend_ratio != 0.0 and step < total_num_blend:
                        try:
                            other_expert_data = next(self._data_iter[index])
                        except StopIteration:
                            self._data_iter[index] = iter(self._data_loader[index])
                            other_expert_data = next(self._data_iter[index])
                        #blend other demos with policy samples
                        if step < half_num_blend:
                            policy_data_deepcopy = self._blend(policy_data_deepcopy, other_expert_data, self._config.blend_ratio)
                        else:
                            policy_data_deepcopy = self._blend(policy_data_deepcopy, other_expert_data, 0.8)
                            
                    _train_info = self._update_discriminator_i(policy_data_deepcopy, expert_data, D_i_index=D_i_index)
                    train_info.add(_train_info)
                    D_i_index += 1

        _train_info = self._rl_agent.train()
        train_info.add(_train_info)

        if self._config.D_i:
            for _ in range(num_batches):
                try:
                    expert_data = next(self._data_iter[3]) # use the last data loader, maybe change later
                except StopIteration:
                    self._data_iter[3] = iter(self._data_loader[3])
                    expert_data = next(self._data_iter[3])
                self.update_normalizer(expert_data["ob"])
        else:
            for _ in range(num_batches):
                try:
                    expert_data = next(self._data_iter)
                except StopIteration:
                    self._data_iter = iter(self._data_loader)
                    expert_data = next(self._data_iter)
                self.update_normalizer(expert_data["ob"])

        return train_info.get_dict(only_scalar=True)

    def data_preprocess(self, policy_data, expert_data):
        _to_tensor = lambda x: to_tensor(x, self._config.device)
        # pre-process observations
        p_o = policy_data["ob"]
        p_o = self.normalize(p_o)
        p_o = _to_tensor(p_o)

        p_ac = _to_tensor(policy_data["ac"])
        p_done = _to_tensor(policy_data["done"])

        p_o_next = policy_data["ob_next"]
        p_o_next = self.normalize(p_o_next)
        p_o_next = _to_tensor(p_o_next)

        e_o = expert_data["ob"]
        e_o = self.normalize(e_o)
        e_o = _to_tensor(e_o)

        e_ac = _to_tensor(expert_data["ac"])
        e_done = _to_tensor(expert_data["done"])

        e_o_next = expert_data["ob_next"]
        e_o_next = self.normalize(e_o_next)
        e_o_next = _to_tensor(e_o_next)

        if self._config.with_taskID:
            p_id = policy_data["id"] 
            e_id = expert_data["id"]
            p_id = _to_tensor(p_id)
            e_id = _to_tensor(e_id)
            return p_o, p_ac, p_done, p_o_next, e_o, e_ac, e_done, e_o_next, p_id, e_id
        else:
            return p_o, p_ac, p_done, p_o_next, e_o, e_ac, e_done, e_o_next, None, None
    
    def _update_discriminator_i(self, policy_data, expert_data, D_i_index):
        info = Info()
        p_o, p_ac, p_done, p_o_next, e_o, e_ac, e_done, e_o_next, p_id, e_id = self.data_preprocess(policy_data, expert_data)


        # Calculate log probabilities of expert actions.
        with torch.no_grad():
            _, _, e_log_pie, _ = self._rl_agent.actor.act(ob=e_o, return_log_prob=True, detach_conv=True, task_id=e_id, ac=e_ac)
            _, _, p_log_pie, _ = self._rl_agent.actor.act(ob=p_o, return_log_prob=True, detach_conv=True, task_id=p_id, ac=p_ac)
        
        p_logit = self._D_i_list[D_i_index](p_o, p_done, p_log_pie, p_o_next, ID=None)
        e_logit = self._D_i_list[D_i_index](e_o, e_done, e_log_pie, e_o_next, ID=None)
        

        p_output = torch.sigmoid(p_logit)
        e_output = torch.sigmoid(e_logit)

        p_loss = self.dis_loss_list[D_i_index](
            p_logit, torch.zeros_like(p_logit).to(self._config.device)
        )
        e_loss = self.dis_loss_list[D_i_index](
            e_logit, torch.ones_like(e_logit).to(self._config.device)
        )

        logits = torch.cat([p_logit, e_logit], dim=0)
        entropy = torch.distributions.Bernoulli(logits=logits).entropy().mean()
        entropy_loss = -self._config.gail_entropy_loss_coeff * entropy

        gail_loss = p_loss + e_loss + entropy_loss #+ grad_pen_loss

        # update the discriminator
        self._D_i_list[D_i_index].zero_grad()
        gail_loss.backward()
        sync_grads(self._D_i_list[D_i_index])
        self._discriminator_optim_list[D_i_index].step()

        info["D_"+str(D_i_index)+"_policy_output"] = p_output.mean().detach().cpu().item()
        info["D_"+str(D_i_index)+"_expert_output"] = e_output.mean().detach().cpu().item()
        info["D_"+str(D_i_index)+"_policy_loss"] = p_loss.detach().cpu().item()
        info["D_"+str(D_i_index)+"_expert_loss"] = e_loss.detach().cpu().item()
        info["D_"+str(D_i_index)+"_loss"] = gail_loss.detach().cpu().item()

        return mpi_average(info.get_dict(only_scalar=True))


    def _update_discriminator(self, policy_data, expert_data):
        info = Info()

        p_o, p_ac, p_done, p_o_next, e_o, e_ac, e_done, e_o_next, p_id, e_id = self.data_preprocess(policy_data, expert_data)

        if not self._config.with_taskID:
            p_id = None
            e_id = None

        # Calculate log probabilities of expert actions.
        with torch.no_grad():
            _, _, e_log_pie, _ = self._rl_agent.actor.act(ob=e_o, return_log_prob=True, detach_conv=True, task_id=e_id, ac=e_ac)
            _, _, p_log_pie, _ = self._rl_agent.actor.act(ob=p_o, return_log_prob=True, detach_conv=True, task_id=p_id, ac=p_ac)

        p_logit = self._discriminator(p_o, p_done, p_log_pie, p_o_next, ID=None)
        e_logit = self._discriminator(e_o, e_done, e_log_pie, e_o_next, ID=None)

        p_output = torch.sigmoid(p_logit)
        e_output = torch.sigmoid(e_logit)

        p_loss = self._discriminator_loss(
            p_logit, torch.zeros_like(p_logit).to(self._config.device)
        )
        e_loss = self._discriminator_loss(
            e_logit, torch.ones_like(e_logit).to(self._config.device)
        )

        airl_loss = p_loss + e_loss

        # update the discriminator
        self._discriminator.zero_grad()
        airl_loss.backward()
        sync_grads(self._discriminator)
        self._discriminator_optim.step()

        info["airl_policy_output"] = p_output.mean().detach().cpu().item()
        info["airl_expert_output"] = e_output.mean().detach().cpu().item()
        info["airl_policy_loss"] = p_loss.detach().cpu().item()
        info["airl_expert_loss"] = e_loss.detach().cpu().item()
        info["airl_loss"] = airl_loss.detach().cpu().item()

        return mpi_average(info.get_dict(only_scalar=True))

