import random
import time
import numpy as np
import torch
import torch.optim as optim


from harvest_sed.principal import Principal
from harvest_sed.training.collection import run_multiple_validation_episodes
from harvest_sed.utils import format_principal_returns, format_taxes
from harvest_sed.utils.logger import logger


class BanditWrapper(Principal):

    def __init__(self, agent_nets, args, bandit, num_discretized_rates):
        super().__init__(args, agent_nets)

        self.num_discretized_rates = num_discretized_rates
        self.upper_bound = 1 if args.env_name == "commons_harvest__open" else 3
        self.agent_opt = optim.Adam(self.agent_nets.parameters(), lr=self.args.agent_lr, eps=self.args.adam_eps)
        self.principal_step = 0
        self.bandit = bandit

    def set_tax_vals(self, ctx, envs):
        if ctx.episode_number % self.args.eps_per_tax_rate == 0:
            self.bandit_action = self.bandit.get_action()
            self.tax_vals_per_game = self.convert_action(self.bandit_action)
            self.principal_step += 1
            envs.apply_principal_action(self.tax_vals_per_game)
            if self.args.reset_agent_nets:
                self.agent_nets.load_state_dict({name: self.og_nets[name] for name in self.og_nets})
        return self.tax_vals_per_game

    def after_episode(self, ctx, envs, episode_buffer, tax_vals_per_game):
        if ctx.episode_number % self.args.eps_per_tax_rate == 0:

            all_val_eps_principal_rewards_per_game, principal_observation = run_multiple_validation_episodes(
                num_val_episodes=self.args.num_val_episodes,
                ctx=ctx,
                envs=envs,
                num_parallel_games=self.args.num_parallel_games,
                episode_length=self.args.episode_length,
                sampling_horizon=self.args.sampling_horizon,
                tax_vals_per_game=tax_vals_per_game,
            )
            mean_principal_reward_per_game = all_val_eps_principal_rewards_per_game.mean(dim=0)

            logger.log_later(
                {
                    "combined_val_train/episode": ctx.total_episode_number,
                    "principal_final/principal_step": self.principal_step,
                    **format_principal_returns(mean_principal_reward_per_game, prefix="principal_final/"),
                    **format_taxes(tax_vals_per_game, prefix="principal_final/"),
                },
                flush=True,
            )
            self.bandit.update_params(self.bandit_action, mean_principal_reward_per_game)

    def convert_action(self, action):

        rate3_index = action % self.num_discretized_rates
        action = action // self.num_discretized_rates
        rate2_index = action % self.num_discretized_rates
        action = action // self.num_discretized_rates
        rate1_index = action % self.num_discretized_rates

        multiplier = self.upper_bound / (self.num_discretized_rates - 1)

        rate1 = rate1_index * multiplier
        rate2 = rate2_index * multiplier
        rate3 = rate3_index * multiplier

        return torch.Tensor([[rate1, rate2, rate3]])


class EpsilonGreedy:

    def __init__(self, arm_count, epsilon, seed, env_name, stepsize="avg"):
        self.epsilon = epsilon
        self.seed = seed
        self.env = env_name
        self.arm_count = arm_count
        self.Q = np.zeros(arm_count)
        self.N = np.zeros(arm_count)
        self.principal_step = 0
        if stepsize == "avg":
            self.get_stepsize = lambda arm: 1 / self.N[arm]
        else:
            self.get_stepsize = lambda arm: stepsize

    def get_action(self):
        if np.random.uniform(0, 1) > self.epsilon:
            action = self.Q.argmax()
        else:
            action = np.random.randint(0, self.arm_count)
        self.principal_step += 1
        if self.principal_step % 200 == 0:
            np.save(f"step_{self.principal_step}_{self.env}_eps_greedy_Q_seed_{self.seed}_{time.time()}.npy", self.Q)
            np.save(f"step_{self.principal_step}_{self.env}_eps_greedy_N_seed_{self.seed}_{time.time()}.npy", self.N)
        return action

    def update_params(self, arm, reward):
        self.N[arm] += 1
        self.Q[arm] += self.get_stepsize(arm) * (reward - self.Q[arm])


