"""
Agent is something which converts states into actions and has state
"""

import copy

import numpy as np
import torch
import random
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
from . import actions

from torch.autograd import Variable
import collections
from torch.distributions import Normal

def extract(a, t, x_shape):
    b, *_ = t.shape
    out = a.gather(-1, t)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))

def default_states_preprocessor(states):
    """
    Convert list of states into the form suitable for model.
    :param states: list of numpy arrays with states
    :return: Variable
    """
    if len(states) == 1:
        np_states = np.expand_dims(states[0], 0)
    else:
        np_states = np.array([np.array(s, copy=False) for s in states], copy=False)
    return torch.tensor(np_states)


def float32_preprocessor(states):
    np_states = np.array(states, dtype=np.float32)
    return torch.tensor(np_states)  


class MOOF_QL_AGENT:

    def __init__(self, actor, critic, value_function, device, behavior_prior, args, preprocessor=float32_preprocessor):
        
        super().__init__()

        self.args = args
        self.preprocessor = preprocessor
        self.device = device
        self.actor = actor
        self.critic = critic
        self.value_function = value_function
        self.actor_target = copy.deepcopy(self.actor)
        self.critic_target = copy.deepcopy(self.critic)        
        self.behavior_prior = behavior_prior
        self.behavior_optimizer = torch.optim.Adam(behavior_prior.parameters()) if behavior_prior is not None else None

        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=self.args.lr_actor)
        if critic:
            self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=self.args.lr_critic)
            self.vf_optimizer = torch.optim.Adam(self.value_function.parameters(), lr=self.args.lr_critic)

        if self.args.lr_decay:
            self.actor_lr_scheduler = CosineAnnealingLR(self.actor_optimizer, T_max=self.args.time_steps//1000, eta_min=0.)
            self.critic_lr_scheduler = CosineAnnealingLR(self.critic_optimizer, T_max=self.args.time_steps//1000, eta_min=0.)
            self.vf_lr_scheduler = CosineAnnealingLR(self.vf_optimizer, T_max=self.args.time_steps//1000, eta_min=0.)

        self.state_size = args.obs_shape
        self.action_size =  args.action_shape
        self.reward_size =  args.reward_size
        self.preference = None
        self.total_it = 0
        self.args = args
        self.weight_num = args.weight_num
        self.max_action = args.max_action[0]
        self.gamma = args.gamma
        self.tau  = args.tau
        self.policy_noise = args.policy_noise * self.max_action
        self.noise_clip = args.noise_clip * self.max_action
        self.policy_freq = args.policy_freq
        self.deterministic = False
        self.w_batch = []
        self.clip_grad_max_norm = 5 if self.args.algo=='Diffusion-QL' else 100 

    def sample_actions(self, state, pref, actor=None, deterministic = False, need_log_pi = False):
        if actor is None:
            actor = self.actor

        if self.args.algo=='IQL':
            mean, log_std = actor(state, pref)
            dist = Normal(mean, log_std.exp())
            actions = dist.rsample()  # for reparameterization trick (mean + std * N(0,1))
            log_prob = dist.log_prob(actions)
            actions = actions.clamp(-self.max_action,self.max_action)
            if need_log_pi:
                return actions, log_prob, dist
            else:
                return mean.clamp(-self.max_action,self.max_action) if deterministic else actions
        elif self.args.algo=='Diffusion-QL':
            return actor(state, pref)
        elif self.args.algo=='TD3+BC' or self.args.algo=='BC':
            if self.args.algo=='BC':
                pref[:] = 0
            actions = actor(state, pref)
            noise = torch.randn_like(actions).to(self.device)
            noise_clip = (noise*self.policy_noise).clamp(-self.noise_clip, self.noise_clip)
            noise_actions = (actions + noise_clip).clamp(-self.max_action,self.max_action)
            return actions if deterministic else noise_actions
            
            

    def __call__(self, states, preference, deterministic = False):
        # Set type for states
        if self.preprocessor is not None:
            states = self.preprocessor(states)
            if torch.is_tensor(states):
                states = states.to(self.device)
       
        # Choose action for a given policy 
        if self.args.algo=='Diffusion-QL' or self.args.algo=='IQL': 
            batch_size = states.shape[0]
            repeat_sample_num = 50
            state_rpt = torch.repeat_interleave(states, repeats=repeat_sample_num, dim=0)
            pref_rpt = torch.repeat_interleave(preference, repeats=repeat_sample_num, dim=0)
            with torch.no_grad():
                if self.args.algo=='Diffusion-QL':
                    actions = self.actor_target.sample(state_rpt, pref_rpt)  
                elif self.args.algo=='IQL': 
                    actions = self.sample_actions(state_rpt, pref_rpt, self.actor, deterministic=False)
                q1, q2 = self.critic_target(state_rpt, pref_rpt, actions)
                q1 = torch.bmm(pref_rpt.unsqueeze(1), q1.unsqueeze(2)).squeeze()
                q2 = torch.bmm(pref_rpt.unsqueeze(1), q2.unsqueeze(2)).squeeze()
                q_value = torch.min(q1, q2)
                q_value = q_value.reshape(batch_size, repeat_sample_num)
                idx = torch.argmax(q_value, dim=1, keepdim=True)
                idx = idx.unsqueeze(-1).expand(-1, -1, self.action_size)
                actions = actions.reshape(batch_size, repeat_sample_num, self.action_size)
                actions = torch.gather(actions, dim=1, index=idx)
            actions = actions.squeeze(1).cpu().data.numpy()
        else:
            with torch.no_grad():
                actions = self.sample_actions(states, preference, self.actor, deterministic=deterministic).cpu().data.numpy()
        return actions

    def cal_bc_loss(self, policy_action, behavior_action, state, pref):
        if self.args.algo=='Diffusion-QL':
            bc_loss = self.actor.loss(behavior_action, state, pref)
        elif self.args.algo=='CVAE-QL':
            recon, mean, std = self.actor(state, pref, behavior_action)
            recon_loss = F.mse_loss(recon, behavior_action)
            KL_loss    = -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean()
            bc_loss = recon_loss + 0.5 * KL_loss
        elif self.args.algo=='TD3+BC' or self.args.algo=='BC':
            bc_loss = torch.square(policy_action-behavior_action).mean(-1)
        else:
            bc_loss = torch.zeros(len(policy_action)).to(self.device)
        return bc_loss
    
    def get_vf_tderror(self, state, action, reward, next_state, w_batch, not_done):
        if self.args.algo=='IQL':
            target_Q = self.critic_target.Q1(state, w_batch, action)
            current_V = self.value_function(state, w_batch)
            td_error = 0.5 * torch.square(current_V-target_Q)
        else:
            v = self.value_function(state, w_batch)
            target_v = self.gamma * not_done * self.value_function(next_state, w_batch) + reward
            td_error = 0.5 * torch.square(target_v-v)
        return td_error
    
    def get_critic_tderror(self, state, action, reward, next_state, w_batch, not_done):
        q1, q2 = self.critic_target(state, w_batch, action)
        q = self.min_qvec_operator(q1, q2, w_batch)
        target_v = self.gamma * not_done * self.value_function(next_state, w_batch) + reward
        td_error = 0.5 * torch.square(target_v-q)
        return td_error
    
    def get_log_action_prob(self, state, action, w_batch):
        if self.args.algo=='IQL':
            _, _, dist = self.sample_actions(state, w_batch, self.actor, need_log_pi=True)
            log_prob = dist.log_prob(action).sum(dim=1) 
        elif self.args.algo=='Diffusion-QL':  # approximate the log prob of diffusion policies with variational lower bound 
            REPEAT_SAMPLE_NUM = 32  
            state = torch.repeat_interleave(state, repeats=REPEAT_SAMPLE_NUM, dim=0)
            action = torch.repeat_interleave(action, repeats=REPEAT_SAMPLE_NUM, dim=0)
            w_batch = torch.repeat_interleave(w_batch, repeats=REPEAT_SAMPLE_NUM, dim=0)
            state_pref = torch.cat([state, w_batch], dim=1)
            t = torch.randint(0, self.actor_target.n_timesteps, (len(action),), device=self.args.device).long()  
            noise = torch.randn_like(action)
            x_noisy = self.actor_target.q_sample(x_start=action, t=t, noise=noise)  
            x_recon = self.actor_target.predict_start_from_noise(x_noisy, t=t, noise=self.actor_target.model(x_noisy, t, state_pref))
            x_recon.clamp_(-self.max_action, self.max_action)
            log_prob = -torch.square(x_recon-action).sum(1) 
            log_prob = log_prob.reshape(-1, REPEAT_SAMPLE_NUM).mean(dim=1)
        elif self.args.algo=='TD3+BC':
            agent_action = self.sample_actions(state, w_batch, self.actor, deterministic = True)
            dist = Normal(agent_action, torch.ones_like(agent_action).to(self.device)*self.policy_noise)
            log_prob = dist.log_prob(action).sum(dim=1) 
        return log_prob

    def min_qvec_operator(self, q1, q2, w_obj):
        min_Q = torch.min(q1, q2)
        return min_Q

    # Learn from batch
    def train_regularized_policy(self, replay_buffer, writer):
        self.writer = writer
        batch_size = self.args.batch_size
        self.total_it += 1
        batch = replay_buffer.sample(batch_size) 
        state_batch, action_batch, next_state_batch, reward_batch, not_done, w_obj_batch = batch
 
        with torch.no_grad():
            # Compute the target Q value
            noise_next_action_batch = self.sample_actions(next_state_batch, w_obj_batch, self.actor_target, deterministic = False)
            target_Q1, target_Q2 = self.critic_target(next_state_batch, w_obj_batch, noise_next_action_batch)
            target_Q = self.min_qvec_operator(target_Q1, target_Q2, w_obj_batch)
            target_Q = reward_batch + not_done * self.gamma * target_Q 
        # Get current Q values
        current_Q1, current_Q2 = self.critic(state_batch, w_obj_batch, action_batch)
        critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)
        
        # Optimize the critic
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm=self.clip_grad_max_norm) 
        self.critic_optimizer.step()

        # Delayed policy updates
        if self.total_it % self.policy_freq == 0:
            # Compute Actor Loss
            policy_action = self.sample_actions(state_batch, w_obj_batch, actor=self.actor, deterministic=True)
            Q1, Q2 = self.critic(state_batch, w_obj_batch, policy_action)    
            Q = Q1 if np.random.uniform() > 0.5 else Q2
            bc_term = self.cal_bc_loss(policy_action, action_batch, state_batch, w_obj_batch)
            wQ = torch.bmm(w_obj_batch.unsqueeze(1), Q.unsqueeze(2)).squeeze()
            lmbda = 1/wQ.abs().mean().detach()  
            weight = 0 if self.args.weight_bc_loss==0 else 1.0 / self.args.weight_bc_loss
            actor_loss = -weight * lmbda * wQ.mean()  + bc_term.mean()

            # Optimize the actor
            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.actor.parameters(), max_norm=self.clip_grad_max_norm)  
            self.actor_optimizer.step()

            # Optimize the value function
            vf = self.value_function(state_batch, w_obj_batch)
            vf_loss = F.mse_loss(vf, Q.detach())
            self.vf_optimizer.zero_grad()
            vf_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.value_function.parameters(), max_norm=self.clip_grad_max_norm) 
            self.vf_optimizer.step()
                       
            # Soft update the target networks
            for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
            actor_update_step = self.total_it // self.policy_freq
            if self.args.algo!='Diffusion-QL' or \
                actor_update_step % 5==0 and self.total_it>=1000:
                for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
                    target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
                        
            # Write the results to tensorboard
            if (self.total_it % 5000) == 0:
                writer.add_scalar('Loss/Actor_Loss'.format(), actor_loss, self.total_it)
                writer.add_scalar('Loss/Actor_wq'.format(), wQ.mean(), self.total_it)
                writer.add_scalar('Loss/Critic_Loss'.format(), critic_loss, self.total_it)
                writer.add_scalar('Loss/vf_Loss'.format(), vf_loss, self.total_it)
                writer.add_scalar('Loss/bc_term'.format(), bc_term.mean(), self.total_it)
                for k in range(current_Q1.shape[1]):
                    writer.add_scalar(f'Loss/Critic_objective_{k}', current_Q1.reshape(batch_size, -1)[:, k].mean(), self.total_it)
        
        if self.args.lr_decay and self.total_it % 1000==0: 
            self.actor_lr_scheduler.step()
            self.critic_lr_scheduler.step()
            self.vf_lr_scheduler.step()

    # Learn from batch
    def train_bc_policy(self, replay_buffer, writer):
        self.writer = writer
        batch_size = self.args.batch_size
        self.total_it += 1
        batch = replay_buffer.sample(batch_size) 
        state_batch, action_batch, next_state_batch, reward_batch, not_done, w_obj_batch = batch

        policy_action = self.sample_actions(state_batch, w_obj_batch, actor=self.actor, deterministic=True)
        bc_term = self.cal_bc_loss(policy_action, action_batch, state_batch, w_obj_batch)
        actor_loss = bc_term.mean()

        # Optimize the actor
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.actor.parameters(), max_norm=self.clip_grad_max_norm)  
        self.actor_optimizer.step()

        for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
                    
        # Write the results to tensorboard
        if (self.total_it % 5000) == 0:
            writer.add_scalar('Loss/Actor_Loss'.format(), actor_loss, self.total_it)

    # Learn from batch
    def train_iql(self, replay_buffer, writer):
        self.writer = writer
        batch_size = self.args.batch_size
        self.total_it += 1

        batch = replay_buffer.sample(batch_size) 
        state_batch, action_batch, next_state_batch, reward_batch, not_done, w_obj_batch = batch
 
        # Update the angle_loss_coeff
        angle_loss_coeff = self.args.angle_loss_coeff if self.total_it<=self.args.iql_warmup_step else 0 #*max(0, 1-self.total_it/(self.args.time_steps*0.75))

        # Optimize the critic
        current_Q1, current_Q2 = self.critic(state_batch, w_obj_batch, action_batch)
        with torch.no_grad():
            target_V = self.value_function(next_state_batch, w_obj_batch)
        target_V = reward_batch + not_done * self.gamma * target_V

        angle_term_1 = torch.rad2deg(torch.acos(torch.clamp(F.cosine_similarity(w_obj_batch,current_Q1),0, 0.9999)))
        angle_term_2 = torch.rad2deg(torch.acos(torch.clamp(F.cosine_similarity(w_obj_batch,current_Q2),0, 0.9999)))
        angle_loss = angle_term_1.mean() + angle_term_2.mean()

        critic_loss = F.mse_loss(current_Q1, target_V) + F.mse_loss(current_Q2, target_V) + angle_loss_coeff * angle_loss
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm=self.clip_grad_max_norm) 
        self.critic_optimizer.step()

        # Optimize the value function
        with torch.no_grad():
            target_Q1, target_Q2 = self.critic_target(state_batch, w_obj_batch, action_batch)
        min_target_Q = self.min_qvec_operator(target_Q1, target_Q2, w_obj_batch)
        current_V = self.value_function(state_batch, w_obj_batch)
        vf_err = torch.bmm(w_obj_batch.unsqueeze(1), (current_V-min_target_Q).unsqueeze(2)).squeeze(-1)
        vf_sign = (vf_err > 0).float()
        vf_weight = (1-vf_sign) * self.args.iql_quantile + vf_sign * (1-self.args.iql_quantile)
        vf_loss = (vf_weight * torch.square(current_V-min_target_Q)).mean()
        self.vf_optimizer.zero_grad()
        vf_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.value_function.parameters(), max_norm=self.clip_grad_max_norm) 
        self.vf_optimizer.step()

        # Soft update the target networks
        for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

        # Delayed policy updates
        if self.total_it % self.policy_freq == 0 and self.total_it >= self.args.iql_warmup_step:
            # Optimize the actor
            _, _, dist = self.sample_actions(state_batch, w_obj_batch, actor=self.actor, deterministic=False, need_log_pi=True)
            
            if self.policy_noise>0:
                noise = torch.randn_like(action_batch).to(self.device)
                noise_clip = (noise*self.policy_noise).clamp(-self.noise_clip, self.noise_clip)
                noise_actions = (action_batch + noise_clip).clamp(-self.max_action,self.max_action)
                with torch.no_grad():
                    noise_target_Q1, noise_target_Q2 = self.critic_target(state_batch, w_obj_batch, noise_actions)
                min_target_Q = self.min_qvec_operator(noise_target_Q1, noise_target_Q2, w_obj_batch)
                log_action_prob = dist.log_prob(noise_actions)
            else:
                log_action_prob = dist.log_prob(action_batch)

            adv = torch.bmm(w_obj_batch.unsqueeze(1), (min_target_Q-current_V).unsqueeze(2)).squeeze(-1).detach()
            exp_adv = torch.exp(adv / self.args.iql_beta)
            if self.args.iql_clip_score is not None:
                exp_adv = torch.clamp(exp_adv, max=self.args.iql_clip_score)
            if self.args.iql_clip_min_score is not None:  #if adv<iql_clip_min_score, set weight=0
                exp_adv = exp_adv * (adv>=self.args.iql_clip_min_score)

            actor_loss = (-log_action_prob * exp_adv).mean()
            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.actor.parameters(), max_norm=self.clip_grad_max_norm)  
            self.actor_optimizer.step()

            # Write the results to tensorboard
            if (self.total_it % 5000) == 0:
                writer.add_scalar('Loss/Actor_Loss'.format(), actor_loss, self.total_it)
                writer.add_scalar('Loss/Critic_Loss'.format(), critic_loss, self.total_it)
                writer.add_scalar('Loss/vf_Loss'.format(), vf_loss, self.total_it)
                writer.add_scalar('Loss/angle_loss'.format(), angle_loss, self.total_it)
                writer.add_scalar('Loss/adv'.format(), adv.mean(), self.total_it)
                writer.add_scalar('Loss/exp_adv'.format(), exp_adv.mean(), self.total_it)
                writer.add_scalar('Loss/log_action_prob'.format(), log_action_prob.mean(), self.total_it)
                writer.add_scalar('Loss/vf_err'.format(), vf_err.mean(), self.total_it)
                writer.add_scalar('Loss/vf_sign'.format(), vf_sign.mean(), self.total_it)

                if self.args.use_wandb:
                    for k in range(state_batch.shape[0]):
                        import wandb
                        wQ = torch.bmm(w_obj_batch.unsqueeze(1), min_target_Q.unsqueeze(2)).squeeze().detach()
                        wandb.log({'Loss/pref_1': w_obj_batch[k,0].item(), 'Loss/V_1': current_V[k, 0].item(), 'Loss/V_2': current_V[k, 1].item(), 
                                   'Loss/Q_1': target_Q1[k, 0].item(), 'Loss/Q_2': target_Q1[k, 1].item(), 'Loss/wQ': wQ[k].item()})
                    
                for k in range(current_Q1.shape[1]):
                    writer.add_scalar(f'Loss/Critic_objective_{k}', target_Q1.reshape(batch_size, 1, -1)[:, :, k].mean(), self.total_it)
        
