from dataclasses import asdict, dataclass, field
from math import sqrt
from typing import Any, Literal, overload

from util.argparser_dataclass import parse_args_to_dataclass


@dataclass
class LoggingConfig:
    log: Literal["wandb", "none"] = field(default="wandb", metadata={"help": "Logging method to use"})
    debug: str | None = field(default=None)


@dataclass
class SeedConfig:
    seed: int = field(default=0, metadata={"help": "Integer to use for python, torch, and cuda seeds"})


@dataclass
class NSeedsConfig:
    n_seeds: int = field(default=10, metadata={"help": "Number of seeds"})


@dataclass
class DatasetConfig:
    env: str = field(metadata={"help": "Environment type"})
    n_envs: int = field(default=100000, metadata={"help": "Number of environments"})
    context_len: int = field(default=100, metadata={"help": "Context length, context horizon"})
    n_states: int = field(default=5, metadata={"help": "Number of states for configurable environments"})
    n_actions: int = field(default=5, metadata={"help": "Number of actions for configurable environments (e.g., arms)"})
    variance: float = field(default=0.3, metadata={"help": "Variance of the noise added to the rewards"})

    def get_summary(self) -> str:
        summary = f"{self.env}"
        summary += f"_envs{self.n_envs}"
        summary += f"_ctxlen{self.context_len}"
        # FIXME: leaky abstraction
        if self.env == "darkroom" or self.env == "chain":
            summary += f"_states{self.n_states}"
        if self.env == "bandit":
            summary += f"_actions{self.n_actions}"
        if self.env == "bandit" or self.env == "chain":
            summary += f"_variance{self.variance}"
        return summary


@dataclass
class ModelConfig:
    arch: str | None = field(default=None, metadata={"help": "Architecture to use, from a set of standard ones"})

    shuffle: bool = field(default=False)
    lr: float = field(default=0.0001)
    dropout: float = field(default=0.0)
    n_embd: int = field(default=32)
    n_layer: int = field(default=4)
    n_head: int = field(default=4)

    n_epochs: int = field(default=10, metadata={"help": "Number of epochs"})

    def __post_init__(self):
        if self.arch is None:
            return

        if self.arch == "1":
            self.shuffle = True
            self.lr = 0.0001
            self.dropout = 0.0
            self.n_embd = 32
            self.n_layer = 4
            self.n_head = 4
        else:
            raise NotImplementedError()

    def get_summary(self) -> str:
        if self.arch is None:
            summary = f"shuffle{self.shuffle}"
            summary += f"lr{self.lr}"
            summary += f"do{self.dropout}"
            summary += f"embd{self.n_embd}"
            summary += f"layer{self.n_layer}"
            summary += f"head{self.n_head}"
        else:
            summary = f"arch{self.arch}"

        summary += f"_epochs{self.n_epochs}"
        return summary

    def get_params(self, other: dict[str, Any] = {}) -> dict[str, Any]:
        return {**vars(self), **other}


@dataclass
class PPOConfig:
    lr: float = field(default=0.002)
    gamma: float = field(default=0.95, metadata={"help": "Discount factor"})
    lam: float = field(default=0.95, metadata={"help": "Lambda for GAE"})
    clip: float = field(default=0.2, metadata={"help": "Clipping range for policy loss"})
    clipvf: float | None = field(default=None, metadata={"help": "Clipping range for value function loss"})
    n_layer: int = field(default=2)
    n_hidden: int = field(default=256, metadata={"help": "Number of hidden units between linear (hidden dim)"})
    n_episodes: int = field(default=5, metadata={"help": "Number of episodes to collect for training"})
    value_loss_coef: float = field(default=0.5)
    entropy_bonus_coef: float = field(default=0.01)

    n_epochs: int = field(default=10)

    def get_summary(self) -> str:
        summary = "ppo_"
        summary += f"lr{self.lr}"
        summary += f"gamma{self.gamma}"
        summary += f"lam{self.lam}"
        summary += f"clip{self.clip}clipvf{self.clipvf}"
        summary += f"layer{self.n_layer}"
        summary += f"hidden{self.n_hidden}"
        summary += f"episodes{self.n_episodes}"
        summary += f"valuelosscoef{self.value_loss_coef}"
        summary += f"entropybonuscoef{self.entropy_bonus_coef}"

        summary += f"_epochs{self.n_epochs}"
        return summary


@dataclass
class PPOMWConfig:
    ppo_lr: float = field(default=2.5e-4)
    gamma: float = field(default=0.99)
    gae_lambda: float = field(default=0.95)
    clip_coef: float = field(default=0.1)
    norm_adv: bool = field(default=True)
    clip_vloss: bool = field(default=True)
    ent_coef: float = field(default=0.01)
    vf_coef: float = field(default=0.5)
    max_grad_norm: float = field(default=0.5)
    target_kl: float | None = field(default=None)
    update_epochs: int = field(default=4)


