from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
import inspect
from typing import Generic, Type, TypeVar

from flax.nnx.filterlib import Filter
from gymnasium.spaces import Box
import numpy as np

from offline.envs.registration import (
    get_registered_datasets,
    make_env_and_load_data,
)
from offline.helper import compile_act, evaluate, log_evaluation_results
from offline.modules.policy import Policy, StateT
from offline.types import OfflineDataWithInfos
from offline.utils.dataset import normalize_rewards, unsquash_actions
from offline.utils.logger import Logger
from offline.utils.parser import ArgumentParser
from offline.utils.misc import set_seed
from offline.utils.suppress_warnings import (
    suppress_absl_warnings,
    suppress_gymnasium_warnings,
)
from offline.utils.tqdm import trange


EPS = 1e-5


@dataclass(frozen=True)
class Arguments:
    dataset: str
    eval: bool
    eval_episodes: int
    eval_freq: int
    logger: Logger
    normalize_observations: bool
    normalize_rewards: bool
    reward_multiplier: bool
    save: bool
    save_checkpoints: bool
    save_eval_results: bool
    seed: int
    total_steps: int
    unsquash: bool

    @classmethod
    def from_kwargs(cls, **kwargs):
        return cls(
            **{
                k: v
                for k, v in kwargs.items()
                if k in inspect.signature(cls).parameters
            }
        )


@dataclass(frozen=True)
class TrainerState(ABC, Generic[StateT]):
    eval_state: StateT

    @property
    @abstractmethod
    def policy(self) -> Policy[StateT]:
        pass


ArgsT_contra = TypeVar("ArgsT_contra", bound=Arguments, contravariant=True)
PolicyT = TypeVar("PolicyT", bound=Policy)
TstateT = TypeVar("TstateT", bound=TrainerState)

InitFn = Callable[[ArgsT_contra, OfflineDataWithInfos], TstateT]
TrainFn = Callable[[int, ArgsT_contra, TstateT], TstateT]


def build_argument_parser(**kwargs):
    parser = ArgumentParser(**kwargs)
    group = parser.add_mutually_exclusive_group()
    group.add_argument(
        "--normalize-observations",
        action="store_const",
        const=True,
        default=True,
    )
    group.add_argument(
        "--do-not-normalize-observations",
        action="store_false",
        dest="normalize_observations",
    )
    group = parser.add_mutually_exclusive_group()
    group.add_argument(
        "--normalize-rewards",
        action="store_const",
        const=True,
        default=True,
    )
    group.add_argument(
        "--do-not-normalize-rewards",
        action="store_false",
        dest="normalize_rewards",
    )
    parser.add_argument(
        "--dataset",
        choices=tuple(get_registered_datasets()),
        default="hopper-medium-v2",
    )
    parser.add_argument("--do-not-save", action="store_false", dest="save")
    parser.add_argument(
        "--do-not-save-eval-results",
        action="store_false",
        dest="save_eval_results",
    )
    parser.add_argument("--eval-episodes", type=int, default=10)
    parser.add_argument("--eval-freq", type=int, default=5000)
    parser.add_argument("--reward-multiplier", type=float, default=1000)
    parser.add_argument("--save-checkpoints", action="store_true")
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--skip-eval", action="store_false", dest="eval")
    parser.add_argument("--total-steps", type=int, default=1000000)
    group = parser.add_mutually_exclusive_group()
    group.add_argument("--unsquash", action="store_true")
    group.add_argument(
        "--squash", action="store_const", const=False, dest="unsquash"
    )
    return parser


def default_load_fn(
    step: int | None,
    logger: Logger,
    model_fn: Callable[[], PolicyT],
    poi: Filter | None = None,
) -> PolicyT:
    if step is None:
        return logger.restore_model("policy", model_fn=model_fn, poi=poi)
    return logger.restore_model(
        "checkpoints", f"policy_{step}", model_fn=model_fn, poi=poi
    )


def _run(
    args: ArgsT_contra,
    init_fn: InitFn[ArgsT_contra, TstateT],
    skip_reward_normalization: bool,
    train_fn: TrainFn[ArgsT_contra, TstateT],
):
    env, data, get_normalized_score = make_env_and_load_data(args.dataset)
    if args.normalize_observations:
        mean = np.mean(data.data.observations, axis=0, keepdims=True)
        std = np.std(data.data.observations, axis=0, keepdims=True)
        std = np.clip(std, EPS, std.max())
        args.logger.save_numpy("stats.npz", mean=mean, std=std)
        observations = (data.data.observations - mean) / std
        data = data._replace(data=data.data._replace(observations=observations))
    else:
        mean, std = 0, 1

    if args.unsquash:
        data = data._replace(
            data=data.data._replace(actions=unsquash_actions(data.data.actions))
        )

    if not skip_reward_normalization and args.normalize_rewards:
        coefficient = args.reward_multiplier * normalize_rewards(data.data, EPS)
        data = data._replace(
            data=data.data._replace(rewards=coefficient * data.data.rewards)
        )

    state = init_fn(args, data)
    assert isinstance(env.action_space, Box)
    if type(state.policy).__call__ == Policy.__call__:
        # __call__ not implemented
        act_fn = None
    else:
        act_fn = compile_act(
            action_space=env.action_space,
            mean=mean,
            std=std,
            unsquash=args.unsquash,
        )

    state.policy.train()

    for step in trange(args.total_steps, desc="Train", leave=True):
        state = train_fn(step, args, state)
        if args.eval_freq > 0 and (step + 1) % args.eval_freq == 0:
            if act_fn is not None and args.eval:
                state.policy.eval()
                results = evaluate(
                    act_fn=act_fn,
                    env=env,
                    eval_episodes=args.eval_episodes,
                    get_normalized_score=get_normalized_score,
                    policy=state.policy,
                    seed=args.seed,
                    state=state.eval_state,
                )
                state.policy.train()
                log_evaluation_results(
                    results=results, step=step, writer=args.logger.writer
                )
                if args.save_eval_results:
                    args.logger.save_numpy("eval", f"{step}.npz", **results)
            if args.save_checkpoints:
                state.policy.save(step, args.logger)
    if act_fn is not None and args.eval:
        state.policy.eval()
        results = evaluate(
            act_fn=act_fn,
            env=env,
            eval_episodes=args.eval_episodes,
            get_normalized_score=get_normalized_score,
            policy=state.policy,
            seed=args.seed,
            state=state.eval_state,
        )
        if args.save_eval_results:
            args.logger.save_numpy("eval", "final.npz", **results)
    if args.save:
        state.policy.save(None, args.logger)


def run(
    arguments_class: Type[ArgsT_contra],
    init_fn: InitFn[ArgsT_contra, TstateT],
    logger: Logger,
    train_fn: TrainFn[ArgsT_contra, TstateT],
    skip_reward_normalization: bool = False,
    **kwargs,
):
    try:
        args = arguments_class(logger=logger, **kwargs)
        set_seed(args.seed)
        suppress_absl_warnings()
        suppress_gymnasium_warnings()
        _run(
            args=args,
            init_fn=init_fn,
            skip_reward_normalization=skip_reward_normalization,
            train_fn=train_fn,
        )
        logger.wait()
    except (Exception, KeyboardInterrupt) as exception:
        logger.cleanup()
        raise exception