from lib.PEDA.modt.evaluation.evaluator_rvs import EvaluatorRVS as Evaluator
from sklearn.linear_model import LinearRegression
from lib.PEDA.modt.evaluation.evaluate_episodes import EvalEpisode
from lib.PEDA.modt.utils import undominated_indices
from gym.spaces import Box
from torch import nn

class MOOF_PEDA_AGENT:
    def __init__(self, model, args, buffer, env, state_mean, state_std, reward_scale_weight):
        from torch.optim import AdamW as Optimizer
        from lib.PEDA.modt.evaluation.evaluator_rvs import EvaluatorRVS as Evaluator
        super().__init__()
        self.args = args
        self.reward_scale_weight = reward_scale_weight
        state_dim, act_dim, pref_dim, rtg_dim = args.obs_shape, args.action_shape, args.reward_size, args.reward_size
        state_dim += pref_dim
        trajectories = buffer.trajectories

        preferences = np.array([traj['preference'][0, :] for traj in trajectories])
        returns = np.array([np.sum(np.multiply(traj['raw_rewards'], traj['preference'])) for traj in trajectories])
        returns_mo = np.array([traj['raw_rewards'].sum(axis=0) for traj in trajectories])
        non_dom = undominated_indices(returns_mo, tolerance=0.1)  
        self.lrModels = [LinearRegression() for _ in range(pref_dim)]
        for obj, lrModel in enumerate(self.lrModels):
            lrModel.fit(preferences[non_dom], returns_mo[non_dom, obj])  #TODO
        max_prefs = np.max(preferences, axis=0)
        min_prefs = np.min(preferences, axis=0)
        min_each_obj_step = np.min(np.vstack([np.min(traj['raw_rewards'], axis=0) for traj in trajectories]), axis=0)
        max_each_obj_step = np.max(np.vstack([np.max(traj['raw_rewards'], axis=0) for traj in trajectories]), axis=0)

        state_mean = np.concatenate((state_mean[0], np.zeros(pref_dim)))
        state_std = np.concatenate((state_std[0], np.ones(pref_dim)))
        self.evaluator = Evaluator(
            env, state_dim, act_dim, pref_dim, rtg_dim,
            max_ep_len=args.max_episode_len,
            scale=1.0,
            state_mean=state_mean,
            state_std=state_std,
            min_each_obj_step=min_each_obj_step,
            max_each_obj_step=max_each_obj_step,
            act_scale=np.array(env.action_space.high),
            use_obj=-1,
            concat_state_pref=True,
            concat_rtg_pref=False,
            concat_act_pref=False,
            normalize_reward=False,
            video_dir=None,
            device=args.device,
            mode='normal',
            logsdir=None,
            eval_only=True,
            rtg_input_scale=reward_scale_weight,
        )

        self.eval_episodes = EvalEpisode(
            evaluator=self.evaluator,
            num_eval_episodes=1,
            max_each_obj_traj=np.max(returns_mo, axis=0),
            rtg_scale=1.0,
            lrModels=self.lrModels,
            use_max_rtg=False
        )

        optimizer = Optimizer(
            model.parameters(),
            lr=1e-4,
            weight_decay=1e-3,
        )

        scheduler = torch.optim.lr_scheduler.LambdaLR(
            optimizer, lambda steps: min((steps+1)/10000, 1)
        )

        loss_fn = lambda s_hat, a_hat, r_hat, pref_hat, s, a, r, pref: \
            torch.mean((a_hat - a) ** 2)

        max_raw_r = np.multiply(np.max(returns_mo, axis=0), max_prefs) # based on weighted values
        min_raw_r = np.multiply(np.min(returns_mo, axis=0), min_prefs)
        max_final_r = np.max(returns)
        min_final_r = np.min(returns)

        self.model = model
        self.optimizer = optimizer
        self.get_batch = buffer
        self.loss_fn = loss_fn
        # for plotting purposes
        self.dataset_min_prefs = min_prefs
        self.dataset_max_prefs = max_prefs
        self.dataset_min_raw_r = min_raw_r # weighted
        self.dataset_max_raw_r = max_raw_r
        self.dataset_min_final_r = min_final_r
        self.dataset_max_final_r = max_final_r
        self.scheduler = scheduler
        self.concat_rtg_pref = 0
        self.concat_act_pref = 0
        self.total_it = 0

    def eval_one_episode(self, prefs=None, num_eval_episodes=1, **kwargs):
        self.model.eval()
        self.eval_episodes.num_eval_episodes = num_eval_episodes
        prefs = prefs.astype(np.float64)
        prefs /= np.linalg.norm(prefs, ord=1, axis=1, keepdims=True)
        eval_fns=self.eval_episodes(pref_set=prefs, reward_scale_weight=self.reward_scale_weight)
        set_final_return, set_unweighted_raw_return, set_weighted_raw_return, set_cum_r_original = [], [], [], []
        for eval_fn in eval_fns:
            outputs, final_returns, unweighted_raw_returns, weighted_raw_returns, cum_r_original = eval_fn(self.model, 0)
            set_final_return.append(np.mean(final_returns, axis=0))
            set_unweighted_raw_return.append(np.mean(unweighted_raw_returns, axis=0))
            set_weighted_raw_return.append(np.mean(weighted_raw_returns, axis=0))
            set_cum_r_original.append(np.mean(cum_r_original, axis=0))
        
        rollout_unweighted_raw_r = np.array(set_unweighted_raw_return)
        return rollout_unweighted_raw_r


    def train_step(self, replay_buffer, writer):
        self.total_it += 1
        self.model.train()
        self.writer = writer
        batch_size = self.args.batch_size
        states, actions, raw_return, avg_rtg, timesteps, attention_mask, pref = replay_buffer.sample(batch_size) 
        states = torch.squeeze(states)
        actions = torch.squeeze(actions)
        avg_rtg = torch.squeeze(avg_rtg)
        if len(avg_rtg.shape) == 1:
            avg_rtg = torch.unsqueeze(avg_rtg, dim=-1)

        states = torch.cat((states, avg_rtg), dim=1)

        loss = self.model.training_step(
            (states, actions),
            batch_idx=0 # doesn't matter in source code
        )

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        if self.scheduler is not None:
            self.scheduler.step()

        if (self.total_it % 5000) == 0:
            writer.add_scalar('Loss/Actor_Loss'.format(), loss, self.total_it)

        return loss.detach().cpu().item()

    def get_log_action_prob(self, state, action, w_batch):
        augment_state = torch.cat([state, w_batch], dim=1)
        w_batch_numpy = w_batch.cpu().detach().numpy().astype(np.float64)
        w_batch_numpy /= np.linalg.norm(w_batch_numpy, ord=1, axis=1, keepdims=True)
        rtg = np.stack([self.lrModels[i].predict(w_batch_numpy) for i in range(self.args.reward_size)], axis=-1)
        rtg = torch.from_numpy(rtg / self.args.max_episode_len).to(self.args.device)
        log_prob = self.model.get_probabilities(augment_state, rtg, action).log()
        return log_prob
    
    def get_critic_tderror(self, state, action, reward, next_state, w_batch, not_done):
        return (torch.zeros_like(reward) + 1e-10).float().to(self.args.device)

    def get_vf_tderror(self, state, action, reward, next_state, w_batch, not_done):
        return (torch.zeros_like(reward) + 1e-10).float().to(self.args.device)