@dataclass
class NPGConfig:
    lr: float = field(default=0.0003)
    gamma: float = field(default=0.98, metadata={"help": "Discount factor"})
    lam: float = field(default=0.95, metadata={"help": "Lambda for GAE"})
    n_episodes: int = field(default=5, metadata={"help": "Number of episodes to collect for training"})
    n_epochs: int = field(default=10)

    def get_summary(self) -> str:
        summary = "npg_"
        summary += f"lr{self.lr}"
        summary += f"gamma{self.gamma}"
        summary += f"lam{self.lam}"
        summary += f"episodes{self.n_episodes}"

        summary += f"_epochs{self.n_epochs}"
        return summary


@dataclass
class QLearningConfig:
    lr: float = field(default=0.1)
    gamma: float = field(default=0.95, metadata={"help": "Discount factor"})
    n_episodes: int = field(default=1, metadata={"help": "Number of episodes to collect for training"})
    frac_greedy_start: float = field(default=0.9, metadata={"help": "Epsilon greedy value to start with"})
    frac_greedy_end: float = field(default=0.05, metadata={"help": "Epsilon greedy value to end with"})
    frac_greedy_decay_factor: float = field(default=1000, metadata={"help": "Epsilon greedy exponential decay factor"})

    n_epochs: int = field(default=100)

    def get_summary(self) -> str:
        summary = "qlearning_"
        summary += f"lr{self.lr}"
        summary += f"gamma{self.gamma}"
        summary += f"episodes{self.n_episodes}"
        summary += f"_greedy{self.frac_greedy_start}to{self.frac_greedy_end}decay{self.frac_greedy_decay_factor}"

        summary += f"_epochs{self.n_epochs}"
        return summary


@dataclass
class EvalConfig:
    epoch: int | None = field(default=None, metadata={"help": "Epoch checkpoint of the model to evaluate. Leave blank for final checkpoint"})
    n_envs_eval: int = field(default=100, metadata={"help": "Number of environments to run in parallel for eval"})
    n_steps_eval: int | None = field(default=None, metadata={"help": "Maximum number of steps in the environment during evaluation"})

    def get_summary(self) -> str:
        summary = f"epoch{self.epoch}"
        summary += f"_evalenvs{self.n_envs_eval}steps{self.n_steps_eval}"
        return summary


@dataclass
class AdversarialTrainingConfig:
    n_rounds: int = field(default=20, metadata={"help": "Number of adversarial training rounds to perform"})
    eps_episodes: float = field(default=0.8, metadata={"help": "Fraction of episodes poisoned"})
    eps_steps: float = field(default=0.4, metadata={"help": "Fraction of steps within an episode poisoned"})
    victim_iters: int = field(default=20, metadata={"help": "Number of iterations the victim should be trained for per dataset"})
    victim_lr: float = field(default=0.00003, metadata={"help": "Learning rate for the victim"})
    attacker_iters: int = field(default=20, metadata={"help": "Number of iterations the attacker should be trained for per dataset"})
    attacker_lr: float = field(default=0.03, metadata={"help": "Learning rate for the attacker"})
    budget_regularizer: float = field(default=10, metadata={"help": "Regularizer constant for the budget in the loss function for the attacker"})
    max_poison_diff: float = field(default=3.0, metadata={"help": "Maximum distance from the original means the attacker is able to poison"})
    attacker_against: str = field(default="dpt", metadata={"help": "Which agent is the attack trying to poison"})
    log_round_rewards: bool = field(default=False)

    def get_summary(self, print_against: bool = True) -> str:
        summary = f"rounds{self.n_rounds}"
        summary += f"_epse{self.eps_episodes}"
        summary += f"_epss{self.eps_steps}"
        summary += f"_victimiters{self.victim_iters}lr{self.victim_lr}"
        summary += f"_attackeriters{self.attacker_iters}lr{self.attacker_lr}"
        summary += f"_maxpoisondiff{self.max_poison_diff}reg{self.budget_regularizer}"
        if print_against:
            summary += f"_against{self.attacker_against}"
        return summary


