from collections import OrderedDict

import numpy as np
import wandb
import torch
import torch.nn as nn
import torch.optim as optim
import torch.autograd as autograd
import torch.distributions
from torch.optim.lr_scheduler import StepLR

from .base_agent import BaseAgent
from .ppo_agent import PPOAgent
from .dataset import ReplayBuffer, RandomSampler
from .expert_dataset import ExpertDataset
from networks.discriminator import ACDiscriminator
from utils.info_dict import Info
from utils.logger import logger
from utils.mpi import mpi_average
from utils.pytorch import (
    optimizer_cuda,
    count_parameters,
    sync_networks,
    sync_grads,
    to_tensor,
)
import algorithms
import copy

class ACGAILAgent(BaseAgent):
    def __init__(self, config, ob_space, ac_space, env_ob_space):
        super().__init__(config, ob_space, ac_space)

        self._ob_space = ob_space
        self._ac_space = ac_space

        if self._config.gail_rl_algo is not None:
            self._rl_agent = algorithms.get_agent_by_name(self._config.gail_rl_algo)(config, ob_space, ac_space,
                                                                                     env_ob_space)

        self._rl_agent.set_reward_function(self._predict_reward)

        # build up networks
        self._discriminator = ACDiscriminator(
            config, ob_space, ac_space if not config.gail_no_action else None
        )
        self._discriminator_loss = nn.BCEWithLogitsLoss()
        self._label_loss = nn.MSELoss()
        self._network_cuda(config.device)

        # build optimizers
        self._discriminator_optim = optim.Adam(
            self._discriminator.parameters(), lr=config.discriminator_lr
        )

        # build learning rate scheduler
        self._discriminator_lr_scheduler = StepLR(
            self._discriminator_optim,
            step_size=self._config.max_global_step // self._config.rollout_length // 5,
            gamma=0.5,
        )

        # expert dataset
        if config.is_train:
            self._dataset = ExpertDataset( # expert demos for all tasks (offline data + target data)
                config.demo_path,
                config.demo_subsample_interval,
                ac_space,
                use_low_level=config.demo_low_level,
                sample_range_start=config.demo_sample_range_start,
                sample_range_end=config.demo_sample_range_end,
                num_task=config.num_task,
                target_taskID=config.target_taskID,
            )
            self._data_loader = torch.utils.data.DataLoader(
                self._dataset,
                batch_size=self._config.batch_size,
                shuffle=True,
                drop_last=True,
            )
            self._data_iter = iter(self._data_loader)

        # load other expert demos
        if config.bc_loss_coeff > 0.0:
            self._rl_agent.set_bc_data_iter(self._data_iter)
            self._rl_agent.set_bc_data_loader(self._data_loader)
            
        # policy dataset
        sampler = RandomSampler(image_crop_size = config.encoder_image_size)()
        self._buffer = ReplayBuffer(
            [
                "ob",
                "ob_next",
                "ac",
                "done",
                "rew",
                "ret",
                "adv",
                "ac_before_activation",
                "id",
            ],
            config.rollout_length,
            sampler.sample_func,
        )

        self._rl_agent.set_buffer(self._buffer)

        # update observation normalizer with dataset
        self.update_normalizer()

        self._log_creation()

    def pretrain_dis(self):
        #load target demos
        self._target_dataset = ExpertDataset(
                self._config.target_demo_path,
                self._config.demo_subsample_interval,
                use_low_level=self._config.demo_low_level,
                sample_range_start=self._config.demo_sample_range_start,
                sample_range_end=self._config.demo_sample_range_end,
            )
        target_goal_obs = []
        for data in self._target_dataset._data:
            goal_obs = data["ob"]["ob"][-self._config.num_task*2:]
            target_goal_obs.append(goal_obs)
        
        def uniq(lst):
            last = lst[0]
            yield last
            for item in lst:
                if not (item - last).any():
                    continue
                yield item
                last = item

        def deduplicate(l):
            return list(uniq(l))
        target_goal_obs = deduplicate(target_goal_obs)    
        
        self._other_dataset2 = ExpertDataset(
                self._config.other_demo_path,
                self._config.demo_subsample_interval,
                use_low_level=self._config.demo_low_level,
                sample_range_start=self._config.demo_sample_range_start,
                sample_range_end=self._config.demo_sample_range_end,
                num_task=self._config.num_task,
            )
        
        def get_task_pos(id, goal_obs):
            if not (id - [1,0,0,0]).any():
                task_pos = goal_obs[2:4]
            elif not (id - [0,1,0,0]).any():
                task_pos = goal_obs[4:6]
            elif not (id - [0,0,1,0]).any():
                task_pos = goal_obs[6:8]
            elif not (id - [0,0,0,1]).any():
                task_pos = goal_obs[0:2]
            return task_pos
        
        first_data = self._other_dataset2._data[0]
        task_pos_ = copy.deepcopy(get_task_pos(first_data["id"], first_data["ob"]["ob"][-self._config.num_task*2:]))
        target_taskID = np.array(self._config.target_taskID).reshape(-1)
        for target in target_goal_obs:
            target_pos = get_task_pos(target_taskID, target)
            # measure task postion and target position is different
            if np.linalg.norm(task_pos_ - target_pos) > 0.1:
                target_state = copy.deepcopy(target)
                break
        i = 0
        for data in self._other_dataset2._data:
            data["id"] = target_taskID
            if i%2 != 0:
                i += 1
                continue
            i += 1
            goal_obs = copy.deepcopy(data["ob"]["ob"][-self._config.num_task*2:])
            task_pos = copy.deepcopy(get_task_pos(data["id"], goal_obs))
            
            if not (task_pos_ - task_pos).any():
                data["ob"]["ob"][-self._config.num_task*2:] = target_state
                data["ob_next"]["ob"][-self._config.num_task*2:] = target_state
                continue
            else:
                task_pos_ = copy.deepcopy(task_pos)

            for target in target_goal_obs:
                target_pos = get_task_pos(target_taskID, target)
                # measure task postion and target position is different
                if np.linalg.norm(task_pos_ - target_pos) > 0.1:
                    target_state = copy.deepcopy(target)
                    break
            data["ob"]["ob"][-self._config.num_task*2:] = target_state
            data["ob_next"]["ob"][-self._config.num_task*2:] = target_state

        self._other_data_loader2 = torch.utils.data.DataLoader(
            self._other_dataset2,
            batch_size=self._config.batch_size,
            shuffle=True,
            drop_last=True,
        )
        self._other_data_iter2 = iter(self._other_data_loader2)
        num_iter = len(self._dataset) // self._config.batch_size
        
        for s_ in range(num_iter):
            try:
                expert_data = next(self._data_iter)
            except StopIteration:
                self._data_iter = iter(self._data_loader)
                expert_data = next(self._data_iter)

            try:
                other_expert_data = next(self._other_data_iter2)
            except StopIteration:
                self._other_data_iter2 = iter(self._other_data_loader2)
                other_expert_data = next(self._other_data_iter2)

            _train_info = self._update_discriminator(other_expert_data, expert_data)

            for k, v in _train_info.items():
                if np.isscalar(v) or (hasattr(v, "shape") and np.prod(v.shape) == 1):
                    wandb.log({"pretrain_dis/%s" % k: v}, step=s_)

        logger.info("Pretrain discriminator finished") 

    def _predict_reward(self, ob, ac, ID=None, *argv):
        if self._config.gail_no_action:
            ac = None
        with torch.no_grad():
            ret, _ = self._discriminator(ob, ac, ID)
            eps = 1e-10
            s = torch.sigmoid(ret)
            if self._config.gail_reward == "vanilla":
                reward = -(1 - s + eps).log()
            elif self._config.gail_reward == "gan":
                reward = (s + eps).log() - (1 - s + eps).log()
            elif self._config.gail_reward == "d":
                reward = ret
        return reward

    def predict_reward(self, ob, ac=None, ID=None, *argv):
        ob = self.normalize(ob)
        ob = to_tensor(ob, self._config.device)
        if self._config.gail_no_action:
            ac = None
        if ac is not None:
            ac = to_tensor(ac, self._config.device)
        if ID is not None:
            ID = to_tensor(ID, self._config.device)

        reward = self._predict_reward(ob, ac, ID, argv)

        if torch.numel(v) == 1:
            return reward.cpu().item()
        else:
            return reward.cpu().numpy()

    def _log_creation(self):
        if self._config.is_chef:
            logger.info("Creating a GAIL agent")
            logger.info(
                "The discriminator has %d parameters",
                count_parameters(self._discriminator),
            )

    def store_episode(self, rollouts):
        self._rl_agent.store_episode(rollouts)

    def state_dict(self):
        return {
            "rl_agent": self._rl_agent.state_dict(),
            "discriminator_state_dict": self._discriminator.state_dict(),
            "discriminator_optim_state_dict": self._discriminator_optim.state_dict(),
            "ob_norm_state_dict": self._ob_norm.state_dict(),
        }

    def load_state_dict(self, ckpt):
        if "rl_agent" in ckpt:
            self._rl_agent.load_state_dict(ckpt["rl_agent"])
        else:
            self._rl_agent.load_state_dict(ckpt)
            self._network_cuda(self._config.device)
            return

        self._discriminator.load_state_dict(ckpt["discriminator_state_dict"])
        self._ob_norm.load_state_dict(ckpt["ob_norm_state_dict"])
        self._network_cuda(self._config.device)

        self._discriminator_optim.load_state_dict(
            ckpt["discriminator_optim_state_dict"]
        )
        optimizer_cuda(self._discriminator_optim, self._config.device)

    def _network_cuda(self, device):
        self._discriminator.to(device)

    def sync_networks(self):
        self._rl_agent.sync_networks()
        sync_networks(self._discriminator)

    def update_normalizer(self, obs=None):
        """ Updates normalizers for discriminator and PPO agent. """
        if self._config.ob_norm:
            if obs is None:
                data_loader = torch.utils.data.DataLoader(
                    self._dataset,
                    batch_size=self._config.batch_size,
                    shuffle=False,
                    drop_last=False,
                )
                for obs in data_loader:
                    super().update_normalizer(obs)
                    self._rl_agent.update_normalizer(obs)
            else:
                super().update_normalizer(obs)
                self._rl_agent.update_normalizer(obs)

    def _blend(self, policy_data, other_expert_data):
        """ Blend expert and policy data. """
        if self._config.blend_ratio:
            other_demo_num = int(self._config.batch_size*self._config.blend_ratio)
            new_transitions = {}
            for k, v in other_expert_data.items():
                if isinstance(v, dict):
                    sub_keys = v.keys()
                    new_transitions[k] = {
                        sub_key: np.concatenate((policy_data[k][sub_key][:other_demo_num], v[sub_key][other_demo_num:])) for sub_key in sub_keys
                    }
                else:
                    new_transitions[k] = np.concatenate((policy_data[k][:other_demo_num], v[other_demo_num:]))

            # shuffle policy data

            return new_transitions

    def train(self, step=0):
        train_info = Info()

        self._discriminator_lr_scheduler.step()

        num_batches = (  # = 3
                self._config.rollout_length
                // self._config.batch_size
                // self._config.discriminator_update_freq
        )
        assert num_batches > 0
        for _ in range(num_batches):
            policy_data = self._buffer.sample(self._config.batch_size)

            try:
                expert_data = next(self._data_iter)
            except StopIteration:
                self._data_iter = iter(self._data_loader)
                expert_data = next(self._data_iter)

            # add other demos to policy samples
            policy_data_deepcopy = copy.deepcopy(policy_data)
            if self._config.blend_ratio != 0.0 and step < self._config.warm_up_steps+self._config.blend_steps:
                try:
                    other_expert_data = next(self._other_data_iter)
                except StopIteration:
                    self._other_data_iter = iter(self._other_data_loader)
                    other_expert_data = next(self._other_data_iter)
                #blend other demos with policy samples
                policy_data_deepcopy = self._blend(policy_data_deepcopy, other_expert_data)
            
            _train_info = self._update_discriminator(policy_data_deepcopy, expert_data)
            train_info.add(_train_info)

        _train_info = self._rl_agent.train()
        train_info.add(_train_info)

        for _ in range(num_batches):
            try:
                expert_data = next(self._data_iter)
            except StopIteration:
                self._data_iter = iter(self._data_loader)
                expert_data = next(self._data_iter)
            self.update_normalizer(expert_data["ob"])

        return train_info.get_dict(only_scalar=True)

    def _update_discriminator(self, policy_data, expert_data, pretrain_discriminator=False):
        info = Info()

        _to_tensor = lambda x: to_tensor(x, self._config.device)
        # pre-process observations
        p_o = policy_data["ob"]
        p_o = self.normalize(p_o)
        p_o = _to_tensor(p_o)

        e_o = expert_data["ob"]
        e_o = self.normalize(e_o)
        e_o = _to_tensor(e_o)
        
        if self._config.gail_no_action:
            p_ac = None
            e_ac = None
        else:
            p_ac = _to_tensor(policy_data["ac"])
            e_ac = _to_tensor(expert_data["ac"])

        if self._config.dis_with_taskID:
            p_id = _to_tensor(policy_data["id"])
            e_id = _to_tensor(expert_data["id"])
        else:
            p_id = None
            e_id = None
        
        # if self._config.is_saliency_map:
        #     p_o['ob'].requires_grad_()
        #     p_id.requires_grad_()

        p_logit, p_label = self._discriminator(p_o, p_ac, p_id)
        e_logit, e_label = self._discriminator(e_o, e_ac, e_id)

        p_label_output = torch.nn.Softmax(dim=1)(p_label)
        e_label_output = torch.nn.Softmax(dim=1)(e_label)
        
        if self._config.label_taskID:
            p_true_label = _to_tensor(policy_data["id"])
            e_true_label = _to_tensor(expert_data["id"])
            p_label_loss = self._label_loss(p_label_output, p_true_label)
            e_label_loss = self._label_loss(e_label_output, e_true_label)
        
        if self._config.label_goal_obs:
            p_label_loss = self._label_loss(p_label_output, p_o[-self._config.num_task*2:])
            e_label_loss = self._label_loss(e_label_output, e_o[-self._config.num_task*2:])
        label_loss = p_label_loss + e_label_loss
        
        # update the discriminator using the label loss
        # self._discriminator.zero_grad()
        # label_loss.backward()
        # sync_grads(self._discriminator)
        # self._discriminator_optim.step()

        p_output = torch.sigmoid(p_logit)
        e_output = torch.sigmoid(e_logit)

        p_loss = self._discriminator_loss(
            p_logit, torch.zeros_like(p_logit).to(self._config.device)
        )
        e_loss = self._discriminator_loss(
            e_logit, torch.ones_like(e_logit).to(self._config.device)
        )

        
        # if self._config.is_saliency_map:
        #     p_loss.backward()
        #     p_o_grad = p_o['ob'].grad.data.abs()
        #     p_id_grad = p_id.grad.data.abs()
        #     saliency = torch.cat([p_o_grad, p_id_grad], dim=-1)
        #     #saliency = saliency.reshape(224, 224)

        logits = torch.cat([p_logit, e_logit], dim=0)
        entropy = torch.distributions.Bernoulli(logits=logits).entropy().mean()
        entropy_loss = -self._config.gail_entropy_loss_coeff * entropy

        grad_pen = self._compute_grad_pen(p_o, p_ac, e_o, e_ac, p_id, pretrain_discriminator)
        grad_pen_loss = self._config.gail_grad_penalty_coeff * grad_pen

        gail_loss = p_loss + e_loss + entropy_loss + grad_pen_loss + label_loss*10.0

        # update the discriminator
        self._discriminator.zero_grad()
        gail_loss.backward()
        sync_grads(self._discriminator)
        self._discriminator_optim.step()

        info["gail_neg_output"] = p_output.mean().detach().cpu().item()
        info["gail_pos_output"] = e_output.mean().detach().cpu().item()
        info["gail_entropy"] = entropy.detach().cpu().item()
        info["gail_neg_loss"] = p_loss.detach().cpu().item()
        info["gail_pos_loss"] = e_loss.detach().cpu().item()
        info["gail_entropy_loss"] = entropy_loss.detach().cpu().item()
        info["gail_grad_pen"] = grad_pen.detach().cpu().item()
        info["gail_grad_loss"] = grad_pen_loss.detach().cpu().item()
        info["gail_loss"] = gail_loss.detach().cpu().item()
        # info["gail_saliency"] = saliency.mean.detach().cpu().numpy()
        info["gail_label_loss"] = label_loss.detach().cpu().item()
        info["gail_neg_label_loss"] = p_label_loss.detach().cpu().item()
        info["gail_pos_label_loss"] = e_label_loss.detach().cpu().item()

        return mpi_average(info.get_dict(only_scalar=True))

    def _compute_grad_pen(self, policy_ob, policy_ac, expert_ob, expert_ac, policy_id=None, pretrain_discriminator=False):
        if pretrain_discriminator:
            if isinstance(policy_ob, OrderedDict) or isinstance(policy_ob, dict):
                batch_size = list(policy_ob.values())[0].shape[0]
            else:
                batch_size = policy_ob[0].shape[0]
        else:
            batch_size = self._config.batch_size
        alpha = torch.rand(batch_size, 1, device=self._config.device)

        def blend_dict(a, b, alpha):
            if isinstance(a, dict):
                return OrderedDict(
                    [(k, blend_dict(a[k], b[k], alpha)) for k in a.keys()]
                )
            elif isinstance(a, list):
                return [blend_dict(a[i], b[i], alpha) for i in range(len(a))]
            else:
                expanded_alpha = alpha.expand_as(a)
                ret = expanded_alpha * a + (1 - expanded_alpha) * b
                ret.requires_grad = True
                return ret

        interpolated_ob = blend_dict(policy_ob, expert_ob, alpha)
        inputs = list(interpolated_ob.values())
        if policy_ac is not None:
            interpolated_ac = blend_dict(policy_ac, expert_ac, alpha)
            inputs = inputs + list(interpolated_ob.values())
        else:
            interpolated_ac = None

        interpolated_logit, _ = self._discriminator(interpolated_ob, interpolated_ac, policy_id)
        ones = torch.ones(interpolated_logit.size(), device=self._config.device)

        # if policy_id is not None:
        #     policy_id.requires_grad_()
        #     inputs = inputs + list(policy_id)

        grad = autograd.grad(
            outputs=interpolated_logit,
            inputs=inputs,
            grad_outputs=ones,
            create_graph=True,
            retain_graph=True,
            only_inputs=True,
        )[0]
        #grad = grad[0]

        grad_pen = (grad.norm(2, dim=1) - 1).pow(2).mean()
        return grad_pen
