# from collections import OrderedDict
# import gzip
# import pickle

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
#import torch.autograd as autograd
import torch.distributions
#from torch.optim.lr_scheduler import StepLR

from .gail_agent_v2 import GAILAgentV2
from .dataset import DataReplayBuffer, RandomSampler
from .expert_dataset import *
from networks.discriminator import DiscriminatorEnsemble
from utils.info_dict import Info
from utils.visual_tool import *
from utils.logger import logger
from utils.mpi import mpi_average
from utils.pytorch import (
    optimizer_cuda,
    # count_parameters,
    sync_networks,
    sync_grads,
    to_tensor,
    # weights_init,
    sample_from_dataloader,
    process_transition,
)
import copy
import numpy as np
from time import time
import os

def get_demo_files(demo_file_path):
        demos = []
        if not demo_file_path.endswith(".pkl"):
            demo_file_path = demo_file_path + "*.pkl"
        for f in glob.glob(demo_file_path):
            if os.path.isfile(f):
                demos.append(f)
        return demos

class ReachableGAILAgent(GAILAgentV2):

    def __init__(self, config, ob_space, ac_space, env_ob_space, layout):
        super().__init__(config, ob_space, ac_space, env_ob_space, layout)

        self._reachable = DiscriminatorEnsemble(config.Dr_ensemble_size, config, ob_space, ac_space if config.reach_discriminator_input else None)
        self._reachable_optim = optim.Adam(self._reachable.parameters(), lr=config.discriminator_lr)
        self._reachable.to(config.device)
        self.layout = layout

        ## Initialize reachable buffer with target task demos and unreachable buffer with other task demos
        sampler = RandomSampler(image_crop_size=config.encoder_image_size)
        self.buffer_keys = ["ob", "ob_next", "ac", "done", "rew", "traj_start_indicator"]
        self._reachable_buffer = DataReplayBuffer(
            self.buffer_keys, config.buffer_size, sampler.sample_func, is_reachable_buffer=True
        )
        self._unreachable_buffer = DataReplayBuffer(
            self.buffer_keys, config.buffer_size, sampler.sample_func
        )
        self._reach_discriminator_loss = nn.MSELoss()
        
        if self._config.gail_rl_algo=='ppo':
            self._num_batches_reach = (  # = 3
                        self._config.rollout_length
                        // self._config.batch_size
                        // self._config.reachability_discriminator_update_freq
                )
            assert self._num_batches_reach > 0

        ## Define Pretraining Counters
        self._pretrain_outer_loop = 0
        self._pretrain_inner_loop = 0
        self.label_start = 0

    def load_training_data(self, sampled_few_demo_index=None):
        super().load_training_data(sampled_few_demo_index)
         ### ASDF: rollout_length instead of buffer size?  2000 vs 1e6
        paths = glob.glob(os.path.join(self._config.log_dir, "*.pkl"))
        if len(paths)==0:
            for i, data in enumerate(self._data_dataset):
                # deepcopy data so that we don't mess up the original dataset
                data_copy = copy.deepcopy(data._data)

                if i == self._config.target_task_index_in_demo_path:
                    self._reachable_buffer.store_transitions(data_copy)
                else:
                    self._unreachable_buffer.store_transitions(data_copy)
            
            self._reachable_buffer._buffer['rew'] = [np.ones_like(self._reachable_buffer._buffer['rew'][i]) for i in range(len(self._reachable_buffer._buffer['rew']))]
            self._unreachable_buffer._buffer['rew'] = [np.zeros_like(self._unreachable_buffer._buffer['rew'][i]) for i in range(len(self._unreachable_buffer._buffer['rew']))]

            # if self._config.is_decay_K:
            #     traj_start_indicator = np.array(self._unreachable_buffer._buffer["traj_start_indicator"])
            #     _matches = np.where(traj_start_indicator[:, 0] == 1)[0].tolist()
            #     self._config.k = np.mean(np.diff(np.array(_matches)))//self._config.k_divide
            #     self.first_update_step = len(self._unreachable_buffer._buffer['ob']) - self._config.rollout_length
        
        if self._unreachable_buffer._current_size > 0:
            self._matches = []
            self._matches = list(np.flatnonzero(self._unreachable_buffer._buffer["traj_start_indicator"]))
    
    def state_dict(self):
        return {
                "rl_agent": self._rl_agent.state_dict(),
                "discriminator_state_dict": self._discriminator.state_dict(),
                "discriminator_optim_state_dict": self._discriminator_optim.state_dict(),
                "ob_norm_state_dict": self._ob_norm.state_dict(),
                "reachable_state_dict": self._reachable.state_dict(),
                "reachable_optim_state_dict": self._reachable_optim.state_dict(),
            }

    def load_state_dict(self, ckpt):
        super().load_state_dict(ckpt)
        if "rl_agent" in ckpt:
            self._reachable.load_state_dict(ckpt["reachable_state_dict"])
            self._reachable.to(self._config.device)
            self._reachable_optim.load_state_dict(ckpt["reachable_optim_state_dict"])
            optimizer_cuda(self._reachable_optim, self._config.device)

    def sync_networks(self):
        super().sync_networks()
        sync_networks(self._reachable)

    def _sample_expert_data(self, task_id, num_samples=None):
        return sample_from_dataloader(self._data_loader[task_id], num_samples or self._config.batch_size, self._config.batch_size)

    def _sample_expert_traj(self, task_id, num_samples=None):
        return sample_from_dataloader(self._traj_loader[task_id], num_samples or self._config.batch_size, self._config.batch_size)

    def _preprocess_data(self, data, key="ob"):
        _to_tensor = lambda x: to_tensor(x, self._config.device)
        data_o = data[key]
        if "ob" in key:
            data_o = self.normalize(data_o)
        data_o = _to_tensor(data_o)
        return data_o

    def pretrain(self):
        ### Pretrain Reachable
        # Train D_r
        info = Info()

        reachable_data, _, _ = self._reachable_buffer.sample(self._config.batch_size)
        unreachable_data, _, _ = self._unreachable_buffer.sample(self._config.batch_size) # all other tasks
        _info = self.update_reachability_discriminator(reachable_data, unreachable_data)
        info.add(_info)
        self._pretrain_inner_loop += 1

        if self._pretrain_inner_loop % self._config.reachability_update_frequency == 0: # currently relabel once in the inner loop/every 50 steps
            # Re-evaluate unreachable data
            num_update_steps = self._config.reachability_update_steps
            if num_update_steps == -1:
                num_update_steps = self._unreachable_buffer._current_size // self._config.batch_size
            for _ in range(num_update_steps): # relabel the whole unreach buffer
                data, episode_indx, t_samples = self._unreachable_buffer.sample(self._config.batch_size)
                self.relabel_reachability_dense(data, episode_indx, t_samples)

            # self.relabel_unreach_buffer(target_k=self._config.k)

        if ((self._pretrain_outer_loop == 0 and self._pretrain_inner_loop == self._config.pretrain_warmup_inner_n_epochs)
                or (self._pretrain_outer_loop > 0 and self._pretrain_inner_loop == self._config.pretrain_inner_n_epochs)):
            self._pretrain_outer_loop += 1
            self._pretrain_inner_loop = 0
       
        done = self._pretrain_outer_loop > self._config.pretrain_n_epochs
        return mpi_average(info.get_dict(only_scalar=False)), done

    def relabel_reachability_dense(self, unreachable_data, episode_indx, t_samples):
        # index = np.random.randint(len(episode_indx), size=1).item()
        # ob = self._unreachable_buffer._buffer['ob_next'][episode_indx[index]][t_samples[index]]['ob']
        # ob_verify = unreachable_data['ob_next']['ob'][index]
        # assert np.all(ob == ob_verify), "ob != ob_verify: ep_index and t_index are not correct! unreachable_data mismatch with unreachable_buffer!" 
        
        o_next = self._preprocess_data(unreachable_data, key="ob_next")
        logit, _ = self._reachable(o_next)
        output = torch.sigmoid(logit).squeeze().detach().cpu().numpy()

        for idx, (ep_indx, t_indx) in enumerate(zip(episode_indx, t_samples)):
            new_rew = output[idx]
            self._unreachable_buffer._buffer['rew'][ep_indx][t_indx] = new_rew #if self._config.dense_reward_type is None else max(new_rew, self._unreachable_buffer._buffer['rew'][ep_indx][t_indx])
            
            if self._config.backwards_relabelling:
                if ep_indx in self._matches: # if the state is the beginning of a trajectory, then
                    continue
                # find the index of the first element in self._matches that is greater than ep_indx then the previous element is the index we want
                match_indx = np.searchsorted(self._matches, ep_indx, side='right') - 1
                ep_start_indx = self._matches[match_indx]

                def update_reward_range(start, end, step=1):
                    for i in range(start, end, -1):
                        new_rew = output[idx] - self._config.dense_reward_scale * step #output[idx]*np.float_power(0.99, step)
                        step += 1
                        # if self._unreachable_buffer._buffer['rew'][i][t_indx] > self._config.relabel_skip_threshold:
                        #     continue
                        if new_rew < self._config.backward_relabel_threshold:
                            break
                        self._unreachable_buffer._buffer['rew'][i][t_indx] = new_rew #if self._config.dense_reward_type is None else max(new_rew, self._unreachable_buffer._buffer['rew'][ep_indx][t_indx])
                        
                    return step

                if ep_indx < ep_start_indx:
                    step = update_reward_range(ep_indx - 1, -1, 1)
                   
                    if step == (ep_indx + 1):
                        step = update_reward_range(self._config.buffer_size - 1, ep_start_indx - 1, step)
                else:
                    step = update_reward_range(ep_indx - 1, ep_start_indx - 1, step=1)

    def relabel_unreach_buffer(self, step=0, target_k=None):
        T = len(self._unreachable_buffer._buffer['ob'])
        _to_tensor = lambda x: to_tensor(x, self._config.device)
        next_is_end_state = (self._unreachable_buffer._buffer['traj_start_indicator'][T-1][0] == 1)
        
        # if target_k is not None:
        #     K_target = target_k
        # elif self._config.is_decay_K:
        #     if step == self.first_update_step or step % 400000 == 0:
        #         traj_start_indicator = np.array(self._unreachable_buffer._buffer["traj_start_indicator"])
        #         _matches = np.where(traj_start_indicator[:, 0] == 1)[0].tolist()
        #         # average length of trajectories in the unreach buffer
        #         avg_traj_len = np.mean(np.diff(np.array(_matches)))
        #         K_target = int(avg_traj_len) // self._config.k_divide
        #         self._config.k = K_target
        #     else:
        #         K_target = self._config.k
        # else:
        K_target = self._config.k
        
        rew_index = []
        K = K_target
        for t in reversed(range(T)):
            if t == (T-1) or next_is_end_state or K==0:
                K = K_target
                rew_index.append(t)

            next_is_end_state = (self._unreachable_buffer._buffer['traj_start_indicator'][t][0] == 1)
            K -= 1
        
        # pick all the obs and acs from buffer that meets the condition "if t == (T-1) or next_is_end_state or K==0:"
        outputs = np.zeros(len(rew_index))
        with torch.no_grad():
            for t in range(0, len(rew_index), 128):
                chunk_size = min(128, len(rew_index) - t)
                temp_index = rew_index[t:t+chunk_size]

                ob = {'ob': np.stack([self._unreachable_buffer._buffer['ob_next'][xx][0]['ob'] for xx in temp_index])}
                ob = self.normalize(ob)
                ob = _to_tensor(ob)
               
                logit, _ = self._reachable(ob)
                temp_output = torch.sigmoid(logit).squeeze().detach().cpu().numpy()
                
                outputs[t:t + chunk_size] = temp_output
        
        
        K = K_target
        idx = 0
        for t in reversed(range(T)):
            if t == (T-1) or next_is_end_state or K==0:
                output = outputs[idx]
                idx += 1
                step = 0
                K = K_target
            
            next_is_end_state = (self._unreachable_buffer._buffer['traj_start_indicator'][t][0] == 1)

            new_rew = output - self._config.dense_reward_scale * step #if self._config.decay_method=='linear' else output * np.float_power(0.99, step)
            
            # if self._unreachable_buffer._buffer['rew'][t][0] > self._config.relabel_skip_threshold:
            #     pass
            if new_rew < self._config.backward_relabel_threshold:
                new_rew = 0
            self._unreachable_buffer._buffer['rew'][t][0] = new_rew
            
            step += 1
            K -= 1


    def update_reachability_discriminator(self, reachable_data, unreachable_data):
        info = Info()

        reachable_o = self._preprocess_data(reachable_data)
        unreachable_o = self._preprocess_data(unreachable_data)

        reachable_ac = None
        unreachable_ac = None

        pos_logit, _ = self._reachable(reachable_o, reachable_ac)
        neg_logit, _ = self._reachable(unreachable_o, unreachable_ac)

        pos_output = torch.sigmoid(pos_logit)
        neg_output = torch.sigmoid(neg_logit)

        rew_neg = self._preprocess_data(unreachable_data, key="rew")
        # resize rew_neg to match the shape of neg_logit (1, batch_size, 1)
        rew_neg = rew_neg.unsqueeze(dim=0)
        rew_neg = rew_neg.unsqueeze(dim=-1)
        neg_loss = self._reach_discriminator_loss(neg_output, rew_neg)
        pos_loss = self._reach_discriminator_loss(pos_output, torch.ones_like(pos_output).to(self._config.device))

        # logits = torch.cat([neg_logit, pos_logit], dim=0)
        # entropy = torch.distributions.Bernoulli(logits=logits).entropy().mean()
        # entropy_loss = -self._config.gail_entropy_loss_coeff * entropy

        # regularizer
        # cubic_penalties = torch.pow(neg_output, 3)
        # mean_penalty = torch.mean(cubic_penalties)
        # reg_loss = self._config.reach_reg_coeff * self._reach_discriminator_loss(neg_output, torch.ones_like(pos_output).to(self._config.device))

        reachable_loss = neg_loss + pos_loss #- reg_loss# + entropy_loss

        self._reachable.zero_grad()
        reachable_loss.backward()
        sync_grads(self._reachable)
        self._reachable_optim.step()

        info["reachable_pos_output"] = pos_output.mean().detach().cpu().item()
        info["reachable_neg_output"] = neg_output.mean().detach().cpu().item()
        info["reachable_loss"] = reachable_loss.detach().cpu().item()
        info["reachable_neg_loss"] = neg_loss.detach().cpu().item()
        info["reachable_pos_loss"] = pos_loss.detach().cpu().item()
        # info["reg_loss"] = reg_loss.detach().cpu().item()

        # pos_outputs = pos_output.mean(dim=0)
        # neg_outputs = neg_output.mean(dim=0)
        # for i in range(pos_outputs.shape[0]):
        #     info[f"reachable{i}_pos_output"] = pos_outputs[i]
        #     info[f"reachable{i}_neg_output"] = neg_outputs[i]

        return mpi_average(info.get_dict(only_scalar=True))

    def store_episode(self, rollouts, step=0):
        rew_gail, rew_reach = super().store_episode(rollouts)
        new_data = copy.deepcopy(rollouts)
        
        if self._config.is_DVD_relabel and step > 1000 + self._config.warm_up_steps:
            if self.label_start == 1: # the first state of the trajectory is moved to the reach buffer in last iteration
                new_data['traj_start_indicator'][0] = 1
                self.label_start = 0

            threshold = self._config.dvd_relabel_threshold
            # if self._config.is_dvd_relabel_anneal and step > 3500000: #TODO: change this to a better way
            #     threshold = self._config.dvd_relabel_threshold - 0.6 * (step - 1000) / (self._config.max_global_step - 1000)
            
            reachable = torch.sigmoid(rew_gail) > threshold
            reachable_idx = np.where(reachable.cpu().numpy())[0]
            percent_reachable = len(reachable_idx) / len(reachable)
            if len(reachable_idx) != 0:
                # store the data to reach buffer and remove them from new_data
                reachable_data = process_transition(reachable_idx, new_data)
                self._reachable_buffer.store_transitions(reachable_data)
                
                reachable_idx = list(reachable_idx)
                reachable_idx.sort(reverse=True)
                for idx in reachable_idx:
                    # this state is a start of a trajectory, make the next state a start after removing this state
                    if new_data['traj_start_indicator'][idx] == 1:
                        if idx+1 < len(new_data['traj_start_indicator']):
                            new_data['traj_start_indicator'][idx+1] = 1
                        else:
                            self.label_start = 1 #the first state of the trajectory is moved to the reach buffer, in next iteration, the fist state will be the start of the trajectory
                    
                    for k, _ in new_data.items():
                            new_data[k].pop(idx)
                
                if len(reachable)==len(reachable_idx):
                    return {'DVD_relabel': percent_reachable}


        ### pre-sort the data
        # Collate reachable transitions
        # TODO: add backwards trajectory looking
        # update rew
        output = torch.sigmoid(rew_reach).reshape(-1).detach().cpu().numpy()
        new_data['rew'] = output
        self._unreachable_buffer.store_transitions(new_data)

        
        self._matches = []
        self._matches = list(np.flatnonzero(self._unreachable_buffer._buffer["traj_start_indicator"]))

        if self._config.is_DVD_relabel and step>1000+self._config.warm_up_steps:
            return {'DVD_relabel': percent_reachable}

    def train(self, step=0):
        train_info = Info()

        # train GAIL discriminator
        if not self._config.is_frozen:
            self._discriminator_lr_scheduler.step()

            # if self._config.gail_rl_algo!='ppo':
            #     self._num_batches = (self._update_iter % self._config.discriminator_update_freq == 0)
            #     self._num_batches_reach = (self._update_iter % self._config.reachability_discriminator_update_freq == 0)
            #     self._update_iter += 1

            task_indices = np.random.choice(self._config.num_task, size=2, replace=False)
            pos_index, neg_index = task_indices
            for _ in range(self._num_batches):
                policy_data = self._buffer.sample(self._config.batch_size)
                
                expert_data_pos = self._sample_expert_data(pos_index)
                expert_data_neg = self._sample_expert_data(neg_index)
                expert_traj_pos = self._sample_expert_traj(pos_index)
                expert_traj_neg = self._sample_expert_traj(neg_index)
                
                _train_info = self._update_discriminator(policy_data, expert_data_pos, expert_data_neg, expert_traj_pos,
                                                        expert_traj_neg, step=step)
                
                train_info.add(_train_info)
        
        ### Re-label unreachable buffer
        # if self._num_batches_reach and step % self._config.reachability_online_update_frequency == 0:
        #     self.relabel_unreach_buffer(step)
        #     train_info.add({'K': self._config.k})
        
        ### train proximity model
        for _ in range(self._num_batches_reach):
            reachable_data, _, _ = self._reachable_buffer.sample(self._config.batch_size)
            unreachable_data, _, _ = self._unreachable_buffer.sample(self._config.batch_size)  # all other tasks
            _info = self.update_reachability_discriminator(reachable_data, unreachable_data)
            train_info.add(_info)
        
        ### Re-label unreachable data
        if self._num_batches_reach and step % self._config.reachability_online_update_frequency == 0:
            num_update_steps = self._config.reachability_online_update_steps # 4
            if num_update_steps == -1:
                num_update_steps = self._unreachable_buffer._current_size // self._config.batch_size
            for _ in range(num_update_steps):
                data, episode_indx, t_samples = self._unreachable_buffer.sample(self._config.batch_size)
                self.relabel_reachability_dense(data, episode_indx, t_samples)
        
        # train RL agent
        _train_info = self._rl_agent.train()
        train_info.add(_train_info)

        for _ in range(self._num_batches_reach):
            try:
                ob_index = np.random.randint(self._config.num_task, size=1).item()
                expert_data = next(self._data_iter[ob_index])
            except StopIteration:
                self._data_iter[ob_index] = iter(self._data_loader[ob_index])
                expert_data = next(self._data_iter[ob_index])
            self.update_normalizer(expert_data["ob"])

        return train_info.get_dict(only_scalar=False)

    def _predict_reward(self, ob, ac, ID=None, *argv):
        _to_tensor = lambda x: to_tensor(x, self._config.device)
        if ob['ob'].dim() == 1:
            bs = 1
        else:
            bs = ob['ob'].shape[0]

        expert_data_pos = self._sample_expert_traj(self._config.target_task_index_in_demo_path, num_samples=bs)
        e_o_pos = self._preprocess_data(expert_data_pos)
        e_ac_pos = self._preprocess_data(expert_data_pos, key="ac")

        with torch.no_grad():

            ### GAIL Discriminator reward
            ret = self._discriminator(ob, ac, e_o_pos, e_ac_pos) # pre-sigmoid
            s = torch.sigmoid(ret)
            gail_reward = ret

            reward = gail_reward * self._config.gail_reward_scale
            ### Reachability Discriminator
            reach, _ = self._reachable(ob)
            
            ### TODO: how to combine reachability reward?
            if self._config.reach_reward == "sum":
                reach_reward = (reach - self._config.reach_reward_constant) * self._config.reach_reward_scale
                reward = reward + reach_reward
            elif self._config.reach_reward == "sum_negative":
                # reach -= self._config.reach_reward_constant # reach rewards are all positive, so subtract 10 to make them negative
                reach_reward = torch.minimum(reach, torch.zeros_like(reach)) * self._config.reach_reward_scale
                reward = reward + reach_reward # only penalize unreachable states
            elif self._config.reach_reward == "tier_sum":
                reach_reward = torch.multiply(torch.sigmoid(reach) - 1, s < self._config.gail_discriminator_threshold) * self._config.reach_reward_scale
                # sets reach_reward to 0 if s < self._config.gail_discriminator_threshold
                gail_reward = s
                reward = s + reach_reward
            elif self._config.reach_reward == "reach_only":
                # reach -= self._config.reach_reward_constant
                reach_reward = reach
                reward = reach
            elif self._config.reach_reward == "post_sigmoid":
                sig_reach = torch.sigmoid(reach)
                reach_reward = (sig_reach - self._config.reach_reward_constant) * self._config.reach_reward_scale
                reward = reward + reach_reward


        reward_dict = {"rew": reward, "gail_rew": gail_reward, "reach_rew": reach_reward, "reach": reach}
        return reward_dict

