import os
import pprint
from dataclasses import asdict

import bullet_safety_gym

try:
    import safety_gymnasium
except ImportError:
    print("safety_gymnasium is not found.")
from gym_minigrid.register import env_list
from gym_minigrid.minigrid import Grid, OBJECT_TO_IDX
from gym_minigrid.wrappers import *

# import gymnasium as gym
import gym
import numpy as np
import pyrallis
import torch
import torch.nn as nn
from tianshou.data import VectorReplayBuffer
from tianshou.env import BaseVectorEnv, ShmemVectorEnv, SubprocVectorEnv
from tianshou.utils.net.common import Net
# from tianshou.utils.net.continuous import ActorProb
from tianshou.utils.net.discrete import Actor
# from examples.configs.cvpo_cfg import *
from fsrl.config.sacl_disc_cfg import (
    Bullet1MCfg,
    Bullet5MCfg,
    Bullet10MCfg,
    Mujoco2MCfg,
    Mujoco10MCfg,
    Mujoco20MCfg,
    MujocoBaseCfg,
    TrainCfg,
)
from fsrl.data import FastCollector
from fsrl.policy import SACLagrangianDisc
from fsrl.trainer import OffpolicyTrainer
from fsrl.utils import TensorboardLogger, WandbLogger
from fsrl.utils.exp_util import auto_name, seed_all
from fsrl.utils.net.common import ActorCritic, CNNActor
from fsrl.utils.net.discrete import DoubleCritic, DoubleCriticCNN

TASK_TO_CFG = {
    0: TrainCfg,
    1: TrainCfg,
    2: TrainCfg,
    3: TrainCfg,
    4: TrainCfg,
    5: TrainCfg,
    6: TrainCfg,
    7: TrainCfg,
    8: TrainCfg,
    9: TrainCfg,
    10: TrainCfg,
    11: TrainCfg,
    12: TrainCfg,
    13: TrainCfg,
    14: TrainCfg,
    15: TrainCfg,
}


