#!/usr/bin/env python3
import numpy as np
import torch
import os
from tqdm import tqdm

os.environ["MUJOCO_GL"] = "egl"

from replay_buffer_diverse import ReplayBuffer
from reward_model_pb2 import PopulationRewardModel
from collections import deque, defaultdict
from discriminator import Discriminator

import utils
import hydra
import wandb

from functools import partial
import gym

class Workspace(object):
    def __init__(self, cfg):
        if cfg.wandb:
            wandb.init(project=cfg.wandb_project, entity=cfg.wandb_entity, config=cfg)
            wandb.run.name = f"{cfg.env}_{cfg.experiment}_{cfg.max_feedback}_{cfg.num_interact}_{cfg.seed}_beta_{cfg.agent.params.beta_init}_utd_{cfg.gradient_update}"
        self.work_dir = os.getcwd()
        print(f'workspace: {self.work_dir}')

        self.reward_tracker = utils.LogRes(f"{cfg.env}_{cfg.experiment}_{cfg.max_feedback}_{cfg.num_interact}_{cfg.seed}_{cfg.teacher_eps_equal}_{cfg.agent.params.beta_init}_{cfg.population_size}_{cfg.copy_agent}_{cfg.disc.on_policy}")
        self.reward_tracker.set_step(0)

        self.cfg = cfg

        utils.set_seed_everywhere(cfg.seed)
        self.device = torch.device(cfg.device)
        self.log_success = False

        self.pop_size = cfg.population_size
        self.obs_since_feedback = [[] for _ in range(self.pop_size)]
        self.infos_since_feedback = []

        # make env
        if 'metaworld' in cfg.env:
            raise NotImplementedError
        else:
            self.envs = [utils.make_env(cfg) for _ in range(self.pop_size)]
            self.env = utils.make_env(cfg)  # Keep single env for evaluation
            
        cfg.agent.params.obs_dim = self.env.observation_space.shape[0]
        cfg.agent.params.action_dim = self.env.action_space.shape[0]
        cfg.agent.params.action_range = [
            float(self.env.action_space.low.min()),
            float(self.env.action_space.high.max())
        ]

        # Instantiate population
        self.agents = [hydra.utils.instantiate(cfg.agent) for _ in range(self.pop_size)]
        for i in range(self.pop_size):
            self.agents[i].index = i

        self.replay_buffers = [ReplayBuffer(
            self.env.observation_space.shape,
            self.env.action_space.shape,
            int(cfg.replay_buffer_capacity),
            self.device,
            max_episode_len=self.env._max_episode_steps,
        ) for _ in range(self.pop_size)]

        # for logging
        self.total_feedback = 0
        self.labeled_feedback = 0
        self.step = 0

        # instantiating the discriminator
        discriminator = Discriminator(state_dim=cfg.agent.params.obs_dim,
                                    num_latents=self.pop_size,
                                    hidden_size=cfg.disc.hidden_size,
                                    learning_rate=cfg.disc.lr,
                                    layernorm=cfg.disc.layernorm,
                                    device=cfg.device)
        
        self.discriminator = discriminator
        self.latents = discriminator.latents
        self.disc = False
        
        # instantiating the reward model
        self.reward_model = PopulationRewardModel(
            self.env.observation_space.shape[0],
            self.env.action_space.shape[0],
            device=cfg.device,
            ensemble_size=cfg.ensemble_size,
            size_segment=cfg.segment,
            max_size=cfg.max_reward_buffer_size,
            activation=cfg.activation,
            lr=cfg.reward_lr,
            mb_size=cfg.reward_batch,
            large_batch=cfg.large_batch,
            label_margin=cfg.label_margin,
            teacher_beta=cfg.teacher_beta,
            teacher_gamma=cfg.teacher_gamma,
            teacher_eps_mistake=cfg.teacher_eps_mistake,
            teacher_eps_skip=cfg.teacher_eps_skip,
            teacher_eps_equal=cfg.teacher_eps_equal,
            data_aug_ratio=cfg.data_aug_ratio,
            hidden_size=cfg.rm_hidden_size,
            num_layers=cfg.rm_num_layers,
            pop_size=cfg.population_size,
            tpa=cfg.tpa,
            replay_buffer=self.replay_buffers[0],
            disc=self.discriminator,
        )
    
        
    def evaluate(self, agent_index):
        average_episode_reward = 0
        average_true_episode_reward = 0
        average_intrinsic_reward = 0
        success_rate = 0
        intrinsic_rewards = []
        
        for episode in range(self.cfg.num_eval_episodes):
            obs = self.env.reset()
            self.agents[agent_index].reset()
            done = False
            episode_reward = 0
            intrinsic_reward = 0
            true_episode_reward = 0
            if self.log_success:
                episode_success = 0

            while not done:
                with utils.eval_mode(self.agents[agent_index]):
                    action = self.agents[agent_index].act(obs, sample=False)
                episode_reward += self.reward_model.r_hat(np.concatenate([obs, action], axis=-1))
                step_intrinsic_reward = self.discriminator.compute_intrinsic_reward(obs, agent_index) if self.disc else 0
                intrinsic_reward += step_intrinsic_reward
                intrinsic_rewards.append(step_intrinsic_reward)

                obs, reward, done, extra = self.env.step(action)
                
                true_episode_reward += reward
                if self.log_success:
                    episode_success = max(episode_success, extra["success"])
            average_episode_reward += episode_reward
            average_true_episode_reward += true_episode_reward
            average_intrinsic_reward += intrinsic_reward

            if self.log_success:
                success_rate += episode_success
            
        average_episode_reward /= self.cfg.num_eval_episodes
        average_true_episode_reward /= self.cfg.num_eval_episodes
        average_intrinsic_reward /= self.cfg.num_eval_episodes
        if self.log_success:
            success_rate /= self.cfg.num_eval_episodes
            success_rate *= 100.0

        return average_true_episode_reward, average_episode_reward, average_intrinsic_reward, intrinsic_rewards, success_rate


    def learn_reward(self, first_flag=0):
                
        # get feedbacks
        labeled_queries, noisy_queries = 0, 0
        # self.reward_model.size_segment += 5
        if first_flag == 1:
            # if it is first time to get feedback, need to use random sampling
            labeled_queries = self.reward_model.uniform_sampling()
        else:
            labeled_queries = self.reward_model.uniform_sampling(self.cfg.explore)

        self.total_feedback += self.reward_model.mb_size
        self.labeled_feedback += labeled_queries

        train_acc = 0
        if self.labeled_feedback > 0:
            # update reward
            if 'metaworld' in self.cfg.env:
                for _ in range(self.cfg.reward_update):
                    train_acc, reward_loss = self.reward_model.train_reward()
                    total_acc = np.mean(train_acc)
                    if total_acc > 0.97:
                        break;
            else:
                num_iters = int(np.ceil(self.cfg.reward_update*self.labeled_feedback/self.reward_model.train_batch_size))
                train_acc, reward_loss, num_iters, diversity_accuracy = self.reward_model.train_reward_iter(num_iters)
                total_acc = np.mean(train_acc)
            self.obs_since_feedback = [[] for _ in range(self.pop_size)]
            self.infos_since_feedback = []
            
            self.reward_model.last_put_sa_t_1 = []
            self.reward_model.last_put_sa_t_2 = []
            self.reward_model.last_put_infos_t_1 = []
            self.reward_model.last_put_infos_t_2 = []
            self.reward_model.last_put_labels = []

            if self.cfg.wandb:
                wandb.log(
                    {
                        "train/reward_model_accuracy": total_acc,
                        "train/reward_model_loss": reward_loss,
                        "train/num_iters": num_iters,
                        "train/diversity_accuracy": diversity_accuracy
                    }
                )
        
        if self.cfg.wandb:
            wandb.log({"train/total_feedback": self.total_feedback})

    def run(self):
        episodes, episode_steps, episode_rewards = [[0]*self.pop_size for _ in range(3)]
        dones = [True]*self.pop_size
        obses = [env.reset() for env in self.envs]
        last_update = 0

        if self.log_success:
            episode_success = 0

        true_episode_rewards = [0]*self.pop_size


        # store train returns of recent 10 episodes
        avg_train_true_return = [deque([], maxlen=10) for _ in range(self.pop_size)]
        avg_train_predicted_return = [deque([], maxlen=5) for _ in range(self.pop_size)]


        interact_count = 0
        with tqdm(total=self.cfg.num_train_steps) as pbar:
            while self.step < self.cfg.num_train_steps:
                for i, is_done in enumerate(dones):
                    if is_done:
                        obses[i] = self.envs[i].reset()
                        self.agents[i].reset()
                        dones[i] = False
                        avg_train_true_return[i].append(true_episode_rewards[i])
                        avg_train_predicted_return[i].append(episode_rewards[i])
                        episode_rewards[i] = 0
                        true_episode_rewards[i] = 0
                        if self.log_success:
                            episode_success = 0
                        episode_steps[i] = 0
                        episodes[i] += 1
                                
                # evaluate agent periodically
                if self.step > 0 and self.step % self.cfg.eval_frequency == 0:
                      
                    # Collect population evaluation metrics
                    pop_rewards = []
                    pop_rewards_hat = []
                    pop_intrinsic_rewards = []
                    pop_success_rates = []
                    pop_intrinsic = []

                    for i in range(self.pop_size):
                        ep_reward, ep_reward_hat, ep_intrinsic_reward, intrinsic_rewards, success_rate = self.evaluate(i)
                        pop_rewards.append(ep_reward)
                        pop_rewards_hat.append(ep_reward_hat)
                        pop_intrinsic_rewards.append(ep_intrinsic_reward)
                        pop_success_rates.append(success_rate)
                        pop_intrinsic.append(intrinsic_rewards)

                        if self.cfg.wandb:
                            if self.log_success:
                                wandb.log(
                                    {
                                        f"true_episode_reward_/true_episode_reward_{i}": ep_reward,
                                        f"episode_reward/episode_reward_{i}": ep_reward_hat,
                                        f"intrinsic_reward/ep_intrinsic_reward_{i}": ep_intrinsic_reward,
                                        f"success_rate/success_rate_{i}": success_rate,
                                    }
                                )
                            else:
                                wandb.log(
                                    {
                                        f"true_episode_reward/true_episode_reward_{i}": ep_reward,
                                        f"intrinsic_reward/ep_intrinsic_reward_{i}": ep_intrinsic_reward,
                                        f"episode_reward/episode_reward_{i}": ep_reward_hat,
                                    }
                                )
                                
                    self.reward_tracker.set_step(self.step)
                    self.reward_tracker.log({
                        "true_episode_reward": ep_reward,
                        "episode_reward": ep_reward_hat,
                    })
                        
                    # Calculate population statistics
                    pop_rewards = np.array(pop_rewards)
                    pop_rewards_hat = np.array(pop_rewards_hat)
                    pop_intrinsic_rewards = np.array(pop_intrinsic_rewards)
                    pop_success_rates = np.array(pop_success_rates)

                    current_beta = self.agents[1].beta(self.step)

                    # Log population metrics in separate categories
                    wandb.log({
                        'population_true_reward/mean': float(np.mean(pop_rewards)),
                        'population_true_reward/std': float(np.std(pop_rewards)),
                        'population_true_reward/min': float(np.min(pop_rewards)),
                        'population_true_reward/max': float(np.max(pop_rewards)),
                        'population_true_reward/reward_gap': float(np.max(pop_rewards) - np.min(pop_rewards)),
                        
                        'population_predicted/mean': float(np.mean(pop_rewards_hat)),
                        'population_predicted/std': float(np.std(pop_rewards_hat)),
                        'population_predicted/min': float(np.min(pop_rewards_hat)),
                        'population_predicted/max': float(np.max(pop_rewards_hat)),
                        
                        'population_intrinsic/episode_mean': float(np.mean(pop_intrinsic_rewards)),
                        'population_intrinsic/episode_std': float(np.std(pop_intrinsic_rewards)),
                        'population_intrinsic/episode_min': float(np.min(pop_intrinsic_rewards)),
                        'population_intrinsic/episode_max': float(np.max(pop_intrinsic_rewards)),
                        
                        'population_intrinsic/mean': float(np.mean(pop_intrinsic)),
                        'population_intrinsic/std': float(np.std(pop_intrinsic)),
                        'population_intrinsic/min': float(np.min(pop_intrinsic)),
                        'population_intrinsic/max': float(np.max(pop_intrinsic)),                    
                    })

                    wandb.log({"eval/step": self.step})
                    wandb.log({"diversity/beta": current_beta})

                # sample action for data collection
                if self.step < self.cfg.num_seed_steps:
                    actions = [env.action_space.sample() for env in self.envs]
                else:
                    actions = []
                    for agent, ob in zip(self.agents, obses):
                        # Set each agent to evaluation mode before selecting actions
                        with utils.eval_mode(agent):
                            action = agent.act(ob, sample=True)  # Get action for each agent
                            actions.append(action)  # Store actions for each agent

                # run training update                
                if self.step == (self.cfg.num_seed_steps + self.cfg.num_unsup_steps):
                    # update schedule
                    if self.cfg.reward_schedule == 1:
                        frac = (self.cfg.num_train_steps-self.step) / self.cfg.num_train_steps
                        if frac == 0:
                            frac = 0.01
                    elif self.cfg.reward_schedule == 2:
                        frac = self.cfg.num_train_steps / (self.cfg.num_train_steps-self.step +1)
                    else:
                        frac = 1
                    self.reward_model.change_batch(frac)
                    
                                        
                    # first learn reward
                    self.learn_reward(first_flag=1)
                    
                    # first learn discriminator
                    self.disc = True
                    self.discriminator.training = True

                    for i in range(self.pop_size):
                        # relabel buffer
                        self.replay_buffers[i].relabel_with_predictor(self.reward_model, self.discriminator, agent_index=i)

                        # reset Q due to unsuperivsed exploration
                        self.agents[i].reset_critic()

                        # update agent
                        self.agents[i].update_after_reset(
                            self.replay_buffers[i],
                            self.step,
                            gradient_update=self.cfg.reset_update,
                            policy_update=True,
                        )
                  
                    # reset interact_count
                    interact_count = 0

                elif self.step > self.cfg.num_seed_steps + self.cfg.num_unsup_steps:
                    # update reward function
                    if self.total_feedback < self.cfg.max_feedback:
                        if interact_count == self.cfg.num_interact:
                            # update schedule
                            if self.cfg.reward_schedule == 1:
                                frac = (self.cfg.num_train_steps-self.step) / self.cfg.num_train_steps
                                if frac == 0:
                                    frac = 0.01
                            elif self.cfg.reward_schedule == 2:
                                frac = self.cfg.num_train_steps / (self.cfg.num_train_steps-self.step +1)
                            else:
                                frac = 1
                            self.reward_model.change_batch(frac)
                            
                            if self.reward_model.teacher_eps_equal > 0:
                                self.reward_model.teacher_eps_equal -= self.cfg.teacher_eps_equal / (self.cfg.max_feedback // self.cfg.reward_batch)

                            
                            # corner case: new total feed > max feed
                            if self.reward_model.mb_size + self.total_feedback > self.cfg.max_feedback:
                                self.reward_model.set_batch(self.cfg.max_feedback - self.total_feedback)
                                
                            self.learn_reward()

                            if self.cfg.copy_agent:
                                for i in range(1, self.pop_size):
                                    utils.copy_agent_target(self.agents[0], self.agents[i])
                            
                            # Relabel replay buffers with new reward model
                            for i in range(self.pop_size):
                                self.replay_buffers[i].relabel_with_predictor(self.reward_model, self.discriminator, agent_index=i)

                                # Stop smerl when feedbacks are over
                                if self.total_feedback == self.cfg.max_feedback and self.cfg.reset_feedback_done:
                                    self.replay_buffers[i].training = False
                            
                            interact_count = 0

                    
                    if (self.total_feedback < self.cfg.max_feedback):
                        size = min(self.cfg.max_reward_buffer_size, len(self.reward_model.inputs[0])-1)
                        for i in range(self.pop_size):
                            self.agents[i].update_onpolicy_sample(
                                self.replay_buffers[i], self.step, size, gradient_update=self.cfg.gradient_update, her_ratio=self.cfg.her_ratio)
                    else:
                        for i in range(self.pop_size):
                            self.agents[i].update(self.replay_buffers[i], self.step, gradient_update=self.cfg.gradient_update)

                    if self.replay_buffers[1].training and self.step % 500 == 0:
                        wandb.log({f'train/training_episodes':self.replay_buffers[1].num_episodes})
                        wandb.log({f'train/total_episodes':self.replay_buffers[1].total_episodes})
                        for index in range(self.pop_size):
                            wandb.log({f'train_SAC_{index}/diverse_episodes':self.replay_buffers[index].is_diverse})

                # unsupervised exploration
                elif self.step > self.cfg.num_seed_steps:
                    for i in range(self.pop_size):
                        self.agents[i].update_state_ent(
                            self.replay_buffers[i],
                            self.step,
                            gradient_update=self.cfg.gradient_update,
                            K=self.cfg.topK,
                        )

                anchor_return = np.mean(avg_train_predicted_return[0])
                # Update discriminator:
                if self.disc:
                    batch_size = self.cfg.disc.batch_size // self.pop_size
                    if not self.cfg.disc.on_policy:
                        samples = [self.replay_buffers[i].sample(batch_size) for i in range(self.pop_size)]
                        sampled_states = torch.cat([samples[i][0] for i in range(self.pop_size)])
                        sampled_actions = torch.cat([samples[i][1] for i in range(self.pop_size)])
                    else:
                        size_on = min(
                            self.cfg.max_reward_buffer_size,
                            len(self.reward_model.inputs[0]) - 1,
                        )
                        samples = [self.replay_buffers[i].sample_onpolicy(batch_size, size_on) for i in range(self.pop_size)]
                        sampled_states = torch.cat([samples[i][0] for i in range(self.pop_size)])
                        sampled_actions = torch.cat([samples[i][1] for i in range(self.pop_size)])

                    critic_targets = [agent.critic_target for agent in self.agents]
                    self.discriminator.update(sampled_states, sampled_actions, critic_targets, self.step)


                next_obses, rewards, dones, extras = [], [], [], []
                for i in range(self.pop_size):
                    next_obs, reward, done, extra = self.envs[i].step(actions[i])
                    next_obses.append(next_obs) 
                    rewards.append(reward)
                    dones.append(done)
                    extras.append(extra)

                reward_hats = []
                intrinsic_rewards = []
                for i in range(self.pop_size):
                    inputs = np.concatenate([obses[i], actions[i]], axis=-1)
                    extrinsic_reward = self.reward_model.r_hat(inputs)
                    intrinsic_reward = self.discriminator.compute_intrinsic_reward(obses[i], z_index=i) if self.disc else 0
                    reward_hats.append(extrinsic_reward)
                    intrinsic_rewards.append(intrinsic_reward)

                # allow infinite bootstrap
                dones = [float(done) for done in dones]
                done_no_maxs = [(
                    0 if episode_steps[i] + 1 == self.envs[i]._max_episode_steps else dones[i]
                ) for i in range(self.pop_size)]
                
                for i in range(self.pop_size):
                    episode_rewards[i] += reward_hats[i]
                    true_episode_rewards[i] += rewards[i]
                
                if self.log_success:
                    episode_success = max(episode_success, extras[i]["success"])
                    
                # Warm anchor:
                if self.step % 200 == 0:
                    wandb.log({"train/cold_anchor": anchor_return})
                    wandb.log({"train/interact_count": interact_count})
                
                # No diversity in unsupervised exploration
                if interact_count < self.cfg.num_interact // 2 or self.total_feedback < self.reward_model.mb_size:
                    anchor_return = 1e3

                # adding data to the reward training data
                for i in range(self.pop_size):
                    self.obs_since_feedback[i].append(obses[i])
                    self.reward_model.add_data(obses[i], actions[i], rewards[i], dones[i], index=i, info=extras[0])
                    self.replay_buffers[i].add(
                        obses[i], actions[i], reward_hats[i], intrinsic_rewards[i], 
                        next_obses[i], dones[i], done_no_maxs[i], extras[0],
                        anchor_return=anchor_return if i > 0 else None
                    )
                
                self.infos_since_feedback.append(extras[0])

                if self.step % 500 == 0:
                    if anchor_return != 1e3:
                        wandb.log({"train/anchor_return": anchor_return})
                    else:
                        wandb.log({"train/anchor_return": 0})

                obses = next_obses
                episode_steps = [step + 1 for step in episode_steps]
                self.step += 1
                interact_count += 1
                pbar.update(1)
            
        for agent in self.agents:
            agent.save(self.work_dir, self.step)
        self.reward_model.save(self.work_dir, self.step)

        
@hydra.main(config_path='config/train_DPB2.yaml', strict=True)
def main(cfg):
    workspace = Workspace(cfg)
    workspace.run()

if __name__ == '__main__':
    main()