class MOOF_PROMPT_MODT_AGENT:
    def __init__(self, model, args, buffer, env, state_mean, state_std, reward_scale_weight):
        from torch.optim import AdamW as Optimizer
        from lib.PEDA.modt.evaluation.evaluator_dt import EvaluatorPromptDT as Evaluator
        super().__init__()
        self.args = args
        self.reward_scale_weight = reward_scale_weight
        state_dim, act_dim, pref_dim, rtg_dim = args.obs_shape, args.action_shape, args.reward_size, args.reward_size
        trajectories = buffer.trajectories

        preferences = np.array([traj['preference'][0, :] for traj in trajectories])
        returns = np.array([np.sum(np.multiply(traj['raw_rewards'], traj['preference'])) for traj in trajectories])
        returns_mo = np.array([traj['raw_rewards'].sum(axis=0) for traj in trajectories])
        non_dom = undominated_indices(returns_mo, tolerance=0.1)  
        self.lrModels = [LinearRegression() for _ in range(pref_dim)]
        for obj, lrModel in enumerate(self.lrModels):
            lrModel.fit(preferences[non_dom], returns_mo[non_dom, obj])  #TODO
        max_prefs = np.max(preferences, axis=0)
        min_prefs = np.min(preferences, axis=0)
        min_each_obj_step = np.min(np.vstack([np.min(traj['raw_rewards'], axis=0) for traj in trajectories]), axis=0)
        max_each_obj_step = np.max(np.vstack([np.max(traj['raw_rewards'], axis=0) for traj in trajectories]), axis=0)

        state_mean = state_mean[0]
        state_std = state_std[0]
        self.evaluator = Evaluator(
            env, state_dim, act_dim, pref_dim, rtg_dim,
            max_ep_len=args.max_episode_len,
            scale=1.0,
            state_mean=state_mean,
            state_std=state_std,
            min_each_obj_step=min_each_obj_step,
            max_each_obj_step=max_each_obj_step,
            act_scale=np.array(env.action_space.high),
            use_obj=-1,
            concat_state_pref=False,
            concat_rtg_pref=False,
            concat_act_pref=False,
            normalize_reward=False,
            video_dir=None,
            device=args.device,
            mode='normal',
            logsdir=None,
            eval_only=True,
            rtg_input_scale=reward_scale_weight,
        )

        self.eval_episodes = EvalEpisode(
            evaluator=self.evaluator,
            num_eval_episodes=1,
            max_each_obj_traj=np.max(returns_mo, axis=0),
            rtg_scale=1.0,
            lrModels=self.lrModels,
            use_max_rtg=False
        )

        optimizer = Optimizer(
            model.parameters(),
            lr=1e-4,
            weight_decay=1e-3,
        )

        scheduler = torch.optim.lr_scheduler.LambdaLR(
            optimizer, lambda steps: min((steps+1)/10000, 1)
        )

        loss_fn = lambda s_hat, a_hat, r_hat, pref_hat, s, a, r, pref: \
            torch.mean((a_hat - a) ** 2)

        max_raw_r = np.multiply(np.max(returns_mo, axis=0), max_prefs) # based on weighted values
        min_raw_r = np.multiply(np.min(returns_mo, axis=0), min_prefs)
        max_final_r = np.max(returns)
        min_final_r = np.min(returns)

        self.model = model
        self.optimizer = optimizer
        self.get_batch = buffer
        self.loss_fn = loss_fn
        # for plotting purposes
        self.dataset_min_prefs = min_prefs
        self.dataset_max_prefs = max_prefs
        self.dataset_min_raw_r = min_raw_r # weighted
        self.dataset_max_raw_r = max_raw_r
        self.dataset_min_final_r = min_final_r
        self.dataset_max_final_r = max_final_r
        self.scheduler = scheduler
        self.concat_rtg_pref = 0
        self.concat_act_pref = 0
        self.total_it = 0

    def eval_one_episode(self, prefs=None, prompts=None, num_eval_episodes=1, **kwargs):
        self.model.eval()
        self.eval_episodes.num_eval_episodes = num_eval_episodes
        prefs = prefs.astype(np.float64)
        prefs /= np.linalg.norm(prefs, ord=1, axis=1, keepdims=True)
        eval_fns=self.eval_episodes(pref_set=prefs, reward_scale_weight=self.reward_scale_weight)
        set_final_return, set_unweighted_raw_return, set_weighted_raw_return, set_cum_r_original = [], [], [], []
        for i, eval_fn in enumerate(eval_fns):
            outputs, final_returns, unweighted_raw_returns, weighted_raw_returns, cum_r_original = eval_fn(self.model, 0, prompt=prompts[i])
            set_final_return.append(np.mean(final_returns, axis=0))
            set_unweighted_raw_return.append(np.mean(unweighted_raw_returns, axis=0))
            set_weighted_raw_return.append(np.mean(weighted_raw_returns, axis=0))
            set_cum_r_original.append(np.mean(cum_r_original, axis=0))
        
        rollout_unweighted_raw_r = np.array(set_unweighted_raw_return)
        return rollout_unweighted_raw_r


    def train_step(self, replay_buffer, writer):
        self.total_it += 1
        self.model.train()
        self.writer = writer
        batch_size = self.args.batch_size
        states, actions, raw_return, rtg, timesteps, attention_mask, pref = replay_buffer.sample(batch_size) 
        rtg = rtg[:, :-1]
        
        action_target = torch.clone(actions)
        return_target = torch.clone(raw_return)
        pref_target = torch.clone(pref)
        
        action_preds, return_preds, pref_preds = self.model.forward(
            states, actions, rtg, pref, timesteps, attention_mask=attention_mask
        )

        act_dim = self.get_batch.act_dim
        action_preds = action_preds.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]
        action_target = action_target.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]
        
        pref_dim = self.get_batch.pref_dim

        return_preds = return_preds.reshape(-1, pref_dim)[attention_mask.reshape(-1) > 0]
        return_target = return_target.reshape(-1, pref_dim)[attention_mask.reshape(-1) > 0]
        
        pref_preds = pref_preds.reshape(-1, pref_dim)[attention_mask.reshape(-1) > 0]
        pref_target = pref_target.reshape(-1, pref_dim)[attention_mask.reshape(-1) > 0]

        loss = self.loss_fn(
            None, action_preds, return_preds, pref_preds,
            None, action_target, return_target, pref_target,
        )
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.25)
        self.optimizer.step()

        if self.scheduler is not None:
            self.scheduler.step()

        if (self.total_it % 5000) == 0:
            writer.add_scalar('Loss/Actor_Loss'.format(), loss, self.total_it)

        return loss.detach().cpu().item()
