import os
from abc import ABC, abstractmethod
from copy import deepcopy

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchopt
from tensordict import TensorDict

from harvest_sed.principal import Principal
from harvest_sed.training.collection import (
    run_multiple_validation_episodes,
)
from harvest_sed.utils import format_principal_returns, format_taxes, mod_step
from harvest_sed.utils.context import Metrics
from harvest_sed.utils.logger import logger

class FixedTaxRate(Principal):
    def __init__(self, agent_nets, args, envs, tax_rate):
        super().__init__(args, agent_nets)
        self.tax_vals_per_game = torch.Tensor(tax_rate).repeat((self.args.num_parallel_games, 1))
        envs.apply_principal_action(self.tax_vals_per_game)

        self.agent_opt = optim.Adam(self.agent_nets.parameters(), lr=self.args.agent_lr, eps=self.args.adam_eps)
        logger.log_later({**format_taxes(self.tax_vals_per_game, prefix="tax_per_game/")})

    def set_tax_vals(self, ctx, envs):
        return self.tax_vals_per_game

    def after_episode(self, *unused, **kwargs):
        # note I'd avoid using "*args" due to potential clash with the config, which is called args
        pass


class ValidateTaxRate(Principal):
    def __init__(self, agent_nets, args, envs, tax_rate):
        super().__init__(args, agent_nets)
        assert self.args.total_episodes == self.args.eps_per_tax_rate
        self.tax_vals_per_game = torch.Tensor(tax_rate).repeat((self.args.num_parallel_games, 1))
        envs.apply_principal_action(self.tax_vals_per_game)
        self.agent_opt = optim.Adam(self.agent_nets.parameters(), lr=self.args.agent_lr, eps=self.args.adam_eps)

    def set_tax_vals(self, ctx, envs):
        return self.tax_vals_per_game

    def after_episode(self, ctx, envs, episode_buffer, tax_vals_per_game):
        if ctx.episode_number == self.args.eps_per_tax_rate:

            all_val_eps_principal_rewards_per_game, principal_observation = run_multiple_validation_episodes(
                num_val_episodes=self.args.num_val_episodes,
                ctx=ctx,
                envs=envs,
                num_parallel_games=self.args.num_parallel_games,
                episode_length=self.args.episode_length,
                sampling_horizon=self.args.sampling_horizon,
                tax_vals_per_game=self.tax_vals_per_game,
            )
            mean_reward_across_val_per_game = all_val_eps_principal_rewards_per_game.mean(dim=0)
            std_reward_across_val_per_game = all_val_eps_principal_rewards_per_game.std(dim=0)

            mean_reward = mean_reward_across_val_per_game.mean().item()
            std_reward = std_reward_across_val_per_game.mean().item()

            logger.log_later(
                {
                    f"{self.args.principal}_{self.args.eps_per_tax_rate}converge_{self.args.num_val_episodes}val/average_mean_return": mean_reward,
                    f"{self.args.principal}_{self.args.eps_per_tax_rate}converge_{self.args.num_val_episodes}val/std_mean_return": std_reward,
                },
                flush=True,
            )
            f = open(
                f"{os.path.dirname(os.path.dirname(os.getcwd()))}/{self.args.validation_filename}.txt",
                "a",
            )
            f.write(
                f"{self.args.saved_heads_path}; {self.args.eps_per_tax_rate} convergence {self.args.num_val_episodes} repeats {self.args.principal}: mean {mean_reward} - std {std_reward}\n"
            )
            f.close()


class RandomTaxRate(Principal):
    def __init__(self, agent_nets, args, envs, num_brackets, multiplier):
        super().__init__(args, agent_nets)

        self.num_brackets = num_brackets
        self.multiplier = multiplier

        self.generation_counter = 0

        self.agent_opt = optim.Adam(self.agent_nets.parameters(), lr=self.args.agent_lr, eps=self.args.adam_eps)

    def generate_random_tax_rates(self):
        return (torch.rand(self.num_brackets) * self.multiplier).repeat((self.args.num_parallel_games, 1))

    def set_tax_vals(self, ctx, envs):
        if ctx.episode_number % self.args.eps_per_tax_rate == 0:
            self.tax_vals_per_game = self.generate_random_tax_rates()
            self.generation_counter += 1
            envs.apply_principal_action(self.tax_vals_per_game)

        return self.tax_vals_per_game

    def after_episode(self, ctx, envs, episode_buffer, tax_vals_per_game):
        if ctx.episode_number % self.args.eps_per_tax_rate == 0:

            all_val_eps_principal_rewards_per_game, principal_observation = run_multiple_validation_episodes(
                num_val_episodes=self.args.num_val_episodes,
                ctx=ctx,
                envs=envs,
                num_parallel_games=self.args.num_parallel_games,
                episode_length=self.args.episode_length,
                sampling_horizon=self.args.sampling_horizon,
                tax_vals_per_game=self.tax_vals_per_game,
            )
            mean_reward_across_val_per_game = all_val_eps_principal_rewards_per_game.mean(dim=0)
            std_reward_across_val_per_game = all_val_eps_principal_rewards_per_game.std(dim=0)

            mean_reward = mean_reward_across_val_per_game.mean().item()
            std_reward = std_reward_across_val_per_game.mean().item()

            logger.log_later(
                {
                    "combined_val_train/episode": ctx.total_episode_number,
                    "principal_final/principal_step": self.generation_counter,
                    "principal_final/returns": mean_reward,
                    "principal_final/std": std_reward,
                    **format_taxes(self.tax_vals_per_game, prefix="principal_final/"),
                },
                flush=True,
            )
            logger.log_random_results(self.tax_vals_per_game.numpy(), mean_reward, std_reward, self.generation_counter)
