from codetiming import Timer
import torch
import wandb
import os

from networks import QNetworkCNNActionUnshared, QNetworkCNNShared, QNetworkCNNSusShared
from policy import PolicyLinProg
from marl_exp.utils import create_wandb_tags_and_config
from marl_exp.replay import ReplayMemoryMARL
from utils import save_frames_as_gif
from config import MODELS_SAVE_PATH, RENDER_SAVE_PATH, SUBSTRATE_NAME, DEVICE, DTYPE, DELTA, TRAINING_TYPE


class PolicyLinProgMARL(PolicyLinProg):
    def __init__(self,
                 env,
                 render_env=None,
                 reward_type='common_reward',
                 SW_prop=0,
                 extra_SW_prop=0,
                 hid_size=64,
                 n_hid_layers=1,
                 gamma=.99,
                 normalize_obs=False,
                 lr_start=1e-3, lr_end=1e-3,
                 eps_start=1., eps_end=0.,
                 batch_size=256,
                 n_batches=50_000,
                 n_interactions=1,
                 n_warm_start_batches=25,
                 target_update_freq=100,
                 log_wandb=False,
                 log_freq=1000,  # should equal episode duration
                 n_eval_episodes=20,
                 scheduling_speed=1.,
                 buffer_size=int(1e5),
                 train_principal=True,
                 val_agent=True,
                 save_path='',
                 tags=None
                 ):
        super(PolicyLinProgMARL, self).__init__(env, hid_size, n_hid_layers, gamma, lr_start, lr_end,
                                                eps_start, eps_end, batch_size,
                                                n_batches, n_interactions, n_warm_start_batches, target_update_freq,
                                                log_wandb, 0, scheduling_speed)
        self.inp_size = env.observation_space.shape[1:]
        self.out_size = env.action_space[0].n
        self.n_agents = env.get_attr("n_agents")[0]
        self.n_envs = self.env.num_envs

        self.reward_type = reward_type
        self.SW_prop = SW_prop + extra_SW_prop
        self.extra_SW_prop = extra_SW_prop

        self.iteration = 0
        self.n_episodes = 0
        self.reset_selfish_agents()
        self.normalize_obs = normalize_obs

        self.render_env = render_env
        self.log_freq = log_freq
        self.n_eval_episodes = n_eval_episodes
        self.render_env = None

        self.buffer = ReplayMemoryMARL(buffer_size, batch_size)
        self.buffer_agent = ReplayMemoryMARL(buffer_size, batch_size)
        self.init_networks()
        self.init_timers()

        self.train_flag = train_principal
        self.val_agent_flag = val_agent
        self.model_save_path = MODELS_SAVE_PATH + save_path
        self.tags = [] if tags is None else tags

    def init_principal(self):
        self.Q_principal = QNetworkCNNShared(self.inp_size, self.hid_size, self.out_size, self.n_hid_layers,
                                             normalize=self.normalize_obs)
        for p in self.Q_principal.parameters():
            p.register_hook(lambda grad: torch.clamp(grad, -1, 1))

        self.opt_principal = torch.optim.Adam(self.Q_principal.parameters(), self.lr_start)

        self.Q_principal_target = QNetworkCNNShared(self.inp_size, self.hid_size, self.out_size, self.n_hid_layers,
                                                    normalize=self.normalize_obs)
        self.Q_principal_target.load_state_dict(self.Q_principal.state_dict())
        self.Q_principal_target.eval()

    def init_agent(self):
        self.Q_agent = QNetworkCNNSusShared(self.inp_size, self.hid_size, self.out_size, self.n_hid_layers,
                                            n_agents=self.n_agents, normalize=self.normalize_obs)
        for p in self.Q_agent.parameters():
            p.register_hook(lambda grad: torch.clamp(grad, -1, 1))
        self.opt_agent = torch.optim.Adam(self.Q_agent.parameters(), self.lr_start)
        self.Q_agent_target = QNetworkCNNSusShared(self.inp_size, self.hid_size, self.out_size, self.n_hid_layers,
                                                   n_agents=self.n_agents, normalize=self.normalize_obs)
        self.Q_agent_target.load_state_dict(self.Q_agent.state_dict())
        self.Q_agent_target.eval()

    def init_agent_val(self):
        self.Q_agent_val = QNetworkCNNActionUnshared(self.inp_size, self.hid_size, self.out_size,
                                                     self.n_hid_layers, self.n_agents, normalize=self.normalize_obs)
        for p in self.Q_agent_val.parameters():
            p.register_hook(lambda grad: torch.clamp(grad, -1, 1))

        self.opt_agent_val = torch.optim.Adam(self.Q_agent_val.parameters(), self.lr_start)

        self.Q_agent_target_val = QNetworkCNNActionUnshared(self.inp_size, self.hid_size, self.out_size,
                                                            self.n_hid_layers, self.n_agents, normalize=self.normalize_obs)
        self.Q_agent_target_val.load_state_dict(self.Q_agent_val.state_dict())
        self.Q_agent_target_val.eval()

    def synch_target_networks(self):
        self.Q_principal_target.load_state_dict(self.Q_principal.state_dict())
        self.Q_agent_target.load_state_dict(self.Q_agent.state_dict())
        self.Q_agent_target_val.load_state_dict(self.Q_agent_val.state_dict())

    def schedule(self):
        super().schedule()
        for g in self.opt_agent_val.param_groups:
            g['lr'] = self.lr

    def init_wandb(self):
        if self.log_wandb:
            config, tags = create_wandb_tags_and_config(self)
            if TRAINING_TYPE is not None:
                tags.append(TRAINING_TYPE)
            tags.extend(self.tags)
            tags.append(SUBSTRATE_NAME)
            tags.append('baseline' if self.reward_type is None else self.reward_type)
            wandb.init(project="rl_contracts", config=config, tags=tags, reinit=True)

    def process_state(self, state):
        return torch.from_numpy(state)

    def process_tensor(self, tensor):
        return tensor.to(DTYPE).to(DEVICE)

    def init_timers(self):
        self.timer_principal = Timer(name="Total principal", text="{name}: {:.4f} seconds",
                                     initial_text='\nPRINCIPAL PHASE')
        self.timer_agent = Timer(name="Total agent", text="{name}: {:.4f} seconds",
                                 initial_text='\nAGENT PHASE')
        self.timer_val = Timer(name="Total validation", text="{name}: {:.4f} seconds",
                               initial_text='\nVALIDATION PHASE' if self.reward_type is not None else '\nBASELINE')

        self.timer_collect = Timer(name="collection", text="{name}: {:.4f} seconds", logger=None)
        self.timer_train = Timer(name="training", text="{name}: {:.4f} seconds", logger=None)
        self.timer_eval = Timer(name="evaluation", text="{name}: {:.4f} seconds", logger=None)

    def reset_timers(self):
        Timer.timers.clear()

    def increment_iteration(self):
        self.iteration += self.n_envs * self.n_interactions

    def get_actual_iteration(self):
        return self.iteration // (self.n_envs * self.n_interactions)

    @torch.no_grad()
    def act(self, states, eps=0.):
        if self.reward_type not in {'common_reward', 'decentralized'}:
            return torch.zeros((self.n_envs, self.n_agents), dtype=DTYPE, device=DEVICE)
        states = self.process_tensor(states)
        _, actions = self.Q_principal(states, sample=True, eps=eps)
        return actions

    @torch.no_grad()
    def act_agent(self, states, actions_p, eps=0.):
        states = self.process_tensor(states)
        _, actions_a = self.Q_agent(states, self.selfish_agents, sample=True, eps=eps)
        actions = actions_a.where(self.selfish_agents, actions_p)
        return actions

    @torch.no_grad()
    def act_val(self, states, actions_p, eps=0.):
        states = self.process_tensor(states)
        _, actions = self.Q_agent_val(states, actions_p, sample=True, eps=eps)
        return actions

    def reset_selfish_agents(self):
        self.selfish_agents = torch.zeros((self.n_envs, self.n_agents)).bool().to(DEVICE)
        self.reset_sus_agents()

    def update_selfish_agents(self):
        n_ones = torch.randint(0, self.n_agents + 1, (self.n_envs,))
        n_zeros = self.n_agents - n_ones
        mask = torch.empty((self.n_envs, self.n_agents), dtype=torch.bool)
        for env_idx in range(self.n_envs):
            idx = torch.randperm(self.n_agents)
            mask[env_idx] = torch.cat([torch.zeros(n_zeros[env_idx]), torch.ones(n_ones[env_idx])])[idx].bool()
        self.selfish_agents = mask.to(DEVICE)

    def reset_sus_agents(self):
        self.sus_agents = torch.zeros((self.n_envs, self.n_agents)).bool().to(DEVICE)

    def update_sus_agents(self, actions_p, actions):
        mask_sus = actions_p != actions
        self.sus_agents[~mask_sus] = False
        self.sus_agents[mask_sus] = True

    @torch.no_grad()
    def get_contracts(self, q_values_a, actions_p):
        if self.reward_type is None:
            return torch.zeros_like(q_values_a)
        payments = q_values_a.max(-1, keepdim=True)[0] - q_values_a.gather(-1, actions_p) + DELTA
        contracts = torch.zeros_like(q_values_a).scatter(-1, actions_p, payments)
        return contracts

    def get_mod_rewards(self, rewards: torch.Tensor):
        return rewards.sum(-1, keepdim=True).repeat_interleave(rewards.shape[-1], -1)

    def get_payments_baseline(self, rewards):
        return torch.zeros_like(rewards)

    def get_payments_from_contracts(self, contracts, actions):
        return contracts.gather(-1, actions).squeeze(-1).clip(0)

    def get_payments_prop(self, rewards, prop):
        return prop * (self.get_mod_rewards(rewards) - rewards).clip(0)

    @torch.no_grad()
    def get_payments(self, states, rewards, actions_p, actions):
        if self.reward_type is None:
            return self.get_payments_baseline(rewards)
        rewards = self.process_tensor(rewards)
        if self.reward_type in {'common_reward', 'decentralized'}:
            states = self.process_tensor(states)
            self.update_sus_agents(actions_p, actions)
            q_values_a = self.Q_agent(states, self.sus_agents)
            contracts = self.get_contracts(q_values_a, actions_p.unsqueeze(-1))
            payments = self.get_payments_from_contracts(contracts, actions.unsqueeze(-1))
            payments += self.get_payments_prop(rewards, self.SW_prop)
        elif self.reward_type == 'constant_proportion':
            payments = self.get_payments_prop(rewards, self.SW_prop)
        return payments.cpu()

    def reset_stuff(self):
        self.iteration = 0
        self.n_episodes = 0
        self.buffer.reset()
        self.buffer_agent.reset()
        self.last_states = self.process_state(self.env.reset())
        self.reset_schedules()
        self.reset_timers()
        self.init_wandb()
        if not os.path.exists(self.model_save_path):
            os.makedirs(self.model_save_path)

    def step(self, actions):
        self.env.step_async(actions.cpu().numpy())
        next_states, rewards, dones, infos = self.env.step_wait()
        rewards = torch.from_numpy(rewards)
        done = any(dones)
        return next_states, rewards, done, infos

    @torch.no_grad()
    def _collect_experience(self, mode='train'):
        states = self.last_states.clone()
        if mode == 'train':
            actions_p = self.act(states, eps=self.eps)
            actions = self.act_agent(states, actions_p, eps=self.eps)
        elif mode == 'val':
            actions_p = self.act(states)
            actions = self.act_val(states, actions_p, eps=self.eps)

        next_states, rewards, done, infos = self.step(actions)

        if mode == 'val':
            payments = self.get_payments(states, rewards, actions_p, actions)
            if self.reward_type == 'decentralized':  # payments come out of agents' pockets
                payments -= payments.flip(-1)  # ONLY CORRECT FOR TWO AGENTS
            rewards += payments

        if done:
            self.n_episodes += 1
            self.last_states = self.process_state(next_states.copy())
            for i in range(self.n_agents):
                next_states[i] = infos[i]["terminal_observation"]  # bootstrapping
            next_states = self.process_state(next_states)
        else:
            self.last_states = next_states = self.process_state(next_states)

        self.buffer.push({
            'states': states,
            'actions': actions,
            'rewards': rewards,
            'next_states': next_states,
        })
        if mode == 'train':
            self.buffer_agent.push({
                'states': states,
                'actions': actions,
                'rewards': rewards,
                'next_states': next_states,
                'selfish_agents': self.selfish_agents.cpu(),
            })
        if done:
            self.update_selfish_agents()
            self.reset_sus_agents()

    def _update_priority(self, transitions, priority, buffer=None):
        if buffer is None:
            buffer = self.buffer
        transitions['td_error'] = priority.detach().cpu()
        buffer.update_priority(transitions)

    def train(self):
        if self.reward_type in {'common_reward', 'decentralized'}:
            if self.reward_type == 'common_reward' and self.train_flag:
                self.timer_principal.start()
                self.reset_stuff()
                prop = self._train()
                if self.log_wandb:
                    wandb.finish()
                self.timer_principal.stop()
                torch.save(self.Q_principal.state_dict(), f'{self.model_save_path}principal.pt')
                torch.save(self.Q_agent.state_dict(), f'{self.model_save_path}agent.pt')
                self.render()
            else:
                self.Q_principal.load_state_dict(torch.load(f'{self.model_save_path}principal.pt'))
                self.Q_agent.load_state_dict(torch.load(f'{self.model_save_path}agent.pt'))
                self.synch_target_networks()

        if self.reward_type != 'common_reward' or self.val_agent_flag:
            self.timer_val.start()
            self.reset_stuff()
            prop = self._validate()
            if self.log_wandb:
                wandb.finish()
            self.timer_val.stop()
            if self.reward_type == 'common_reward':
                save_name = f'{self.model_save_path}val.pt'
            elif self.reward_type is None:
                save_name = f'{self.model_save_path}baseline.pt'
            else:
                save_name = f'{self.model_save_path}{self.reward_type}.pt'
            torch.save(self.Q_agent_val.state_dict(), save_name)
            self.render('val')
        return prop

    def _train(self):
        for b in range(self.n_batches+1):
            self.increment_iteration()

            self.timer_collect.start()
            for _ in range(self.n_interactions):
                self._collect_experience()
            self.timer_collect.stop()

            if b < self.n_warm_start_batches:
                continue

            self.timer_train.start()

            self._train_principal()
            self._train_agent()

            self.schedule()

            if self.get_actual_iteration() % self.target_update_freq == 0:
                self.synch_target_networks()

            self.timer_train.stop()

            if self.get_actual_iteration() % self.log_freq == 0:
                prop = self.log()
        return prop

    def _train_principal(self):
        transitions = self.buffer.sample()
        states, actions, rewards, next_states = (
            self.process_tensor(transitions['states']),
            transitions['actions'].to(DEVICE),
            self.process_tensor(transitions['rewards']),
            self.process_tensor(transitions['next_states']),
        )
        actions = actions.unsqueeze(-1)
        if TRAINING_TYPE is None:
            rewards = self.get_mod_rewards(rewards)

        q_values_a = self.Q_agent(states)
        contracts = self.get_contracts(q_values_a, actions)
        payments = self.get_payments_from_contracts(contracts, actions).squeeze(-1) - DELTA

        q_values = self.Q_principal(states).gather(-1, actions).squeeze(-1)
        with torch.no_grad():
            next_actions_principal = self.get_best_actions(self.Q_principal(next_states))  # double DQN
            next_q_values = self.Q_principal_target(next_states).gather(-1, next_actions_principal).squeeze(-1)
            targets = rewards - payments / (self.n_agents - 1) / 10 + self.gamma * next_q_values  # bootstrap if done

        if TRAINING_TYPE == 'vdn':
            q_values, targets = q_values.sum(-1), targets.sum(-1)

        loss = self.get_loss(q_values, targets)
        self.update_principal(loss)
        self.loss_agent = self.mixing_coef * self.loss_agent + (1 - self.mixing_coef) * loss.cpu().item()

        priority = (q_values - targets).abs()
        if TRAINING_TYPE is None:
            priority = priority.sum(-1)
        self._update_priority(transitions, priority)

    def _train_agent(self):
        transitions = self.buffer_agent.sample()
        states, actions, rewards, next_states, selfish_agents = (
            self.process_tensor(transitions['states']),
            transitions['actions'].to(DEVICE),
            self.process_tensor(transitions['rewards']),
            self.process_tensor(transitions['next_states']),
            transitions['selfish_agents'].to(DEVICE),
        )
        actions = actions.unsqueeze(-1)
        selfish_agents = selfish_agents.unsqueeze(-1)

        q_values = self.Q_agent(states, selfish_agents).gather(-1, actions).squeeze(-1)
        with torch.no_grad():
            ### the principal pays to match the highest utility exactly, so we can just use the highest utility
            next_actions = self.get_best_actions(self.Q_agent(next_states, selfish_agents))  # double DQN
            next_q_values = self.Q_agent_target(next_states, selfish_agents).gather(-1, next_actions).squeeze(-1)
            targets = rewards + self.gamma * next_q_values  # bootstrap if done

        loss = self.get_loss(q_values, targets)
        self.update_agent(loss)
        self.loss_agent = self.mixing_coef * self.loss_agent + (1 - self.mixing_coef) * loss.item()

        priority = (q_values - targets).abs().sum(-1)
        self._update_priority(transitions, priority, buffer=self.buffer_agent)

    def _validate(self):
        for b in range(self.n_batches+1):
            self.increment_iteration()

            self.timer_collect.start()
            for _ in range(self.n_interactions):
                self._collect_experience('val')
            self.timer_collect.stop()

            if b < self.n_warm_start_batches:
                continue

            self.timer_train.start()

            transitions = self.buffer.sample()
            states, actions, rewards, next_states = (
                self.process_tensor(transitions['states']),
                transitions['actions'].to(DEVICE),
                self.process_tensor(transitions['rewards']),
                self.process_tensor(transitions['next_states']),
            )
            actions = actions.unsqueeze(-1)

            if self.reward_type in {'common_reward', 'decentralized'}:
                actions_principal = self.get_best_actions(self.Q_principal(states))
                next_actions_principal = self.get_best_actions(self.Q_principal(next_states))
            else:
                actions_principal = next_actions_principal = torch.zeros_like(actions)

            q_values = self.Q_agent_val(states, actions_principal).gather(-1, actions).squeeze(-1)
            with torch.no_grad():
                next_actions = self.get_best_actions(self.Q_agent_val(next_states, next_actions_principal))  # double DQN
                next_q_values = self.Q_agent_target_val(next_states, next_actions_principal).gather(-1, next_actions).squeeze(-1)
                targets = rewards + self.gamma * next_q_values  # bootstrap if done

            loss = self.get_loss(q_values, targets)

            self.update_agent_val(loss)
            self.loss_agent_val = self.mixing_coef * self.loss_agent_val + (1 - self.mixing_coef) * loss.item() / self.n_agents

            priority = (q_values - targets).abs().sum(-1)
            self._update_priority(transitions, priority)
            self.schedule()

            self.timer_train.stop()

            if self.get_actual_iteration() % self.target_update_freq == 0:
                self.synch_target_networks()

            if self.get_actual_iteration() % self.log_freq == 0:
                prop = self.log('val')
        return prop

    @torch.no_grad()
    def _evaluate(self, mode='agent'):
        self.reset_sus_agents()

        SW, payment, rew_diff, acc = 0., 0., 0., 0.
        n_episodes, iteration = 0, 0
        last_states = self.process_tensor(self.process_state(self.env.reset()))
        while n_episodes < self.n_eval_episodes:
            iteration += self.n_envs
            actions = actions_p = self.act(last_states)
            if mode == 'val':
                actions = self.act_val(last_states, actions_p)

            next_states, rewards, done, infos = self.step(actions)

            payment_ = self.get_payments(last_states, rewards, actions_p, actions)
            mod_rewards = self.get_mod_rewards(rewards)
            payment += payment_.abs().mean(-1).sum().item()
            rew_diff += (mod_rewards - rewards).abs().mean(-1).sum().item()
            acc += (actions_p == actions).float().mean(1).sum().item()

            last_states = self.process_tensor(self.process_state(next_states))
            if done:
                SW += sum([info['social_welfare'] for info in infos])
                n_episodes += self.n_envs
                self.reset_sus_agents()

        SW /= n_episodes
        payment /= iteration
        rew_diff /= iteration
        acc /= iteration
        return SW, payment, rew_diff, acc

    @torch.no_grad()
    def _evaluate_equilibrium(self):
        self.reset_sus_agents()
        all_deviate_welfare, one_cooperate_welfare, one_deviate_welfare, all_cooperate_welfare = 0., 0., 0., 0.
        last_states = self.process_tensor(self.process_state(self.env.reset()))
        
        for idx_deviate in range(-1, self.n_agents):
            n_episodes = 0
            while n_episodes < self.n_eval_episodes:
                actions_p = self.act(last_states)
                actions = actions_p.clone()
                if idx_deviate >= 0:  # one agent deviates from equilibrium
                    actions_a = self.act_val(last_states, actions_p)
                    actions[:, idx_deviate] = actions_a[:, idx_deviate]

                next_states, rewards, done, infos = self.step(actions)

                payments = self.get_payments(last_states, rewards, actions_p, actions)
                if idx_deviate < 0:  # no agent deviates from equilibrium
                    all_cooperate_welfare += rewards.sum().item() + payments.sum().item()
                else:  # one agent deviates from equilibrium
                    one_deviate_welfare += rewards[:, idx_deviate].sum().item() + payments[:, idx_deviate].sum().item()

                last_states = self.process_tensor(self.process_state(next_states))
                if done:
                    n_episodes += self.n_envs
                    
        for idx_cooperate in range(-1, self.n_agents):
            n_episodes = 0
            while n_episodes < self.n_eval_episodes:
                actions_p = self.act(last_states)
                actions = self.act_val(last_states, actions_p)
                if idx_cooperate >= 0:  # one agent cooperates
                    actions[:, idx_cooperate] = actions_p[:, idx_cooperate]

                next_states, rewards, done, infos = self.step(actions)

                payments = self.get_payments(last_states, rewards, actions_p, actions)
                if idx_cooperate < 0:  # no agent cooperates; all agents deviate from equilibrium
                    all_deviate_welfare += rewards.sum().item() + payments.sum().item()
                else:  # one agent cooperates
                    one_cooperate_welfare += rewards[:, idx_cooperate].sum().item() + payments[:, idx_cooperate].sum().item()

                last_states = self.process_tensor(self.process_state(next_states))
                if done:
                    n_episodes += self.n_envs

        return one_deviate_welfare / all_cooperate_welfare, all_deviate_welfare / one_cooperate_welfare

    @torch.no_grad()
    def log(self, mode='agent'):
        self.timer_eval.start()
        SW, payment, rew_diff, acc = self._evaluate(mode)
        payment_rew_diff_prop = min(payment / rew_diff, 1)

        loss = self.loss_agent
        one_deviate_welfare_prop, all_deviate_welfare_prop = 0, 0
        if mode == 'val':
            loss = self.loss_agent_val
            if self.reward_type in {'common_reward', 'decentralized'}:
                one_deviate_welfare_prop, all_deviate_welfare_prop = self._evaluate_equilibrium()
        self.timer_eval.stop()

        print(f'\nIteration: {self.iteration}')
        print(f'Loss: {round(loss, 4)}, '
              f'Social Welfare: {round(SW, 2)}, '
              f'Avg payment: {round(payment, 3)}, '
              f'Avg reward diff: {round(rew_diff, 3)}, '
              f'Accuracy: {round(acc, 3)}, '
              f'One deviate welfare prop: {round(one_deviate_welfare_prop, 3)}, '
              f'All deviate welfare prop: {round(all_deviate_welfare_prop, 3)}, '
              f'Number of episodes: {self.n_episodes}'
              f'\nTime collection: {round(Timer.timers.total("collection"))} sec, '
              f'Time training: {round(Timer.timers.total("training"))} sec, '
              f'Time evaluation: {round(Timer.timers.total("evaluation"))} sec'
              )

        if self.log_wandb:
            if self.reward_type is None:
                mode = 'baseline'
            else:
                if self.reward_type == 'common_reward':
                    if mode == 'agent':
                        mode = 'train_agent'
                else:
                    mode = self.reward_type
                if self.extra_SW_prop > 0:
                    mode = mode + '_extra'

            wandb.log(
                {mode: {
                    'loss': loss,
                    'SW': SW,
                    'one_deviate_welfare_prop': one_deviate_welfare_prop,
                    'all_deviate_welfare_prop': all_deviate_welfare_prop,
                    'avg_payment': payment,
                    'avg_rew_diff': rew_diff,
                    'payment_rew_diff_prop': payment_rew_diff_prop,
                    'accuracy': acc,
                    'n_episodes': self.n_episodes,
                }},
                step=self.iteration,
                commit=True
            )
        self.reset_timers()
        return payment_rew_diff_prop

    @torch.no_grad()
    def render(self, mode='train'):
        if self.render_env is None:
            return
        self.buffer.reset()  # to free some memory
        fn = mode + '.gif' if self.reward_type is not None else 'baseline.gif'

        state, info = self.render_env.reset()
        done = False
        frames = []
        while not done:
            frames.append(self.render_env.render())
            state = self.process_state(state).unsqueeze(0)

            actions = self.act(state)
            if mode == 'val':
                actions = self.act_val(state, actions)
            actions = actions.cpu()

            state, _, terminated, truncated, info = self.render_env.step(actions.squeeze(0).numpy())
            done = terminated or truncated
        save_frames_as_gif(frames, path=RENDER_SAVE_PATH, filename=fn)
