# DVD compare traj. with (s,a)
from collections import OrderedDict

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.autograd as autograd
import torch.distributions
from torch.optim.lr_scheduler import StepLR

from .base_il_agent import BaseILAgent
from .dataset import ReplayBuffer, RandomSampler
from .expert_dataset import * 
from networks.discriminator import DiscriminatorV2
from utils.info_dict import Info
from utils.logger import logger
from utils.mpi import mpi_average
from utils.pytorch import (
    optimizer_cuda,
    count_parameters,
    sync_networks,
    sync_grads,
    to_tensor,
    sample_from_dataloader,
)

import numpy as np
from time import time
import wandb

class GAILAgentV2(BaseILAgent):
    def __init__(self, config, ob_space, ac_space, env_ob_space, layout):
        super().__init__(config, ob_space, ac_space, env_ob_space, layout)

        self._rl_agent.set_reward_function(self._predict_reward)

        self.encoder = None
        self._discriminator = DiscriminatorV2(
            config, ob_space, ac_space if not config.gail_no_action else None, self.encoder,
        )
        self._discriminator_loss = nn.BCEWithLogitsLoss()

        self._network_cuda(config.device)

        # build optimizers
        self._discriminator_optim = optim.Adam(
            self._discriminator.parameters(), lr=config.discriminator_lr
        )

        # build learning rate scheduler
        self._discriminator_lr_scheduler = StepLR(
            self._discriminator_optim,
            step_size=(self._config.max_global_step // self._config.rollout_length // 5) if config.gail_rl_algo=='ppo' else (self._config.max_global_step // 5),
            gamma=0.5,
        )

        # policy dataset
        sampler = RandomSampler(image_crop_size = config.encoder_image_size)
        if config.gail_rl_algo=='ppo':
            keys = ["ob", "ob_next", "ac", "done", "rew", "ret", "adv", "ac_before_activation", "success", "traj_start_indicator"]
        elif config.gail_rl_algo=='ddpg':
            keys = ["ob", "ob_next", "ac", "done", "done_mask", "rew"]
        else:
            keys = ["ob", "ob_next", "ac", "done", "rew"]

        if self._config.with_taskID:
            keys.append("id")

        self._buffer = ReplayBuffer(
            keys,
            config.rollout_length if config.gail_rl_algo=='ppo' else config.buffer_size,
            sampler.sample_func,
        )

        self._rl_agent.set_buffer(self._buffer)
        self._update_iter = 0
        if self._config.gail_rl_algo=='ppo':
            self._num_batches = (  # = 3
                        config.rollout_length
                        // config.batch_size
                        // config.discriminator_update_freq
                )
            assert self._num_batches > 0

        self._log_creation()
    
    def load_training_data(self, sampled_few_demo_index=None):
        # expert dataset
        if self._config.is_train:
            path = self._config.demo_path.split("#")
            self._data_dataset = []
            self._data_loader, self._data_iter = [], []
            self._traj_loader, self._traj_iter = [], []
            for _, _p in enumerate(path):
                data_dataset, data_loader, data_iter, _ = self._load_dataset(_p, sampled_few_demo_index=sampled_few_demo_index)
                self._data_dataset.append(data_dataset)
                self._data_loader.append(data_loader)
                self._data_iter.append(data_iter)
            
            self.sampled_few_demo_index = None
            for _dataset_ite in self._data_dataset:
                if _dataset_ite.sampled_few_demo_index is not None:
                    self.sampled_few_demo_index = _dataset_ite.sampled_few_demo_index
                    break

            for _, _p in enumerate(path):
                traj_dataset, traj_loader, traj_iter = self._load_traj(_p, sampled_few_demo_index=self.sampled_few_demo_index)
                self._traj_loader.append(traj_loader)
                self._traj_iter.append(traj_iter)
           
            if self._config.bc_loss_coeff > 0.0:
                self._rl_agent.set_bc_data_iter(self._data_iter[self._config.target_task_index_in_demo_path])
                self._rl_agent.set_bc_data_loader(self._data_loader[self._config.target_task_index_in_demo_path])

            # update observation normalizer with dataset
            self.update_normalizer()

        else: # this is for visualization of heatmaps. 
            path = self._config.demo_path.split("#")
            target_path = path[self._config.target_task_index_in_demo_path]
            self._data_dataset = []
            data_dataset, data_loader, data_iter, _ = self._load_dataset(target_path, sampled_few_demo_index=sampled_few_demo_index)
            self._data_dataset.append(data_dataset)
            
            self._traj_loader, self._traj_iter = [], []
            traj_dataset, traj_loader, traj_iter = self._load_traj(target_path, sampled_few_demo_index=sampled_few_demo_index)
            self._traj_loader.append(traj_loader)
            self._traj_iter.append(traj_iter)

    def _sample_expert_traj(self, task_id, num_samples=None):
        return sample_from_dataloader(self._traj_loader[task_id], num_samples or self._config.batch_size, self._config.batch_size)
    
    def _predict_reward(self, ob, ac, ID=None, *argv):
        _to_tensor = lambda x: to_tensor(x, self._config.device)

        if ob['ob'].dim() == 1: # predict reward for one sample in rollout.run()
            bs = 1
        else:
            bs = ob['ob'].shape[0] # predict reward for a batch of samples in ppo.store_episode()
        
        rew_list = []
        e_ac_pos = None
        for _ in range(self._config.average_num):
            expert_data_pos = self._sample_expert_traj(self._config.target_task_index_in_demo_path, num_samples=bs)
            e_o_pos = expert_data_pos["ob"]
            e_o_pos = self.normalize(e_o_pos)
            e_o_pos = _to_tensor(e_o_pos)

            e_ac_pos = expert_data_pos["ac"]
            e_ac_pos = _to_tensor(e_ac_pos)
            
            with torch.no_grad():
                ret = self._discriminator(ob, ac, e_o_pos, e_ac_pos)
                reward = ret
                rew_list.append(reward)

        if bs > 1:
            # if self._config.average_num==1:
            reward_dict = {"rew": rew_list[0]}
            # else:
            #     rew_list = torch.stack(rew_list, dim=1)
            #     reward_dict = {"rew": rew_list.mean(dim=1)}
        else:
            rew_list = torch.stack(rew_list, dim=0)
            reward_dict = {"rew": rew_list.mean(dim=0)}
        return reward_dict
    
    def predict_reward(self, ob, ac=None, ID=None, *argv):
        ob = self.normalize(ob)
        ob = to_tensor(ob, self._config.device)
        ac = to_tensor(ac, self._config.device)

        reward_dict = self._predict_reward(ob, ac, ID, argv)
        for k, v in reward_dict.items():
            if torch.numel(v) == 1:
                reward_dict[k] = v.cpu().item()
            else:
                reward_dict[k] = v.cpu().numpy()
        return reward_dict
    
    def _log_creation(self):
        if self._config.is_chef:
            logger.info("Creating a DVD agent")
            logger.info(
                "The discriminator has %d parameters",
                count_parameters(self._discriminator),
            )

    def store_episode(self, rollouts, step=0):
        rew_gail, rew_reach = self._rl_agent.store_episode(rollouts, step=step)
        return rew_gail, rew_reach

    def state_dict(self):
        return {
            "rl_agent": self._rl_agent.state_dict(),
            "discriminator_state_dict": self._discriminator.state_dict(),
            "discriminator_optim_state_dict": self._discriminator_optim.state_dict(),
            "ob_norm_state_dict": self._ob_norm.state_dict(),
        }

    def load_state_dict(self, ckpt):
        if "rl_agent" in ckpt:
            self._rl_agent.load_state_dict(ckpt["rl_agent"])
        else:
            self._rl_agent.load_state_dict(ckpt)
            if self._config.pretrained_encoder == 'vae':
                self._discriminator.encoder.copy_conv_weights_from(self._rl_agent._actor.encoder)
            self._network_cuda(self._config.device)
            return

        self._discriminator.load_state_dict(ckpt["discriminator_state_dict"])
        self._ob_norm.load_state_dict(ckpt["ob_norm_state_dict"])
        self._network_cuda(self._config.device)

        self._discriminator_optim.load_state_dict(
            ckpt["discriminator_optim_state_dict"]
        )

        optimizer_cuda(self._discriminator_optim, self._config.device)

    def _network_cuda(self, device):
        self._discriminator.to(device)

    def sync_networks(self):
        self._rl_agent.sync_networks()
        sync_networks(self._discriminator)

    def update_normalizer(self, obs=None):
        """ Updates normalizers for discriminator and PPO agent. """
        if self._config.ob_norm:
            if obs is None: 
                ob_data_loader_list, ob_data_iter_list = [], []
                for task_id in range(self._config.num_task):
                    ob_data_loader = torch.utils.data.DataLoader(
                        self._data_dataset[task_id], # randomly select one
                        batch_size=self._config.batch_size,
                        shuffle=False,
                        drop_last=False,
                    )
                    ob_data_iter = iter(ob_data_loader)
                    ob_data_loader_list.append(ob_data_loader)
                    ob_data_iter_list.append(ob_data_iter)
                
                for _ in range(self._num_batches):
                    ob_index = np.random.choice(self._config.num_task, size=1)[0]
                    try:
                        expert_data = next(ob_data_iter_list[ob_index])
                    except StopIteration:
                        ob_data_iter_list[ob_index] = iter(ob_data_loader_list[ob_index])
                        expert_data = next(ob_data_iter_list[ob_index])
                    obs = expert_data["ob"]
                    super().update_normalizer(obs)
                    self._rl_agent.update_normalizer(obs)
            else:
                super().update_normalizer(obs)
                self._rl_agent.update_normalizer(obs)

    def _blend(self, policy_data, other_expert_data, blend_ratio):
        """ Blend expert and policy data. """
        other_demo_num = int(self._config.batch_size*blend_ratio)
        new_transitions = {}
        for k, v in other_expert_data.items():
            if isinstance(v, dict):
                sub_keys = v.keys()
                new_transitions[k] = {
                    sub_key: np.concatenate((policy_data[k][sub_key][:other_demo_num], v[sub_key][other_demo_num:])) for sub_key in sub_keys
                }
            else:
                new_transitions[k] = np.concatenate((policy_data[k][:other_demo_num], v[other_demo_num:]))

            # shuffle policy data

        return new_transitions

    def load_data(self, data_iter, data_loader, index):
        try:
            return next(data_iter[index])
        except StopIteration:
            data_iter[index] = iter(data_loader[index])
            return next(data_iter[index])
    
    def train(self, step=0):
        train_info = Info()

        # if self._config.gail_rl_algo!='ppo':
        #     self._num_batches = (self._update_iter % self._config.discriminator_update_freq == 0)
        #     self._update_iter += 1
        
        if not self._config.is_frozen:
            task_indices = np.random.choice(self._config.num_task, size=2, replace=False)
            pos_index, neg_index = task_indices
            
            for _ in range(self._num_batches):
                policy_data = self._buffer.sample(self._config.batch_size)

                expert_data_pos = self.load_data(self._data_iter, self._data_loader, pos_index)
                expert_data_neg = self.load_data(self._data_iter, self._data_loader, neg_index)
                expert_traj_pos = self.load_data(self._traj_iter, self._traj_loader, pos_index)
                expert_traj_neg = self.load_data(self._traj_iter, self._traj_loader, neg_index)
                
                _train_info = self._update_discriminator(policy_data, expert_data_pos, expert_data_neg, expert_traj_pos, expert_traj_neg, step=step)
                train_info.add(_train_info)
                
            self._discriminator_lr_scheduler.step()


        _train_info = self._rl_agent.train()
        train_info.add(_train_info)
        
        for _ in range(self._num_batches):
            ob_index = np.random.choice(self._config.num_task, size=1)[0]
            expert_data = self.load_data(self._data_iter, self._data_loader, ob_index)
            self.update_normalizer(expert_data["ob"])

        return train_info.get_dict(only_scalar=True)

    def _update_discriminator(self, policy_data, expert_data_pos, expert_data_neg, expert_traj_pos, expert_traj_neg, step=0):
        info = Info()

        _to_tensor = lambda x: to_tensor(x, self._config.device)
        # pre-process observations
        p_o = policy_data["ob"]
        p_o = self.normalize(p_o)
        p_o = _to_tensor(p_o)

        e_o_pos = expert_data_pos["ob"]
        e_o_pos = self.normalize(e_o_pos)
        e_o_pos = _to_tensor(e_o_pos)

        e_o_neg = expert_data_neg["ob"]
        e_o_neg = self.normalize(e_o_neg)
        e_o_neg = _to_tensor(e_o_neg)

        e_o_traj_pos = expert_traj_pos["ob"]
        e_o_traj_pos = self.normalize(e_o_traj_pos)
        e_o_traj_pos = _to_tensor(e_o_traj_pos)

        e_o_traj_neg = expert_traj_neg["ob"]
        e_o_traj_neg = self.normalize(e_o_traj_neg)
        e_o_traj_neg = _to_tensor(e_o_traj_neg)

        p_ac = _to_tensor(policy_data["ac"])
        e_ac_pos = _to_tensor(expert_data_pos["ac"])
        e_ac_neg = _to_tensor(expert_data_neg["ac"])
        e_ac_traj_pos = _to_tensor(expert_traj_pos["ac"])
        e_ac_traj_neg = _to_tensor(expert_traj_neg["ac"])

        p_logit1 = self._discriminator(p_o, p_ac, e_o_traj_pos, e_ac_traj_pos)
        p_logit2 = self._discriminator(p_o, p_ac, e_o_traj_neg, e_ac_traj_neg)

        e_logit_neg1 = self._discriminator(e_o_pos, e_ac_pos, e_o_traj_neg, e_ac_traj_neg)
        e_logit_neg2 = self._discriminator(e_o_neg, e_ac_neg, e_o_traj_pos, e_ac_traj_pos)

        # positive pair
        e_logit_pos1 = self._discriminator(e_o_pos, e_ac_pos, e_o_traj_pos, e_ac_traj_pos)
        e_logit_pos2 = self._discriminator(e_o_neg, e_ac_neg, e_o_traj_neg, e_ac_traj_neg)

        p_output1 = torch.sigmoid(p_logit1)#, torch.sigmoid(p_logit2)
        e_output_neg1 = torch.sigmoid(e_logit_neg1)#, torch.sigmoid(e_logit_neg2)
        e_output_pos1 = torch.sigmoid(e_logit_pos1)#, torch.sigmoid(e_logit_pos2)

        neg_logit = torch.cat([p_logit1, p_logit2, e_logit_neg1, e_logit_neg2], dim=0)
        pos_logit = torch.cat([e_logit_pos1, e_logit_pos2], dim=0)
        
        # expert?
        # neg_logit = torch.cat([p_logit1, p_logit2], dim=0)
        # pos_logit = torch.cat([e_logit_pos1, e_logit_pos2, e_logit_neg1, e_logit_neg2], dim=0)

        #same task?
        # neg_logit = torch.cat([e_logit_neg1, e_logit_neg2], dim=0)
        # pos_logit = torch.cat([e_logit_pos1, e_logit_pos2], dim=0)

        neg_loss = self._discriminator_loss(neg_logit, torch.zeros_like(neg_logit).to(self._config.device))
        pos_loss = self._discriminator_loss(pos_logit, torch.ones_like(pos_logit).to(self._config.device))

        logits = torch.cat([neg_logit, pos_logit], dim=0)
        entropy = torch.distributions.Bernoulli(logits=logits).entropy().mean()
        # entropy_loss = -self._config.gail_entropy_loss_coeff * entropy

        #grad_pen = self._compute_grad_pen(p_o, p_ac, e_o_pos, e_ac_pos, pretrain_discriminator)
        #grad_pen_loss = self._config.gail_grad_penalty_coeff * grad_pen

        gail_loss = neg_loss + pos_loss# + entropy_loss + grad_pen_loss

        # update the discriminator
        self._discriminator.zero_grad()
        gail_loss.backward()
        sync_grads(self._discriminator)
        self._discriminator_optim.step()

        info["gail_pos_output"] = e_output_pos1.mean().detach().cpu().item()
        info["gail_neg_output"] = e_output_neg1.mean().detach().cpu().item()
        info["gail_neg2_output"] = p_output1.mean().detach().cpu().item()
        info["gail_entropy"] = entropy.detach().cpu().item()
        info["gail_neg_loss"] = neg_loss.detach().cpu().item()
        info["gail_pos_loss"] = pos_loss.detach().cpu().item()
        info["gail_loss"] = gail_loss.detach().cpu().item()

        return mpi_average(info.get_dict(only_scalar=True))

    def _compute_grad_pen(self, policy_ob, policy_ac, expert_ob, expert_ac, pretrain_discriminator=False):
        batch_size = self._config.batch_size
        alpha = torch.rand(batch_size, 1, device=self._config.device)

        def blend_dict(a, b, alpha):
            if isinstance(a, dict):
                return OrderedDict(
                    [(k, blend_dict(a[k], b[k], alpha)) for k in a.keys()]
                )
            elif isinstance(a, list):
                return [blend_dict(a[i], b[i], alpha) for i in range(len(a))]
            else:
                expanded_alpha = alpha.expand_as(a)
                ret = expanded_alpha * a + (1 - expanded_alpha) * b
                ret.requires_grad = True
                return ret

        interpolated_ob = blend_dict(policy_ob, expert_ob, alpha)
        inputs = list(interpolated_ob.values())
        if policy_ac is not None:
            interpolated_ac = blend_dict(policy_ac, expert_ac, alpha)
            inputs = inputs + list(interpolated_ob.values())
        else:
            interpolated_ac = None

        interpolated_logit = self._discriminator(policy_ob, policy_ac, expert_ob, expert_ac)
        zeros = torch.zeros(interpolated_logit.size(), device=self._config.device)

        grad = autograd.grad(
            outputs=interpolated_logit,
            inputs=inputs,
            grad_outputs=zeros,
            create_graph=True,
            retain_graph=True,
            only_inputs=True,
        )[0]

        grad_pen = (grad.norm(2, dim=1) - 1).pow(2).mean()
        return grad_pen

    def pretrain_dis(self):
        num_iter = self._config.pre_dvd_step#len(self._data_dataset[self._config.target_task_index_in_demo_path]) // self._config.batch_size
        for s_ in range(num_iter):
            task_indices = np.random.choice(self._config.num_task, size=3, replace=False)
            pos_index, neg_index, neg_index2 = task_indices
            expert_data_pos = self.load_data(self._data_iter, self._data_loader, pos_index)
            expert_data_neg = self.load_data(self._data_iter, self._data_loader, neg_index)
            expert_data_neg2 = self.load_data(self._data_iter, self._data_loader, neg_index2)
            expert_traj_pos = self.load_data(self._traj_iter, self._traj_loader, pos_index)
            expert_traj_neg = self.load_data(self._traj_iter, self._traj_loader, neg_index)

            _train_info = self._update_discriminator(expert_data_neg2, expert_data_pos, expert_data_neg, expert_traj_pos, expert_traj_neg, step=s_)
            for k, v in _train_info.items():
                if np.isscalar(v) or (hasattr(v, "shape") and np.prod(v.shape) == 1):
                    wandb.log({"pretrain_dis/%s" % k: v}, step=s_)
        logger.info("Pretrain discriminator finished")