@pyrallis.wrap()
def train(args: TrainCfg):
    # set seed and computing
    seed_all(args.seed)
    torch.set_num_threads(args.thread)

    task = args.task
    env_name = env_list[args.task]
    default_cfg = TASK_TO_CFG[task]() if task in TASK_TO_CFG else TrainCfg()
    # use the default configs instead of the input args.
    if args.use_default_cfg:
        default_cfg.task = args.task
        default_cfg.seed = args.seed
        default_cfg.device = args.device
        default_cfg.logdir = args.logdir
        default_cfg.project = args.project
        default_cfg.group = args.group
        default_cfg.suffix = args.suffix
        args = default_cfg

    # setup logger
    cfg = asdict(args)
    default_cfg = asdict(default_cfg)
    if args.name is None:
        args.name = auto_name(default_cfg, cfg, args.prefix, args.suffix)
    if args.group is None:
        args.group = env_name + "-cost-" + str(int(args.cost_limit))
    if args.logdir is not None:
        args.logdir = os.path.join(args.logdir, args.project, args.group+"_reward_shaping")
    # logger = WandbLogger(cfg, args.project, args.group, args.name, args.logdir)
    logger = TensorboardLogger(args.logdir, log_txt=True, name=args.name)
    logger.save_config(cfg, verbose=args.verbose)

    training_num = min(args.training_num, args.episode_per_collect)
    worker = eval(args.worker)
    train_envs = worker([lambda: StandardSafetyGymWrapper(gym.make(env_name), reward_shaping=True) for _ in range(training_num)])
    test_envs = worker([lambda: StandardSafetyGymWrapper(gym.make(env_name), reward_shaping=False) for _ in range(args.testing_num)])

    # model
    
    env = gym.make(env_name)
    env = StandardSafetyGymWrapper(env, reward_shaping=True)
    state_shape = env.observation_space.shape or env.observation_space.n
    action_shape = env.action_space.shape or env.action_space.n
    # max_action = env.action_space.high[0]

    # net = Net(state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
    # actor = Actor(
    #     net,
    #     action_shape,
    #     device=args.device,
    # ).to(args.device)
    actor = CNNActor(
        state_shape[0],
        action_shape,
        hidden_sizes=args.hidden_sizes,
        device=args.device,
    ).to(args.device)
    actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)

    critics = []
    for i in range(2):
        net1 = Net(
            state_shape[0]-25+12,
            hidden_sizes=args.hidden_sizes,
            device=args.device
        )
        net2 = Net(
            state_shape[0]-25+12,
            hidden_sizes=args.hidden_sizes,
            device=args.device
        )
        critics.append(DoubleCriticCNN(net1, net2, last_size=action_shape, device=args.device).to(args.device))

    critic_optim = torch.optim.Adam(
        nn.ModuleList(critics).parameters(), lr=args.critic_lr
    )

    actor_critic = ActorCritic(actor, critics)
    # orthogonal initialization
    for m in actor_critic.modules():
        if isinstance(m, torch.nn.Linear):
            torch.nn.init.orthogonal_(m.weight)
            torch.nn.init.zeros_(m.bias)

    if args.last_layer_scale:
        # do last policy layer scaling, this will make initial actions have (close to)
        # 0 mean and std, and will help boost performances,
        # see https://arxiv.org/abs/2006.05990, Fig.24 for details
        for m in actor.mu.modules():
            if isinstance(m, torch.nn.Linear):
                torch.nn.init.zeros_(m.bias)
                m.weight.data.copy_(0.01 * m.weight.data)

    if args.auto_alpha:
        target_entropy = -np.log(1.0 / env.action_space.n) * args.target_entropy_ratio
        log_alpha = torch.zeros(1, requires_grad=True, device=args.device)
        alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr)
        args.alpha = (target_entropy, log_alpha, alpha_optim)

    policy = SACLagrangianDisc(
        actor=actor,
        critics=critics,
        actor_optim=actor_optim,
        critic_optim=critic_optim,
        logger=logger,
        alpha=args.alpha,
        tau=args.tau,
        gamma=args.gamma,
        exploration_noise=None,
        n_step=args.n_step,
        use_lagrangian=False,
        lagrangian_pid=(0.0,0.0,0.0),
        cost_limit=100,
        rescaling=args.rescaling,
        reward_normalization=False,
        deterministic_eval=args.deterministic_eval,
        action_scaling=args.action_scaling,
        max_cost_q=args.max_cost_q,
        action_bound_method=args.action_bound_method,
        observation_space=env.observation_space,
        action_space=env.action_space,
        lr_scheduler=None,
        lagrangian_max=args.lagrangian_max
    )

    # collector
    train_collector = FastCollector(
        policy,
        train_envs,
        VectorReplayBuffer(args.buffer_size, len(train_envs)),
        exploration_noise=True,
    )
    test_collector = FastCollector(policy, test_envs)

    def stop_fn(reward, cost):
        return reward > args.reward_threshold and cost < args.cost_limit

    def checkpoint_fn():
        return {"model": policy.state_dict()}

    if args.save_ckpt:
        logger.setup_checkpoint_fn(checkpoint_fn)

    # trainer
    trainer = OffpolicyTrainer(
        policy=policy,
        train_collector=train_collector,
        test_collector=test_collector,
        max_epoch=args.epoch,
        batch_size=args.batch_size,
        cost_limit=args.cost_limit,
        step_per_epoch=args.step_per_epoch,
        update_per_step=args.update_per_step,
        episode_per_test=args.testing_num,
        episode_per_collect=args.episode_per_collect,
        stop_fn=stop_fn,
        logger=logger,
        resume_from_log=args.resume,
        save_model_interval=args.save_interval,
        verbose=args.verbose,
    )

    for epoch, epoch_stat, info in trainer:
        logger.store(tab="train", cost_limit=args.cost_limit)
        print(f"Epoch: {epoch}")
        print(info)

    if __name__ == "__main__":
        pprint.pprint(info)
        # Let's watch its performance!
        env_name = env_list[args.task]
        env = gym.make(env_name)
        env = StandardSafetyGymWrapper(env, reward_shaping=True)
        policy.eval()
        collector = FastCollector(policy, env)
        result = collector.collect(n_episode=10, render=args.render)
        rews, lens, cost = result["rew"], result["len"], result["cost"]
        print(f"Final eval reward: {rews.mean()}, cost: {cost}, length: {lens.mean()}")

        policy.train()
        collector = FastCollector(policy, env)
        result = collector.collect(n_episode=10, render=args.render)
        rews, lens, cost = result["rew"], result["len"], result["cost"]
        print(f"Final train reward: {rews.mean()}, cost: {cost}, length: {lens.mean()}")


if __name__ == "__main__":
    train()
