import warnings
import numpy as np
from skopt import Optimizer
import torch
import torch.optim as optim


from harvest_sed.principal import Principal
from harvest_sed.training.collection import run_multiple_validation_episodes
from harvest_sed.utils import format_principal_returns, format_taxes
from harvest_sed.utils.logger import logger


class GaussianRegression(Principal):

    def __init__(
        self, agent_nets, args, multiplier, n_initial_points=10, acq_func="EI", acq_optimizer="sampling", initial_point_generator="grid"
    ):
        super().__init__(args, agent_nets)
        if args.num_parallel_games > 1:
            raise NotImplementedError
        self.agent_opt = optim.Adam(self.agent_nets.parameters(), lr=self.args.agent_lr, eps=self.args.adam_eps)
        self.principal_step = 0

        self.opt = Optimizer(
            dimensions=[(0.0, multiplier), (0.0, multiplier), (0.0, multiplier)],
            base_estimator="GP",
            n_initial_points=n_initial_points,
            acq_func=acq_func,
            acq_optimizer=acq_optimizer,
            initial_point_generator=initial_point_generator,
        )

    def set_tax_vals(self, ctx, envs):
        if ctx.episode_number % self.args.eps_per_tax_rate == 0:
            self.suggestion = self.opt.ask()
            self.tax_vals_per_game = torch.Tensor([self.suggestion])
            self.principal_step += 1
            envs.apply_principal_action(self.tax_vals_per_game)
            if self.args.reset_agent_nets:
                self.agent_nets.load_state_dict({name: self.og_nets[name] for name in self.og_nets})
        return self.tax_vals_per_game

    def after_episode(self, ctx, envs, episode_buffer, tax_vals_per_game):
        if ctx.episode_number % self.args.eps_per_tax_rate == 0:

            all_val_eps_principal_rewards_per_game, principal_observation = run_multiple_validation_episodes(
                num_val_episodes=self.args.num_val_episodes,
                ctx=ctx,
                envs=envs,
                num_parallel_games=self.args.num_parallel_games,
                episode_length=self.args.episode_length,
                sampling_horizon=self.args.sampling_horizon,
                tax_vals_per_game=tax_vals_per_game,
            )
            mean_principal_reward_per_game = all_val_eps_principal_rewards_per_game.mean(dim=0)

            logger.log_later(
                {
                    "combined_val_train/episode": ctx.total_episode_number,
                    "principal_final/principal_step": self.principal_step,
                    **format_principal_returns(mean_principal_reward_per_game, prefix="principal_final/"),
                    **format_taxes(tax_vals_per_game, prefix="principal_final/"),
                },
                flush=True,
            )

            self.opt.tell(self.suggestion, mean_principal_reward_per_game.item())