@dataclass
class AdaptiveAttackerConfig:
    adaptiveatt_arch: str | None = field(default=None, metadata={"help": "Architecture to use, from a set of standard ones"})

    adaptiveatt_shuffle: bool = field(default=False)
    adaptiveatt_lr: float = field(default=0.0001)
    adaptiveatt_dropout: float = field(default=0.0)
    adaptiveatt_n_embd: int = field(default=32)
    adaptiveatt_n_layer: int = field(default=4)
    adaptiveatt_n_head: int = field(default=4)

    adaptiveatt_context_len: int = field(default=500)

    def __post_init__(self):
        if self.adaptiveatt_arch is None:
            return

        if self.adaptiveatt_arch == "1":
            self.adaptiveatt_shuffle = True
            self.adaptiveatt_lr = 0.0001
            self.adaptiveatt_dropout = 0.0
            self.adaptiveatt_n_embd = 32
            self.adaptiveatt_n_layer = 4
            self.adaptiveatt_n_head = 4
        else:
            raise NotImplementedError()

    @property
    def arch(self) -> str | None:
        return self.adaptiveatt_arch

    @property
    def shuffle(self) -> int:
        return self.adaptiveatt_shuffle

    @property
    def lr(self) -> float:
        return self.adaptiveatt_lr

    @property
    def dropout(self) -> float:
        return self.adaptiveatt_dropout

    @property
    def n_embd(self) -> int:
        return self.adaptiveatt_n_embd

    @property
    def n_layer(self) -> int:
        return self.adaptiveatt_n_layer

    @property
    def n_head(self) -> int:
        return self.adaptiveatt_n_head

    @property
    def context_len(self) -> int:
        return self.adaptiveatt_context_len

    def get_summary(self) -> str:
        summary = "adaptiveatt_"
        if self.adaptiveatt_arch is None:
            summary += f"shuffle{self.adaptiveatt_shuffle}"
            summary += f"lr{self.adaptiveatt_lr}"
            summary += f"do{self.adaptiveatt_dropout}"
            summary += f"embd{self.adaptiveatt_n_embd}"
            summary += f"layer{self.adaptiveatt_n_layer}"
            summary += f"head{self.adaptiveatt_n_head}"
        else:
            summary += f"arch{self.adaptiveatt_arch}"

        return summary

    def get_params_unprefixed(self, other: dict[str, Any] = {}) -> dict[str, Any]:
        return {**{key.removeprefix("adaptiveatt_"): val for key, val in vars(self).items()}, **other}


def get_model_name(dataset_config: DatasetConfig, model_config: ModelConfig | PPOConfig | NPGConfig | QLearningConfig) -> str:
    return f"{dataset_config.get_summary()}_{model_config.get_summary()}"


@overload
def get_model_save_name(base_name: str, seed_config: SeedConfig) -> str: ...
@overload
def get_model_save_name(base_name: str, seed_config: SeedConfig, max_epochs: int, epoch: int | None) -> str: ...
def get_model_save_name(base_name: str, seed_config: SeedConfig, max_epochs: int | None = None, epoch: int | None = None) -> str:
    name = f"{base_name}_seed{seed_config.seed}"
    if epoch != None and (max_epochs is not None and epoch != max_epochs):
        name += f"_epoch{epoch}"
    return name


def get_adv_trained_model_name(
    dataset_config: DatasetConfig,
    model_config: ModelConfig,
    eval_config: EvalConfig,
    adv_train_config: AdversarialTrainingConfig,
    *,
    print_against: bool = True,
) -> str:
    return f"{get_model_name(dataset_config, model_config)}_{eval_config.get_summary()}_{adv_train_config.get_summary(print_against=print_against)}"


def get_adaptive_adv_trained_model_name(
    dataset_config: DatasetConfig,
    model_config: ModelConfig,
    eval_config: EvalConfig,
    adv_train_config: AdversarialTrainingConfig,
    adaptive_attacker_config: AdaptiveAttackerConfig,
    *,
    print_against: bool = True,
) -> str:
    return f"{get_adv_trained_model_name(dataset_config, model_config, eval_config, adv_train_config, print_against=print_against)}_{adaptive_attacker_config.get_summary()}"


def get_legacy_filename_config(model_config: ModelConfig, dataset_config: DatasetConfig, seed_config: SeedConfig) -> dict[str, Any]:
    return {
        **asdict(model_config),
        "n_envs": dataset_config.n_envs,
        "n_hists": 1,
        "n_samples": 1,
        "H": dataset_config.context_len,
        "dim": dataset_config.n_actions,
        "seed": seed_config.seed,
        "var": dataset_config.variance,
        "cov": 0.0,
    }


def get_legacy_darkroom_config(model_config: ModelConfig, dataset_config: DatasetConfig, seed_config: SeedConfig) -> dict[str, Any]:
    square_len = int(sqrt(dataset_config.n_states))
    return {
        **asdict(model_config),
        "n_envs": dataset_config.n_envs,
        "n_hists": 1,
        "n_samples": 1,
        "H": dataset_config.context_len,
        "dim": square_len,
        "seed": seed_config.seed,
    }


def get_legacy_miniworld_config(model_config: ModelConfig, dataset_config: DatasetConfig, seed_config: SeedConfig) -> dict[str, Any]:
    return {
        **asdict(model_config),
        "n_envs": dataset_config.n_envs,
        "n_hists": 1,
        "n_samples": 1,
        "H": dataset_config.context_len,
        "seed": seed_config.seed,
    }


@dataclass
class PrintAdvAgainstsConfig:
    setup_dir: str = field(metadata={"help": "Folder name under which to aggregate results"})
    n_steps: int = field(default=500, metadata={"help": "Number of steps taken in environment"})
    meta: str = field(default="")
