import os
import uuid
import types
from dataclasses import asdict, dataclass
from typing import Any, DefaultDict, Dict, List, Optional, Tuple
from torch.distributions import Normal
import bullet_safety_gym  # noqa
import dsrl
import gymnasium as gym  # noqa
import numpy as np
import pyrallis
import torch
from dsrl.infos import DENSITY_CFG
from dsrl.offline_env import OfflineEnvWrapper, wrap_env  # noqa
from fsrl.utils import WandbLogger
from tqdm.auto import trange  # noqa

from config.cpq_config import CPQ_DEFAULT_CONFIG, CPQTrainConfig
from osrl.algorithms import CPQ, CPQTrainer
from osrl.common import TransitionDataset
from osrl.common.exp_util import auto_name, seed_all
from torch.utils.data import TensorDataset, DataLoader

import pickle


@pyrallis.wrap()
def train(args: CPQTrainConfig):
    # update config
    cfg, old_cfg = asdict(args), asdict(CPQTrainConfig())
    differing_values = {key: cfg[key] for key in cfg.keys() if cfg[key] != old_cfg[key]}
    cfg = asdict(CPQ_DEFAULT_CONFIG[args.task]())
    cfg.update(differing_values)
    args = types.SimpleNamespace(**cfg)

    # setup logger
    default_cfg = asdict(CPQ_DEFAULT_CONFIG[args.task]())
    if args.name is None:
        args.name = auto_name(default_cfg, cfg, args.prefix, args.suffix)
    if args.group is None:
        args.group = args.task + "-cost-" + str(int(args.cost_limit))
    if args.logdir is not None:
        args.logdir = os.path.join(args.logdir, args.group, args.name)
    logger = WandbLogger(cfg, args.project, args.group, args.name, args.logdir)
    # logger = TensorboardLogger(args.logdir, log_txt=True, name=args.name)
    logger.save_config(cfg, verbose=args.verbose)

    # set seed
    seed_all(args.seed)
    if args.device == "cpu":
        torch.set_num_threads(args.threads)

    # initialize environment
    env = gym.make(args.task)

    # pre-process offline dataset
    data = env.get_dataset()
    env.set_target_cost(args.cost_limit)

    cbins, rbins, max_npb, min_npb = None, None, None, None
    if args.density != 1.0:
        density_cfg = DENSITY_CFG[args.task + "_density" + str(args.density)]
        cbins = density_cfg["cbins"]
        rbins = density_cfg["rbins"]
        max_npb = density_cfg["max_npb"]
        min_npb = density_cfg["min_npb"]
    data = env.pre_process_data(data,
                                args.outliers_percent,
                                args.noise_scale,
                                args.inpaint_ranges,
                                args.epsilon,
                                args.density,
                                cbins=cbins,
                                rbins=rbins,
                                max_npb=max_npb,
                                min_npb=min_npb)

    # wrapper
    env = wrap_env(
        env=env,
        reward_scale=args.reward_scale,
    )
    env = OfflineEnvWrapper(env)

    # model & optimizer set
    model = CPQ(
        state_dim=env.observation_space.shape[0],
        action_dim=env.action_space.shape[0],
        max_action=env.action_space.high[0],
        a_hidden_sizes=args.a_hidden_sizes,
        c_hidden_sizes=args.c_hidden_sizes,
        vae_hidden_sizes=args.vae_hidden_sizes,
        sample_action_num=args.sample_action_num,
        gamma=args.gamma,
        tau=args.tau,
        beta=args.beta,
        num_q=args.num_q,
        num_qc=args.num_qc,
        qc_scalar=args.qc_scalar,
        cost_limit=args.cost_limit,
        episode_len=args.episode_len,
        device=args.device,
    )
    print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")

    def checkpoint_fn():
        return {"model_state": model.state_dict()}

    logger.setup_checkpoint_fn(checkpoint_fn)

    # trainer
    trainer = CPQTrainer(model,
                         env,
                         logger=logger,
                         actor_lr=args.actor_lr,
                         critic_lr=args.critic_lr,
                         alpha_lr=args.alpha_lr,
                         vae_lr=args.vae_lr,
                         reward_scale=args.reward_scale,
                         cost_scale=args.cost_scale,
                         device=args.device)

    # initialize pytorch dataloader
    dataset = TransitionDataset(data,
                                reward_scale=args.reward_scale,
                                cost_scale=args.cost_scale)

    results_dir = f"../offline_dataset/{args.task}"

    if not os.path.exists(results_dir):
        os.makedirs(results_dir)

    hyperparameters = {
        "reward_scale": args.reward_scale,
        "cost_scale": args.cost_scale,
        "outliers_percent": args.outliers_percent,
        "noise_scale": args.noise_scale,
        "inpaint_ranges": args.inpaint_ranges,
        "epsilon": args.epsilon,
        "density": args.density,
        "cbins": cbins,
        "rbins": rbins,
        "max_npb": max_npb,
        "min_npb": min_npb
    }

    with open(f'{results_dir}/hyperparameters.txt', 'w') as file:
        for param, value in hyperparameters.items():
            file.write(f'{param}: {value}\n')
    with open(f'{results_dir}/dataset.pkl', 'wb') as file:
        pickle.dump(dataset.dataset, file)

    print("Dataset saved!")

    trainloader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        pin_memory=True,
        num_workers=args.num_workers,
    )
    trainloader_iter = iter(trainloader)

    # for saving the best
    best_reward = -np.inf
    best_cost = np.inf
    best_idx = 0

    update_steps = 100_000

    # offline_training
    for step in trange(update_steps, desc="Training"):
        batch = next(trainloader_iter)
        observations, next_observations, actions, rewards, costs, done = [
            b.to(args.device) for b in batch
        ]
        trainer.train_one_step(observations, next_observations, actions, rewards, costs,
                               done)

        # evaluation
        if (step + 1) % 1000 == 0 or step == update_steps - 1:
            ret, cost, length = trainer.evaluate(args.eval_episodes)
            logger.store(tab="eval", Cost=cost, Reward=ret, Length=length)
            print(ret, cost)

            # save the current weight
            logger.save_checkpoint()
            # save the best weight
            if cost < best_cost or (cost == best_cost and ret > best_reward):
                best_cost = cost
                best_reward = ret
                best_idx = step
                logger.save_checkpoint(suffix="best")

            logger.store(tab="train", best_idx=best_idx)
            logger.write(step, display=False)

            save_dir = '../offline_model'

            if not os.path.exists(save_dir):
                os.makedirs(save_dir)

            torch.save(model.actor.state_dict(), os.path.join(save_dir, f'{args.task}_cpq_actor.pth'))
            torch.save(model.critic.state_dict(), os.path.join(save_dir, f'{args.task}_cpq_critic.pth'))
            torch.save(model.cost_critic.state_dict(), os.path.join(save_dir, f'{args.task}_cpq_cost_critic.pth'))
            torch.save(model.actor_optim.state_dict(), os.path.join(save_dir, f'{args.task}_cpq_actor_optim.pth'))
            torch.save(model.critic_optim.state_dict(), os.path.join(save_dir, f'{args.task}_cpq_critic_optim.pth'))
            torch.save(model.cost_critic_optim.state_dict(),
                       os.path.join(save_dir, f'{args.task}_cpq_cost_critic_optim.pth'))

        else:
            logger.write_without_reset(step)



if __name__ == "__main__":
    train()