class UCB:
    def __init__(self, arm_count, coef, seed, env_name, stepsize="avg"):
        self.coef = coef
        self.seed = seed
        self.env = env_name
        self.arm_count = arm_count
        self.Q = np.zeros(arm_count)
        self.N = np.zeros(arm_count)
        if stepsize == "avg":
            self.get_stepsize = lambda arm: 1 / self.N[arm]
        else:
            self.get_stepsize = lambda arm: stepsize
        self.timestep = 0

        random.seed(seed)
        self.initial_pull_order = random.sample(range(arm_count), arm_count)
        self.initial_pulls_completed = 0

    def get_action(self):
        self.timestep += 1
        if self.timestep % 200 == 0:
            np.save(f"step_{self.timestep}_{self.env}_ucb_Q_step_{self.timestep}_seed_{self.seed}_{time.time()}.npy", self.Q)
            np.save(f"step_{self.timestep}_{self.env}_ucb_N_step_{self.timestep}_seed_{self.seed}_{time.time()}.npy", self.N)

        """ All arms need to be played once to begin with.
        Can make this more concise by adding a small epsilon to self.N
        at initialisation, but this is clearer. """
        if self.initial_pulls_completed < self.arm_count:
            action = self.initial_pull_order[self.initial_pulls_completed]
            self.initial_pulls_completed += 1
            return action

        ln_timestep = np.log(np.full(self.arm_count, self.timestep))
        confidence = self.coef * np.sqrt(ln_timestep / self.N)
        action = np.argmax(self.Q + confidence)
        return action

    def update_params(self, arm, reward):
        self.N[arm] += 1
        self.Q[arm] += self.get_stepsize(arm) * (reward - self.Q[arm])


class ThompsonSampling:
    """Arm distributions assumed Gaussian with unknown mean and variance.
    Conjugate prior is normal-inverse-gamma NIG(mean, count, shape, scale)
    See: https://en.wikipedia.org/wiki/Normal-inverse-gamma_distribution
    Posterior derivation from excellent article https://gertjanvandenburg.com/blog/thompson_sampling/ """

    def __init__(self, arm_count, seed, env_name, prior_mean=0, prior_count=0.05, prior_shape=1, prior_scale=25):
        self.seed = seed
        self.env = env_name
        self.principal_step = 0
        self.prior_mean = prior_mean
        self.prior_count = prior_count
        self.prior_shape = prior_shape
        self.prior_scale = prior_scale

        self.N = np.zeros(arm_count)
        self.mean = np.zeros(arm_count)
        self.rho = np.full(arm_count, self.prior_mean)
        self.ssd = np.zeros(arm_count)
        self.beta = np.full(arm_count, self.prior_scale)

    def get_action(self):
        self.principal_step += 1
        if self.principal_step % 200 == 0:
            np.save(f"step_{self.principal_step}_{self.env}_thompson_N_seed_{self.seed}_{time.time()}.npy", self.N)
            np.save(f"step_{self.principal_step}_{self.env}_thompson_mean_seed_{self.seed}_{time.time()}.npy", self.mean)
            np.save(f"step_{self.principal_step}_{self.env}_thompson_rho_seed_{self.seed}_{time.time()}.npy", self.rho)
            np.save(f"step_{self.principal_step}_{self.env}_thompson_ssd_seed_{self.seed}_{time.time()}.npy", self.ssd)
            np.save(f"step_{self.principal_step}_{self.env}_thompson_beta_seed_{self.seed}_{time.time()}.npy", self.beta)
        sigma2 = 1.0 / np.random.gamma(0.5 * self.N + self.prior_shape, 1.0 / self.beta)
        mus = np.random.normal(self.rho, np.sqrt(sigma2 / (self.N + self.prior_count)))
        return mus.argmax()

    def update_params(self, arm, reward):
        old_N, old_mean = self.N[arm], self.mean[arm]
        self.N[arm] += 1
        self.mean[arm] += 1 / self.N[arm] * (reward - self.mean[arm])
        self.rho[arm] = (self.prior_count * self.prior_mean + self.N[arm] * self.mean[arm]) / (self.prior_count + self.N[arm])
        self.ssd[arm] += reward**2 + old_N * old_mean**2 - self.N[arm] * self.mean[arm] ** 2
        self.beta[arm] = (
            self.prior_scale
            + 0.5 * self.ssd[arm]
            + (self.N[arm] * self.prior_count * (self.mean[arm] - self.prior_mean) ** 2 / (2 * (self.N[arm] + self.prior_count)))
        )
