from time import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributions

from .base_il_agent import BaseILAgent
from .dataset import DataReplayBuffer, RandomSampler
from .expert_dataset import *
from networks.discriminator import DiscriminatorEnsemble
from .expert_dataset import ExpertDataset
from utils.info_dict import Info
from utils.visual_tool import *
from utils.logger import logger
from utils.mpi import mpi_average
from utils.pytorch import (
    optimizer_cuda,
    count_parameters,
    sync_networks,
    sync_grads,
    to_tensor,
    weights_init,
    sample_from_dataloader,
    process_transition,
)
import copy
import numpy as np
import os

class ProxAgent(BaseILAgent):
    def __init__(self, config, ob_space, ac_space, env_ob_space, layout):
        super().__init__(config, ob_space, ac_space, env_ob_space, layout)

        self._reachable = DiscriminatorEnsemble(config.Dr_ensemble_size, config, ob_space, ac_space if config.reach_discriminator_input else None)
        self._reachable_optim = optim.Adam(self._reachable.parameters(), lr=config.discriminator_lr)
        self._reachable.to(config.device)

        ## Initialize reachable buffer with target task demos and unreachable buffer with other task demos
        sampler = RandomSampler(image_crop_size=config.encoder_image_size)
        buffer_keys = ["ob", "ob_next", "ac", "done", "rew", "traj_start_indicator"]
        self._reachable_buffer = DataReplayBuffer(
            buffer_keys, config.buffer_size, sampler.sample_func, is_reachable_buffer=True
        )
        self._unreachable_buffer = DataReplayBuffer(
            buffer_keys, config.buffer_size, sampler.sample_func
        )

        self._reach_discriminator_loss = nn.MSELoss()

        ### ASDF: rollout_length instead of buffer size?  2000 vs 1e6
        paths = glob.glob(os.path.join(self._config.log_dir, "*.pkl"))
        if len(paths)==0:
            _other_dataset = ExpertDataset(
                config.other_demo_path,
                config.demo_subsample_interval,
                ac_space,
                use_low_level=config.demo_low_level,
                sample_range_start=config.demo_sample_range_start,
                sample_range_end=config.demo_sample_range_end,
                num_task=config.num_task,
                target_taskID=None,
                with_taskID= config.with_taskID,
                frame_stack=self._config.frame_stack if config.encoder_type == "cnn" else None,
            )
            self._unreachable_buffer.store_transitions(_other_dataset._data)
            
            if config.add_more_demos2reachbuffer is not None:
                in_dataset = ExpertDataset(
                            self._config.target_demo_path,
                            self._config.demo_subsample_interval,
                            self._ac_space,
                            use_low_level=self._config.demo_low_level,
                            sample_range_start=self._config.demo_sample_range_start,
                            sample_range_end=self._config.demo_sample_range_end,
                            num_task=self._config.num_task,
                            target_taskID=self._config.target_taskID,
                            num_target_demos=self._config.add_more_demos2reachbuffer,
                            target_demo_path=self._config.target_demo_path,
                            is_sqil=(self._config.algo=="sqil"),
                            frame_stack=self._config.frame_stack if self._config.encoder_type == "cnn" else None,
                        )

            self._reachable_buffer.store_transitions(copy.deepcopy(self._dataset._data if config.add_more_demos2reachbuffer is None else in_dataset._data))
            
            # label the reward of reachable buffer as 1, and unreachable buffer as 0
            self._reachable_buffer._buffer['rew'] = [np.ones_like(self._reachable_buffer._buffer['rew'][i]) for i in range(len(self._reachable_buffer._buffer['rew']))]
            self._unreachable_buffer._buffer['rew'] = [np.zeros_like(self._unreachable_buffer._buffer['rew'][i]) for i in range(len(self._unreachable_buffer._buffer['rew']))]
        
        if self._config.gail_rl_algo=='ppo':
            self._num_batches_reach = (  # = 3
                        self._config.rollout_length
                        // self._config.batch_size
                        // self._config.reachability_discriminator_update_freq
                )
            assert self._num_batches_reach > 0

        ## Define Pretraining Counters
        self._pretrain_outer_loop = 0
        self._pretrain_inner_loop = 0

        ## test correlation between reachability and dvd reward
        self._correlation_dvd = []
        self._correlation_reach = []

        if self._config.backwards_relabelling and self._unreachable_buffer._current_size > 0:
            self._matches = []
            for i, item in enumerate(self._unreachable_buffer._buffer["traj_start_indicator"]):
                if item[0] == 1:
                    self._matches.append(i)
    
    def state_dict(self):
        return {
                "rl_agent": self._rl_agent.state_dict(),
                "ob_norm_state_dict": self._ob_norm.state_dict(),
                "reachable_state_dict": self._reachable.state_dict(),
                "reachable_optim_state_dict": self._reachable_optim.state_dict(),
            }

    def load_state_dict(self, ckpt):
        super().load_state_dict(ckpt)
        if "rl_agent" in ckpt:
            self._reachable.load_state_dict(ckpt["reachable_state_dict"])
            self._reachable.to(self._config.device)
            self._reachable_optim.load_state_dict(ckpt["reachable_optim_state_dict"])
            optimizer_cuda(self._reachable_optim, self._config.device)

    def sync_networks(self):
        super().sync_networks()
        sync_networks(self._reachable)


    def _sample_expert_data(self, task_id, num_samples=None):
        return sample_from_dataloader(self._data_loader[task_id], num_samples or self._config.batch_size, self._config.batch_size)

    def _sample_expert_traj(self, task_id, num_samples=None):
        return sample_from_dataloader(self._traj_loader[task_id], num_samples or self._config.batch_size, self._config.batch_size)


    def _preprocess_data(self, data, key="ob"):
        _to_tensor = lambda x: to_tensor(x, self._config.device)
        data_o = data[key]
        if "ob" in key:
            data_o = self.normalize(data_o)
        data_o = _to_tensor(data_o)
        return data_o

    def pretrain(self):
        ### Pretrain Reachable
        # Train D_r
        info = Info()

        reachable_data, _, _ = self._reachable_buffer.sample(self._config.batch_size)
        unreachable_data, _, _ = self._unreachable_buffer.sample(self._config.batch_size) # all other tasks
        _info = self.update_reachability_discriminator(reachable_data, unreachable_data)
        info.add(_info)
        self._pretrain_inner_loop += 1

        if self._pretrain_inner_loop % self._config.reachability_update_frequency == 0:
            # Re-evaluate unreachable data
            num_update_steps = self._config.reachability_update_steps
            if num_update_steps == -1:
                num_update_steps = self._unreachable_buffer._current_size // self._config.batch_size
            for _ in range(num_update_steps):
                data, episode_indx, t_samples = self._unreachable_buffer.sample(self._config.batch_size)
                self.relabel_reachability_dense(data, episode_indx, t_samples)

        if ((self._pretrain_outer_loop == 0 and self._pretrain_inner_loop == self._config.pretrain_warmup_inner_n_epochs)
                or (self._pretrain_outer_loop > 0 and self._pretrain_inner_loop == self._config.pretrain_inner_n_epochs)):
            self._pretrain_outer_loop += 1
            self._pretrain_inner_loop = 0
            
        done = self._pretrain_outer_loop > self._config.pretrain_n_epochs

        return mpi_average(info.get_dict(only_scalar=False)), done

    def relabel_reachability_dense(self, unreachable_data, episode_indx, t_samples):
        index = np.random.randint(len(episode_indx), size=1).item()
        ob = self._unreachable_buffer._buffer['ob_next'][episode_indx[index]][t_samples[index]]['ob']
        ob_verify = unreachable_data['ob_next']['ob'][index]
        assert np.all(ob == ob_verify), "ob != ob_verify: ep_index and t_index are not correct! unreachable_data mismatch with unreachable_buffer!" 
        
        o_next = self._preprocess_data(unreachable_data, key="ob_next")
        ac = None
        if self._config.reach_discriminator_input==2:
            ac = self._preprocess_data(unreachable_data, key="ac")
        elif self._config.reach_discriminator_input==1:
            _ac_list = []
            for e_idx, t_idx in zip(episode_indx, t_samples): # add ac for these transitions
                data1 = self._unreachable_buffer._buffer['ob_next'][e_idx][t_idx]['ob']
                if e_idx == self._unreachable_buffer._current_size - 1:
                    _ac_list.append(np.zeros_like(self._unreachable_buffer._buffer['ac'][e_idx][t_idx]['ac']))
                    continue
                data2 = self._unreachable_buffer._buffer['ob'][e_idx + 1][t_idx]['ob']
                
                if (data1 == data2).all():
                    _ac = self._unreachable_buffer._buffer['ac'][e_idx+1][t_idx]['ac']
                    _ac_list.append(_ac)
                else: # if the state is the last state, then the action is 0
                    _ac_list.append(np.zeros_like(self._unreachable_buffer._buffer['ac'][e_idx][t_idx]['ac']))
            ac = {"ac": np.array(_ac_list)}
            ac = self._preprocess_data(ac, key="ac")

        if ac is not None:
            logit, _ = self._reachable(o_next, ac)
        else:
            logit, _ = self._reachable(o_next)
        output = torch.sigmoid(logit).squeeze().detach().cpu().numpy()

        for idx, (ep_indx, t_indx) in enumerate(zip(episode_indx, t_samples)):
            new_rew = output[idx]
            self._unreachable_buffer._buffer['rew'][ep_indx][t_indx] = new_rew #if self._config.dense_reward_type is None else max(new_rew, self._unreachable_buffer._buffer['rew'][ep_indx][t_indx])
            
            if self._config.backwards_relabelling:
                if ep_indx in self._matches: # if the state is the beginning of a trajectory, then
                    continue
                # find the index of the first element in self._matches that is greater than ep_indx then the previous element is the index we want
                match_indx = np.searchsorted(self._matches, ep_indx, side='right') - 1
                ep_start_indx = self._matches[match_indx]

                def update_reward_range(start, end, step=1):
                    for i in range(start, end, -1):
                        new_rew = output[idx] - self._config.dense_reward_scale * step #output[idx]*np.float_power(0.99, step)
                        step += 1
                        if self._unreachable_buffer._buffer['rew'][i][t_indx] > self._config.relabel_skip_threshold:
                            continue
                        if new_rew < self._config.backward_relabel_threshold:
                            break
                        self._unreachable_buffer._buffer['rew'][i][t_indx] = new_rew #if self._config.dense_reward_type is None else max(new_rew, self._unreachable_buffer._buffer['rew'][ep_indx][t_indx])
                        
                    return step

                if ep_indx < ep_start_indx:
                    step = update_reward_range(ep_indx - 1, -1, 1)
                   
                    if step == (ep_indx + 1):
                        step = update_reward_range(self._config.buffer_size - 1, ep_start_indx - 1, step)
                else:
                    step = update_reward_range(ep_indx - 1, ep_start_indx - 1, step=1)

    def update_reachability_discriminator(self, reachable_data, unreachable_data):
        info = Info()

        reachable_o = self._preprocess_data(reachable_data)
        unreachable_o = self._preprocess_data(unreachable_data)

        if self._config.reach_discriminator_input:
            reachable_ac = self._preprocess_data(reachable_data, key="ac")
            unreachable_ac = self._preprocess_data(unreachable_data, key="ac")
        else:
            reachable_ac = None
            unreachable_ac = None

        pos_logit, _ = self._reachable(reachable_o, reachable_ac)
        neg_logit, _ = self._reachable(unreachable_o, unreachable_ac)

        pos_output = torch.sigmoid(pos_logit)
        neg_output = torch.sigmoid(neg_logit)

        rew_neg = self._preprocess_data(unreachable_data, key="rew")
        # resize rew_neg to match the shape of neg_logit (1, batch_size, 1)
        rew_neg = rew_neg.unsqueeze(dim=0)
        rew_neg = rew_neg.unsqueeze(dim=-1)
        neg_loss = self._reach_discriminator_loss(neg_output, rew_neg)
        pos_loss = self._reach_discriminator_loss(pos_output, torch.ones_like(pos_output).to(self._config.device))

        logits = torch.cat([neg_logit, pos_logit], dim=0)
        entropy = torch.distributions.Bernoulli(logits=logits).entropy().mean()
        entropy_loss = -self._config.gail_entropy_loss_coeff * entropy

        reachable_loss = neg_loss + pos_loss + entropy_loss

        self._reachable.zero_grad()
        reachable_loss.backward()
        sync_grads(self._reachable)
        self._reachable_optim.step()

        info["reachable_pos_output"] = pos_output.mean().detach().cpu().item()
        info["reachable_neg_output"] = neg_output.mean().detach().cpu().item()
        info["reachable_loss"] = reachable_loss.detach().cpu().item()

        pos_outputs = pos_output.mean(dim=0)
        neg_outputs = neg_output.mean(dim=0)
        for i in range(pos_outputs.shape[0]):
            info[f"reachable{i}_pos_output"] = pos_outputs[i]
            info[f"reachable{i}_neg_output"] = neg_outputs[i]

        return mpi_average(info.get_dict(only_scalar=True))

    def store_episode(self, rollouts, step=0):

        super().store_episode(rollouts)
        new_data = copy.deepcopy(rollouts)

        if self._config.is_DVD_relabel and step>1000+self._config.warm_up_steps:
            # move data to reach buffer if self._discriminator(s,a)>0.8
            from utils.pytorch import obs2tensor
            bs = len(new_data['ob']) if isinstance(new_data['ob'], list) else new_data['ob'].shape[0]

            expert_data_pos = self._sample_expert_traj(self._config.target_task_index_in_demo_path, num_samples=bs)
            e_o_pos = self._preprocess_data(expert_data_pos)
            ob = self.normalize(new_data['ob'])
            ob = obs2tensor(ob, self._config.device)
            e_ac_pos = self._preprocess_data(expert_data_pos, key="ac") if not self._config.gail_no_action else None
            ac = obs2tensor(new_data["ac"], self._config.device) if not self._config.gail_no_action else None

            with torch.no_grad():
                ret = torch.sigmoid(self._discriminator(ob, ac, e_o_pos, e_ac_pos))
                reachable = ret.squeeze(dim=-1) > self._config.dvd_relabel_threshold
                reachable_idx = np.where(reachable.cpu().numpy())[0]
                percent_reachable = len(reachable_idx) / len(reachable)
                if len(reachable_idx) != 0:
                    # store the data to reach buffer and remove them from new_data
                    reachable_data = process_transition(reachable_idx, new_data)

                    reachable_idx = list(reachable_idx)
                    reachable_idx.sort(reverse=True)
                    for idx in reachable_idx:
                        for k, v in new_data.items():
                                new_data[k].pop(idx)
                    self._reachable_buffer.store_transitions(reachable_data)
                    if len(reachable)==len(reachable_idx):
                        return {'DVD_relabel': percent_reachable}

        if self._config.presort_policy_samples and (self._config.pretrain_prox or (self._config.pretrain_prox==False and step>0)):
            # get obs from rollouts
            new_transitions = {}
            for k, v in new_data.items():
                if isinstance(v[0], dict):
                    sub_keys = v[0].keys()
                    new_transitions[k] = {
                        sub_key: np.stack([v_[sub_key] for v_ in v]) for sub_key in sub_keys
                    }
                else:
                    new_transitions[k] = np.stack(v)

            obs = self._preprocess_data(new_transitions, key="ob")
            if self._config.reach_discriminator_input:
                ac = self._preprocess_data(new_transitions, key="ac")
                logit, _ = self._reachable(obs, ac)
            else:
                logit, _ = self._reachable(obs)
            output = torch.sigmoid(logit)

            # Collate reachable transitions
            # TODO: add backwards trajectory looking
            # update rew
            output = output.reshape(-1)
            output = output.detach().cpu().numpy()
            new_transitions['rew'] = output

            _data = process_transition([i for i in range(output.shape[0])], new_transitions)
            self._unreachable_buffer.store_transitions(_data)

        else:
            new_data['rew'] = np.zeros_like(new_data['rew'])
            
            self._unreachable_buffer.store_transitions(new_data)
        
        if self._config.backwards_relabelling and self._unreachable_buffer._current_size > 0:
            self._matches = []
            for i, item in enumerate(self._unreachable_buffer._buffer["traj_start_indicator"]):
                if item[0] == 1:
                    self._matches.append(i)
            
        if self._config.is_DVD_relabel and step>1000+self._config.warm_up_steps:
            return {'DVD_relabel': percent_reachable}

    def train(self, step=0):
        train_info = Info()

        if self._config.gail_rl_algo!='ppo':
            self._num_batches_reach = (self._update_iter % self._config.reachability_discriminator_update_freq == 0)
            self._update_iter += 1

        for _ in range(self._num_batches_reach):
            reachable_data, _, _ = self._reachable_buffer.sample(self._config.batch_size)
            unreachable_data, _, _ = self._unreachable_buffer.sample(self._config.batch_size)  # all other tasks
            _info = self.update_reachability_discriminator(reachable_data, unreachable_data)
            train_info.add(_info)

        ### Re-label unreachable data
        relabel_time = time()
        if self._num_batches_reach and step % self._config.reachability_online_update_frequency == 0:
            num_update_steps = self._config.reachability_online_update_steps
            if num_update_steps == -1:
                num_update_steps = self._unreachable_buffer._current_size // self._config.batch_size
            for _ in range(num_update_steps):
                data, episode_indx, t_samples = self._unreachable_buffer.sample(self._config.batch_size)
                self.relabel_reachability_dense(data, episode_indx, t_samples)

        train_info.add({"relabel_time": time() - relabel_time})

        # train RL agent
        _train_info = self._rl_agent.train()
        train_info.add(_train_info)
        
        for _ in range(self._num_batches_reach):
            try:
                expert_data = next(self._data_iter)
            except StopIteration:
                self._data_iter = iter(self._data_loader)
                expert_data = next(self._data_iter)
            self.update_normalizer(expert_data["ob"])

        return train_info.get_dict(only_scalar=False)

    def _predict_reward(self, ob, ac, ID=None, *argv):
        if self._config.gail_no_action:
            ac = None

        with torch.no_grad():
            if self._config.reach_discriminator_input:
                reach, _ = self._reachable(ob, ac)
            else:
                reach, _ = self._reachable(ob)
            mean_reach = torch.mean(reach, dim=0)
            median_reach = torch.median(reach, dim=0).values
            reach = median_reach
            ### TODO: how to combine reachability reward?
            if self._config.reach_reward == "sum_negative":
                reach -= self._config.reach_reward_constant
                reward = torch.minimum(reach, torch.zeros_like(reach)) * self._config.reach_reward_scale

            # check the correlation bewteen reachability and dvd rewards
            # if gail_reward.shape[0] == 1:
            #     self._correlation_dvd.extend(gail_reward.cpu().numpy())
            #     self._correlation_reach.extend(reach_reward.cpu().numpy())
            # corr = None
            # if len(self._correlation_dvd) > 10000:
            #     _to_tensor = lambda x: to_tensor(x, self._config.device)
            #     dvd = np.array(self._correlation_dvd)
            #     reach = np.array(self._correlation_reach)
            #     corr = np.corrcoef(dvd, reach)[0, 1]
            #     corr = _to_tensor(corr)
            #     self._correlation_dvd = []
            #     self._correlation_reach = []

        reward_dict = {"rew": reward, "mean_reach": mean_reach, "median_reach": median_reach, "reach": reach, "sig_reach": torch.sigmoid(reach)}
        # if corr is not None:
        #     reward_dict["corr"] = corr
        return reward_dict

