import os
from abc import ABC, abstractmethod
from copy import deepcopy

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchopt
from tensordict import TensorDict

from harvest_sed.principal import Principal
from harvest_sed.principal.llm_utils import (
    api_handler_factory,
    latest_rate_and_return_matching_prompt_vocab,
)
from harvest_sed.principal.prompts import get_prompt_backbone
from harvest_sed.training.collection import (
    run_multiple_validation_episodes,
    run_validation_episode,
)
from harvest_sed.utils import format_principal_returns, format_taxes, mod_step
from harvest_sed.neural.agent_architectures import DesignerNet, PrincipalAgent

from harvest_sed.utils.context import Metrics
from harvest_sed.utils.logger import logger


class Designer(Principal):
    def __init__(self, agent_nets, args, num_agents, device, principal_obs_length, num_brackets):
        super().__init__(args, agent_nets)

        self.device = device
        self.num_agents = num_agents
        self.lr_logging = args.principal_lr
        self.principal_obs_length = principal_obs_length

        """ Tax function net. """
        self.tax_decider = DesignerNet(
            principal_obs_length=principal_obs_length,
            id_hidden_dimension=args.aid_hidden_dimension,
            num_brackets=num_brackets,
            sigmoid_shift=args.aid_sigmoid_shift,
            output_multiplier=(3 if args.env_name == "clean_up" else 1)
        ).to(device)

        self.generation_counter = 0

        """ Optimizer for tax function net. """
        self.optimizer = optim.Adam(
            self.tax_decider.parameters(),
            lr=self.args.principal_lr,
            eps=self.args.adam_eps,
        )

        self.mean_rewards = torch.zeros(self.args.num_parallel_games)
        self.num_mean_updates = 1 # technically we just did one (but mainly just can't divide by zero)

        """ Previous episode rewards we feed as input to tax function - set to None for initialisation in _decide_tax_vals """
        self.principal_observation = None

        """ Agent nets and optimizers, held here as principal is responsible for stepping agent nets. """
        assert (
            self.args.freeze_agent_net_core
        ), "AID on full net is highly memory-intensive. Only actor and critic heads trained."
        # note that use_accelerated_op flag is controlling for explosion in second derivative of Adam sqrt term
        self.agent_actor_opt = torchopt.MetaAdam(
            self.agent_nets.actor,
            lr=self.args.agent_lr,
            eps=self.args.adam_eps,
            use_accelerated_op=True,
        )
        self.agent_critic_opt = optim.Adam(
            self.agent_nets.critic.parameters(),
            lr=self.args.agent_lr,
            eps=self.args.adam_eps,
        )
        self.agent_actor_opt.step = mod_step

    def agent_step(self, metrics):
        """Override standard agent step method.
        Actor head is trained with a differentiable optimizer.
        Critic head is trained with a regular pytorch optimizer.
        "network" submodule, the rest of the agent net, is assumed to be frozen.
        """

        """ Form actor and critic components of PPO loss separately. """
        actor_loss = metrics.pg_loss - self.args.agent_ent_coef * metrics.entropy_loss
        critic_loss = metrics.v_loss * self.args.vf_coef

        """ Step actor head. """
        self.agent_actor_opt.step(self=self.agent_actor_opt, loss=actor_loss, max_grad_norm=self.args.max_grad_norm)

        """ Step critic head. """
        self.agent_critic_opt.zero_grad()
        critic_loss.backward()
        nn.utils.clip_grad_norm_(self.agent_nets.critic.parameters(), self.args.max_grad_norm)
        self.agent_critic_opt.step()

    def _process_observation(self, principal_observation):
        """Private method for processing the total set of observations principal is allowed.

        Args:
            principal_observation (dict[Tensor]): information from validation episode principal is allowed access to,
                                                  or None on first step.

        Returns:
            observation principal will use to generate action.
        """

        if self.args.principal_gets_aux:
            if principal_observation is None:
                """ First generation. """
                obs = torch.zeros(self.args.num_parallel_games, self.principal_obs_length)
            else:
                match self.args.env_name:
                    case "commons_harvest__open":
                        obs = principal_observation["apples_trajectory"].T/64
                    case "clean_up":
                        obs = torch.zeros(self.args.num_parallel_games, self.principal_obs_length)
                        for _ in range(len(principal_observation)):
                            regrowth_traj = principal_observation["regrowth_trajectory"].T
                            num_cleaned_traj = principal_observation["num_cleaned_trajectory"].T
                            downsampled_regrowth = regrowth_traj[:, :: 2]
                            downsampled_num_cleaned = num_cleaned_traj[:, :: 2]
                            obs += torch.cat([downsampled_regrowth, downsampled_num_cleaned], dim=-1)
                        obs /= len(principal_observation)
        else:
            obs = torch.full((self.args.num_parallel_games, self.principal_obs_length), 1).float()

        return obs

    def set_tax_vals(self, ctx, envs):
        """Trainable parameters need to be detached here to prevent flow to previous episodes."""
        torchopt.stop_gradient(self.agent_nets.actor)
        torchopt.stop_gradient(self.agent_actor_opt)

        if self.args.reset_agent_nets:
            self.agent_nets.load_state_dict({name: self.og_nets[name] for name in self.og_nets})

        """ Decide next tax values. """
        tax_vals_per_game = self.tax_decider(self._process_observation(self.principal_observation).to(self.device)).cpu()
        self.generation_counter += 1

        """ Set tax values in the environments. """
        envs.apply_principal_action(tax_vals_per_game)

        """ Return tax values without gradient for bookkeeping - gradient flows through tax rates set in the environments. """
        return tax_vals_per_game.clone().detach()

    def diff_through_ind_set_tax_vals(self, ctx, envs):

        """ Decide next tax values. """
        tax_vals_per_game = self.tax_decider(self._process_observation(self.principal_observation).to(self.device)).cpu()
        self.generation_counter += 1

        """ Set tax values in env with no gradient flowing to rewards. """
        envs.apply_principal_action(tax_vals_per_game.detach().clone())

        """ Keep gradient on tax vals output. """
        return tax_vals_per_game

    def after_episode(self, ctx, envs, episode, tax_vals_per_game):

        """ Collect a validation episode. """
        print("\nCOLLECTING VALIDATION EPISODE\n")
        validation_episode = run_validation_episode(
            ctx,
            envs,
            self.args.num_parallel_games,
            self.args.episode_length,
            self.args.sampling_horizon,
            tax_vals_per_game,
            keep_log_prob_grads=True,
            log_prefix="validation/"
        )

        """ Record information from validation episode that principal is allowed access to. """
        self.principal_observation = validation_episode.get_episode_principal_observation()

        logger.log_later(
            {
                "combined_val_train/episode": ctx.total_episode_number,
                "principal_final/principal_step": self.generation_counter,
                "principal_final/principal_lr": self.lr_logging,
                "principal_final/principal_running_mean": self.mean_rewards,
                **format_principal_returns(validation_episode.principal_cumulative_reward, prefix="principal_final/"),
                **format_taxes(tax_vals_per_game, prefix="principal_final/"),
            },
            flush=True,
        )

        """ Step tax function. """
        print("\nSTEPPING TAX FUNCTION\n")
        self._step_tax_function(validation_episode)

    def _step_tax_function(self, validation_episode):
        """One step on tax function parameters, using data from a validation episode."""

        """ Principal reward-to-go in validation trajectory. """
        ones = torch.ones_like(validation_episode.principal_reward_trajectory)
        # gamma_prod is $$[\gamma, \gamma^2, \dots, \gamma^n]$$
        gamma_prod = torch.cumprod(ones * self.args.gamma, dim=0)
        # principal_returns here is $$\left[\sum_{t=1}^n{r_t\gamma^t}, \sum_{t=2}^n{r_t\gamma^t}, \dots, r_n\gamma^n\right]$$
        principal_returns = torch.flip(
            torch.cumsum(
                torch.flip(validation_episode.principal_reward_trajectory * gamma_prod, dims=[0]),
                dim=0,
            ),
            dims=[0],
        )
        # principal_returns here is $$\left[\left(r_1+r_2\gamma+\dots\right),\left(r_2+r_3\gamma+\dots\right),\dots, r_n\right]$$
        principal_returns = principal_returns / gamma_prod


        principal_rewards_per_game = validation_episode.principal_reward_trajectory.sum(dim=0)
        baselined_total_reward = principal_rewards_per_game - self.mean_rewards


        """ Update running mean value estimator. (i.e. "step critic estimator")"""
        self.mean_rewards += (principal_rewards_per_game - self.mean_rewards) / self.num_mean_updates
        self.num_mean_updates += 1


        """ Principal policy gradient loss.

        Agent policy parameters have been updated in training episode using rewards dependent on tax function parameters.
        Since these new policies were update differentiably, their parameters have gradient dependency on tax function parameters.
        In validation episode, we collected log-probabilities of actions produced by these updated nets.
        These log-probabilities retain gradient dependency on tax function parameters.
        We use them in a policy gradient loss with the principal's validation episode reward trajectory to step tax function parameters.
        """
        # Action log-probabilities in validation episode for updated agent policies, summed over agents in each parallel game.
        sum_agents_log_probs = validation_episode.agent_episode_log_probs.sum(dim=-1)
        # Policy gradient loss to differentiate back tax function's effect on principal returns in validation episode.
        principal_pg_loss = -(sum_agents_log_probs * baselined_total_reward).sum()

        """ Step tax function """
        principal_pg_loss.backward()
        nn.utils.clip_grad_norm_(self.tax_decider.parameters(), self.args.max_grad_norm)
        self.optimizer.step()
        self.optimizer.zero_grad()

    def save_params(self, episode_number):
        """Override standard save method to also save tax function parameters."""
        self.make_save_dir(episode_number)
        torch.save(
            self.agent_nets.state_dict(),
            f"./saved_params/ep{episode_number}/seed_{self.args.seed}_gets_aux_{self.args.principal_gets_aux}_agents_net_ep{episode_number}.pt",
        )
        torch.save(
            self.tax_decider.state_dict(),
            f"./saved_params/ep{episode_number}/seed_{self.args.seed}_gets_aux_{self.args.principal_gets_aux}_tax_net_ep{episode_number}.pt",
        )


class DualRLPrincipal(Principal):
    def __init__(self, agent_nets, args, num_agents, device, principal_obs_length, num_brackets):
        super().__init__(args, agent_nets)

        self.device = device
        self.num_agents = num_agents
        self.principal_obs_length = principal_obs_length

        """ Principal policy net. """
        self.principal_agent = PrincipalAgent(principal_obs_length, num_brackets, args.dual_rl_hidden_dim, args.dual_rl_num_hidden_layers).to(device)        

        """ Optimizer for principal policy net. """
        self.optimizer = optim.Adam(
            self.principal_agent.parameters(),
            lr=self.args.principal_lr,
            eps=self.args.adam_eps,
        )

        self.principal_step = 0

        """ Since we have trajectories of length one only here, the value function for a policy is just
        the expected reward over actions sampled from this policy. We estimate this value function by
        keeping a running mean of rewards received so far. This won't be perfect, but should be a good
        enough baseline to reduce variance - and allows us to leave out critic nets. """
        self.mean_rewards = torch.zeros(self.args.num_parallel_games)

        """ Since this principal can produce NO-OP "do not change tax rate" actions, we must initialise some first tax rate for NO-OP masking.
        This is not first tax rate set in the environment. That is stored in the field last_tax_rates_set and is initialised by call to policy net. """
        self.previous_proposed_tax_rates = torch.zeros(self.args.num_parallel_games, num_brackets)

        """ Agent optimizer, held here as principal is responsible for stepping agent nets. """
        self.agent_opt = optim.Adam(self.agent_nets.parameters(), lr=self.args.agent_lr, eps=self.args.adam_eps)

    def _process_observation(self, principal_observation):
        """Private method for processing the total set of observations principal is allowed.

        Args:
            principal_observation (dict[Tensor]): information from validation episode principal is allowed access to,
                                                  or None on first step.

        Returns:
            observation principal will use to generate action.
        """
        if self.args.principal_gets_aux:
            if principal_observation is None:
                """ First generation. """
                obs = torch.zeros(self.args.num_parallel_games, self.principal_obs_length)
            else:
                match self.args.env_name:
                    case "commons_harvest__open":
                        obs = principal_observation[-1]["apples_trajectory"].T/64
                    case "clean_up":
                        obs = torch.zeros(self.args.num_parallel_games, self.principal_obs_length)
                        for _ in range(len(principal_observation)):
                            regrowth_traj = principal_observation[-1]["regrowth_trajectory"].T
                            num_cleaned_traj = principal_observation[-1]["num_cleaned_trajectory"].T
                            downsampled_regrowth = regrowth_traj[:, :: 2]
                            downsampled_num_cleaned = num_cleaned_traj[:, :: 2]
                            obs += torch.cat([downsampled_regrowth, downsampled_num_cleaned], dim=-1)
                        obs /= len(principal_observation)
        else:
            obs = torch.full((self.args.num_parallel_games, self.principal_obs_length), 1).float()

        return obs

    def _decide_tax_vals(self, principal_observation):
        """Private method for generating new tax rates.

        Args:
            principal_observation (Tensor[num_parallel_games, principal_obs_length]): principal observation

        Returns:
            proposed_tax_vals_per_game (Tensor[num_parallel_games, num_brackets] - no gradient): new tax values, proposed as actions by policy network
        """

        """ Extract information to act on from information principal is allowed to observe.
        Save as a field as a one-step buffer for optimisation later. """
        self.observation_to_act_on = self._process_observation(principal_observation).to(self.device)

        """ Produce output from policy net. This is a one-step trajectory we will update from in optimization later. """
        with torch.no_grad():
            # this will already be on device and we leave it there for optimisation steps after episode.
            tax_decision_data: TensorDict = self.principal_agent(self.observation_to_act_on)

        """ Retrieve action sampled from policy net. """
        action = tax_decision_data["actions"].cpu()

        """ Mask out action 21, which is a NO-OP corresponding to leaving a tax rate unchanged. """
        no_op_mask = action == 21

        """ Discretising [0,1] into intervals of 0.05, with 1.05s to ignore where we had NO-OP 21s -- equivalent for cleanup [0,3] """
        if self.args.env_name == "clean_up":
            unmasked_tax_vals_per_game = action * 0.15
        else:
            unmasked_tax_vals_per_game = action / 20


        """ Apply no-op mask to repeat previous tax values where needed. """
        proposed_tax_vals_per_game = torch.where(
            no_op_mask, self.previous_proposed_tax_rates, unmasked_tax_vals_per_game
        )

        """ Return proposed new tax values, before any maximum tax cap clipping, and one-step trajectory to step from when optimizing. """
        return proposed_tax_vals_per_game, tax_decision_data

    def set_tax_vals(self, ctx, envs):
        """Will be triggered at episode 0 and at start of every tax period."""
        if ctx.episode_number % self.args.eps_per_tax_rate == 0:
            """If we have also just finished a tax period."""
            if ctx.episode_number > 0:

                """ Run multiple validation episodes. 
                Aggregates mean principal reward per game within these episodes,
                as well as all information principal is allowed to use in deciding next action. """
                all_val_eps_principal_rewards_per_game, principal_observation = run_multiple_validation_episodes(
                    num_val_episodes=self.args.num_val_episodes,
                    ctx=ctx,
                    envs=envs,
                    num_parallel_games=self.args.num_parallel_games,
                    episode_length=self.args.episode_length,
                    sampling_horizon=self.args.sampling_horizon,
                    tax_vals_per_game=self.last_tax_rates_set,
                    log_prefix=f"validation/",
                )
                mean_principal_reward_per_game = all_val_eps_principal_rewards_per_game.mean(dim=0)

                logger.log_later(
                    {
                        f"principal_final/principal_step": self.principal_step,
                        **format_principal_returns(mean_principal_reward_per_game, prefix="principal_final/"),
                        **format_taxes(
                            self.last_tax_rates_set, prefix="principal_final/"
                        ),
                    },
                    flush=True,
                )
                """ Step principal nets. """
                self._step_nets(ctx, mean_principal_reward_per_game)

                """ Reset agent nets to their original state. """
                if self.args.reset_agent_nets:
                    self.agent_nets.load_state_dict({name: self.og_nets[name] for name in self.og_nets})
            else:
                principal_observation = None

            """ Generate new proposed tax rates and data collected along this one-step trajectory.
            Uses principal observation generated from validation episodes just run, or None if first step. """
            proposed_tax_vals_per_game, tax_decision_data = self._decide_tax_vals(principal_observation)

            self.principal_step += 1

            """ Store buffer of one-step trajectory for use in optimization as a field. """
            self.tax_decision_data = (
                tax_decision_data  # already on device as was created there and hasn't been moved off
            )

            """ Store proposed tax rates for use with NO-OPs next time. """
            self.previous_proposed_tax_rates = proposed_tax_vals_per_game

            """ Cap proposed tax rates according to annealed maximum tax rate cap. """
            if self.args.initial_max_tax_rate < 1:
                max_tax_rate = min(
                    1.0,
                    self.args.initial_max_tax_rate
                    + (1.0 - self.args.initial_max_tax_rate)
                    * (ctx.episode_number / (self.args.num_tax_annealment_episodes - 1)),
                )
                capped_tax_vals_per_game = torch.min(
                    proposed_tax_vals_per_game,
                    torch.full_like(proposed_tax_vals_per_game, max_tax_rate),
                )
            else:
                capped_tax_vals_per_game = proposed_tax_vals_per_game

            """ Set capped tax rates in the environments. """
            envs.apply_principal_action(capped_tax_vals_per_game)

            self.last_tax_rates_set = capped_tax_vals_per_game

        """ Return the capped tax rates that we set in the environments. """
        return self.last_tax_rates_set

    def _step_nets(self, ctx, principal_rewards_per_game):
        """Private method for updating principal policy and critic (which here is just a running mean)."""
        
        if self.args.dual_rl_use_running_mean:
            """ In this one-step trajectory case, reward received is exactly the Q function.
            Mean reward approximates value function, yielding an estimate of the advantages. """
            advantage_estimates = principal_rewards_per_game - self.mean_rewards
        else:
            advantage_estimates = principal_rewards_per_game.to(self.device) - self.tax_decision_data["values"]
            returns_estimates = principal_rewards_per_game.to(self.device)

        """ Update running mean value estimator. (i.e. "step critic estimator")"""
        self.mean_rewards += (principal_rewards_per_game - self.mean_rewards) / self.principal_step

        """ Loop over desired number of update epochs on this one-step trajectory. """
        for epoch in range(self.args.dual_rl_principal_num_policy_updates_per_collection_update):
            """New policy net outputs, given original sigma observation and action from this one-step trajectory."""
            new_net_outputs: TensorDict = self.principal_agent(
                self.observation_to_act_on,
                self.tax_decision_data["actions"],  # already on device from when it was generated
            )

            """ Compute policy loss and training metrics. """
            metrics: Metrics = ctx.alg.get_policy_loss(
                new_agent_net_outputs=new_net_outputs,
                trajectory_agent_net_outputs=self.tax_decision_data,
                advantages=advantage_estimates.to(self.device),
                norm_adv=False,
                clip_coef=self.args.principal_clip_coef,
            )
            if self.args.dual_rl_use_running_mean:
                loss = metrics.pg_loss - self.args.principal_ent_coef * metrics.entropy_loss
            else:
                metrics.v_loss = ctx.alg.get_value_loss(
                    new_agent_net_outputs=new_net_outputs,
                    trajectory_agent_net_outputs=self.tax_decision_data,
                    returns=returns_estimates,
                    clip_vloss=self.args.clip_vloss,
                    clip_coef=self.args.value_clip_coef,
                )
                """ Form policy loss - no critic so no value loss. """
                loss = metrics.pg_loss - self.args.principal_ent_coef * metrics.entropy_loss + metrics.v_loss * self.args.principal_vf_coef

            """ Step policy net. """
            self.optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(self.principal_agent.parameters(), self.args.max_grad_norm)
            self.optimizer.step()

            # logger.log_distribution(new_net_outputs["distribution"][0],0,self.principal_step, epoch)

            logger.log_later(
                {
                    "principal_opt/step": ((ctx.episode_number // self.args.eps_per_tax_rate) -1) * self.args.num_policy_updates_per_collection_update + epoch,
                    f"principal_opt/policy_update_epoch": epoch,
                    f"principal_opt/pg_loss": metrics.pg_loss,
                    f"principal_opt/running_mean": self.mean_rewards.item(),
                    f"principal_opt/value_estimate": self.tax_decision_data["values"].item(),
                    f"principal_opt/entropy_loss": metrics.entropy_loss,
                    f"principal_opt/approx_kl": metrics.approx_kl,
                    f"principal_opt/clipfrac": np.mean(metrics.clipfracs),
                },
                flush=True,
            )

    def after_episode(self, *unused, **kwargs):
        pass

    def save_params(self, episode_number):
        """Override standard save method to also save principal policy net parameters."""
        self.make_save_dir(episode_number)
        torch.save(
            self.agent_nets.state_dict(),
            f"./saved_params/ep{episode_number}/agents_net_ep{episode_number}.pt",
        )
        torch.save(
            self.principal_agent.state_dict(),
            f"./saved_params/ep{episode_number}/principal_agent_net_ep{episode_number}.pt",
        )


class LLMPrincipal(Principal):
    def __init__(self, agent_nets, args, num_brackets, envs) -> None:
        super().__init__(args, agent_nets)
  
        assert args.num_parallel_games == 1, "LLM method configured for one parallel game only."
        assert args.saved_core_path != "", "LLM method requires a saved core path."

        self.agent_opt = optim.Adam(self.agent_nets.parameters(), lr=self.args.agent_lr, eps=self.args.adam_eps)

        self.rates_and_returns_in_natural_language = ""
        self.generation_counter = 1
        self.max_tax = 1.0 if args.env_name == "commons_harvest__open" else 3.0
        self.total_generations = args.total_principal_steps
        self.prompt_backbone = get_prompt_backbone(args.llm_prompt_style, self.total_generations)
        self.api_handler = api_handler_factory(args.llm_model, args.temperature)
        
        self.num_brackets = num_brackets

        """ Initialise first tax rate as all zeros. """
        self.current_tax_tensor = self._decide_tax_vals()

        """ Set first tax rate in the environment. Game ID is zero as we require one parallel game only. """
        envs.apply_principal_action(self.current_tax_tensor)
    
    def set_tax_vals(self, ctx, envs):
        if ctx.episode_number > 0 and ctx.episode_number % self.args.eps_per_tax_rate == 0:
            """This runs validation episodes, records results to history, and logs LLM return."""

            self.evaluate_and_record(ctx, envs, val_log_prefix="validation/")

            self.generation_counter += 1

            """ Decide a new tax rate and store it in current tax rate field. """
            # self.current_tax_tensor is of shape (num_parallel_games, num_brackets), which here is (1, num_brackets)
            self.current_tax_tensor = self._decide_tax_vals()

            """ Set new tax rate in environment. """
            envs.apply_principal_action(self.current_tax_tensor)

            """ Reset agent nets, readying them for more convergence episodes under new tax rate. """
            if self.args.reset_agent_nets:
                self.agent_nets.load_state_dict({name: self.og_nets[name] for name in self.og_nets})

        return self.current_tax_tensor

    def _decide_tax_vals(self):
        """See prompts.py for a written out example of the prompt structure"""
        prompt = (
            self.prompt_backbone["general_explanation"]
            + self.prompt_backbone["provide_history"]
            + self.rates_and_returns_in_natural_language
            + self.prompt_backbone["reminder"]
        )
        if self.generation_counter == 1:
            prompt += "This is your first attempt"

        if self.generation_counter == self.total_generations:
            prompt += self.prompt_backbone["last_try"]

        new_tax_rate = self.generate_and_log_response(
            prompt=prompt,
        )
        return torch.Tensor([new_tax_rate])

    def generate_and_log_response(self, prompt):
        try:
            """Here, we only want to log the prompt/response if the reponse was able to be validated."""
            unvalidated_repsonse = self.api_handler.query_llm(
                prompt
            )
            tax_rate = self.api_handler.parse_and_validate_response(unvalidated_repsonse, self.num_brackets)
            logger.log_prompt_and_response(
                prompt=prompt,
                response=unvalidated_repsonse,
                generation_number=self.generation_counter,
            )
        except Exception as e:
            """Here, we log the prompt/response before validation, so that if an error is raised we can see why"""
            prompt += "It is vitally important that your tax rate is surrounded by one set of dollar signs."
            unvalidated_repsonse = self.api_handler.query_llm(
                prompt
            )
            logger.log_prompt_and_response(
                prompt=prompt,
                response=unvalidated_repsonse,
                generation_number=self.generation_counter,
            )
            tax_rate = self.api_handler.parse_and_validate_response(unvalidated_repsonse, self.num_brackets)

        for tax in range(len(tax_rate)):
            if tax_rate[tax] > self.max_tax:
                tax_rate[tax] = self.max_tax
            if tax_rate[tax] < 0:
                tax_rate[tax] = 0
                
        return tax_rate

    def after_episode(self, ctx, envs, episode_buffer, tax_vals_per_game):
        self.current_tax_tensor = tax_vals_per_game

    def _process_observation(self, principal_observation):
        obs = {}
        match self.args.env_name:
            case "commons_harvest__open":
                obs['mean_cumulative_raw_rewards'] = torch.stack(
                    list(principal_observation[i]["cumulative_agent_raw_rewards"] for i in range(len(principal_observation)))
                ).mean(dim=0)
                obs['mean_apple_trajectory'] = torch.stack(
                    list(principal_observation[i]["apples_trajectory"] for i in range(len(principal_observation)))
                ).mean(dim=0)
            case "clean_up":
                obs['mean_cumulative_raw_rewards'] = torch.stack(
                    list(principal_observation[i]["cumulative_agent_raw_rewards"] for i in range(len(principal_observation)))
                ).mean(dim=0)
                obs['mean_cumulative_num_cleaned'] = torch.stack(
                    list(principal_observation[i]["num_cleaned"] for i in range(len(principal_observation)))
                ).mean(dim=0)
                obs['mean_regrowth_trajectory'] = torch.stack(
                    list(principal_observation[i]["regrowth_trajectory"] for i in range(len(principal_observation)))
                ).mean(dim=0)
        return obs
    
    def evaluate_and_record(self, ctx, envs, val_log_prefix):
        """
        After self.eps_per_tax_rate episodes of agent training on the latest tax rate, we copy their nets and run self.num_val_episodes validation episodes
        on the agents, resetting their nets to the copied ones each time. We then take the mean of the cumulative rewards from these episodes to reduce the
        variance in the LLM's reward signal. This mean is then used as the reward for the LLM's tax rate generation, which is added to the history.
        """

        all_val_eps_principal_rewards_per_game, principal_observation = run_multiple_validation_episodes(
                num_val_episodes=self.args.num_val_episodes,
                ctx=ctx,
                envs=envs,
                num_parallel_games=self.args.num_parallel_games,
                episode_length=self.args.episode_length,
                sampling_horizon=self.args.sampling_horizon,
                tax_vals_per_game=self.current_tax_tensor,
                log_prefix=val_log_prefix,
            )

        # assuming only one parallel game so can take mean without specifying dimension
        mean_reward = all_val_eps_principal_rewards_per_game.mean().item()
        std_reward = all_val_eps_principal_rewards_per_game.std().item()

        tax = self.current_tax_tensor[0].numpy()

        val_obs = self._process_observation(principal_observation)

        latest_rate_and_return = latest_rate_and_return_matching_prompt_vocab(
            self.args.env_name,
            self.args.llm_gets_aux,
            self.generation_counter,
            tax,
            val_obs,
            self.args.episode_length,
            mean_reward,
        )
        self.rates_and_returns_in_natural_language += latest_rate_and_return

        logger.log_later(
            {
                "combined_val_train/episode": ctx.total_episode_number,
                "principal_final/principal_step": self.generation_counter,
                f"principal_final/returns": mean_reward,
                f"principal_final/std": std_reward,
                **format_taxes(self.current_tax_tensor, prefix="principal_final/"),
            },
            flush=True,
        )
