import os
import uuid
import types
import copy
from dataclasses import asdict, dataclass
from typing import Any, DefaultDict, Dict, List, Optional, Tuple

import bullet_safety_gym  # noqa
import dsrl
import gymnasium as gym  # noqa
import gym as gym_org
import numpy as np
import pyrallis
import torch
from dsrl.infos import DENSITY_CFG
from dsrl.offline_env import OfflineEnvWrapper, wrap_env  # noqa
from fsrl.utils import WandbLogger
from fsrl.utils import TensorboardLogger
from torch.utils.data import DataLoader
from tqdm.auto import trange  # noqa

from examples.configs.sa_encoder_configs import SAEncoder_DEFAULT_CONFIG, SAEncoderTrainConfig
from osrl.algorithms import State_AE, Action_AE, inverse_dynamics_model, ActionAETrainer, StateAETrainer
from osrl.common import SequenceDataset, TransitionDataset
from osrl.common.exp_util import auto_name, seed_all


@pyrallis.wrap()
def train(args: SAEncoderTrainConfig):
    # update config
    cfg, old_cfg = asdict(args), asdict(SAEncoderTrainConfig())
    differing_values = {key: cfg[key] for key in cfg.keys() if cfg[key] != old_cfg[key]}
    cfg = asdict(SAEncoder_DEFAULT_CONFIG[args.task]())
    cfg.update(differing_values)
    args = types.SimpleNamespace(**cfg)

    # setup logger
    default_cfg = asdict(SAEncoder_DEFAULT_CONFIG[args.task]())
    if args.name is None:
        args.name = auto_name(default_cfg, cfg, args.prefix, args.suffix)
    if args.group is None:
        args.group = args.task + "-cost-" + str(int(args.cost_limit))
    if args.logdir is not None:
        args.logdir = os.path.join(args.logdir, args.group, args.name)
    # logger = WandbLogger(cfg, args.project, args.group, args.name, args.logdir)
    logger = TensorboardLogger(args.logdir, log_txt=True, name=args.name+"_state_AE")
    logger.save_config(cfg, verbose=args.verbose)
    

    # set seed
    seed_all(args.seed)
    if args.device == "cpu":
        torch.set_num_threads(args.threads)

    # initialize environment
    if "Metadrive" in args.task:
        # import gym
        env = gym_org.make(args.task)
    else:
        env = gym.make(args.task)

    # pre-process offline dataset
    data = env.get_dataset()
    env.set_target_cost(args.cost_limit)

    cbins, rbins, max_npb, min_npb = None, None, None, None
    if args.density != 1.0:
        density_cfg = DENSITY_CFG[args.task + "_density" + str(args.density)]
        cbins = density_cfg["cbins"]
        rbins = density_cfg["rbins"]
        max_npb = density_cfg["max_npb"]
        min_npb = density_cfg["min_npb"]
    data = env.pre_process_data(data,
                                args.outliers_percent,
                                args.noise_scale,
                                args.inpaint_ranges,
                                args.epsilon,
                                args.density,
                                cbins=cbins,
                                rbins=rbins,
                                max_npb=max_npb,
                                min_npb=min_npb)

    # wrapper
    env = wrap_env(
        env=env,
        reward_scale=args.reward_scale,
    )
    env = OfflineEnvWrapper(env)

    state_encoder = State_AE(
        state_dim=env.observation_space.shape[0],
        encode_dim=args.state_encode_dim,
        hidden_sizes=args.state_encoder_hidden_sizes
    ).to(args.device)
    idm = inverse_dynamics_model(
        state_dim=args.state_encode_dim,
        action_dim=args.action_encode_dim,
        hidden_sizes=args.inverse_dynamics_model_hidden_sizes
    ).to(args.device)
    train_action_encoder=False
    if env.action_space.shape[0]!=args.action_encode_dim:
        train_action_encoder=True
        action_logger = TensorboardLogger(args.logdir, log_txt=True, name=args.name+"_action_AE")
        action_logger.save_config(cfg, verbose=args.verbose)
        action_encoder = Action_AE(
            action_dim=env.action_space.shape[0],
            encode_dim=args.action_encode_dim,
            hidden_sizes=args.action_encoder_hidden_sizes
        ).to(args.device)

    # # model & optimizer & scheduler setup
    # model = RTG_model(
    #     state_dim=env.observation_space.shape[0],
    #     prompt_dim=args.prompt_dim,
    #     cost_embedding_dim=args.embedding_dim,
    #     state_embedding_dim=args.embedding_dim,
    #     prompt_embedding_dim=args.embedding_dim,
    #     r_hidden_sizes=args.r_hidden_sizes,
    #     use_state=args.use_state,
    #     use_prompt=args.use_prompt
    # ).to(args.device)
    # print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")

    def checkpoint_fn_state():
        return {"model_state": state_encoder.state_dict()}
    if train_action_encoder:
        def checkpoint_fn_action():
            return {"model_state": action_encoder.state_dict()}

    logger.setup_checkpoint_fn(checkpoint_fn_state)
    if train_action_encoder:
        action_logger.setup_checkpoint_fn(checkpoint_fn_action)

    # trainer
    state_trainer = StateAETrainer(
        state_encoder,
        idm,
        logger=logger,
        learning_rate=args.learning_rate,
        device=args.device,
        idm_loss_weight=args.idm_loss_weight
    )
    if train_action_encoder:
        action_trainer = ActionAETrainer(
            action_encoder,
            logger=action_logger,
            learning_rate=args.learning_rate,
            device=args.device,
            add_noise=args.add_noise,
            noise_scale=args.noise_scale
        )


    dataset = TransitionDataset(data,
                                reward_scale=args.reward_scale,
                                cost_scale=args.cost_scale)

    state_trainloader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        pin_memory=True,
        num_workers=args.num_workers,
    )
    state_trainloader_iter = iter(state_trainloader)

    if train_action_encoder:
        action_trainloader = DataLoader(
            dataset,
            batch_size=args.batch_size,
            pin_memory=True,
            num_workers=args.num_workers,
        )
        action_trainloader_iter = iter(action_trainloader)

    testloader = DataLoader(
        dataset,
        batch_size=args.batch_size*100,
        pin_memory=True,
        num_workers=args.num_workers,
    )
    testloader_iter = iter(testloader)

    if train_action_encoder:
        best_loss = np.inf
        best_action_encoder = None
        for epoch in range(args.action_epoch_num):
            for step in trange(args.steps_per_epoch, desc="Training"):
                batch = next(action_trainloader_iter)
                observations, next_observations, actions, rewards, costs, done = [
                    b.to(args.device) for b in batch
                ]
                action_trainer.train_one_step(actions)
            action_logger.save_checkpoint()
            test_batch = next(testloader_iter)
            observations, next_observations, actions, rewards, costs, done = [
                b.to(args.device) for b in test_batch
            ]
            eval_loss = action_trainer.eval_one_step(actions)
            if eval_loss < best_loss:
                best_loss = eval_loss
                best_idx = epoch+1
                best_action_encoder = copy.deepcopy(action_trainer.model)
                action_logger.save_checkpoint(suffix="best")
            action_logger.store(tab="train", best_idx=best_idx)
            action_logger.write(epoch+1, display=False)


    best_loss = np.inf
    if train_action_encoder:
        best_action_encoder.eval()
    for epoch in range(args.state_epoch_num):
        for step in trange(args.steps_per_epoch, desc="Training"):
            batch = next(state_trainloader_iter)
            observations, next_observations, actions, rewards, costs, done = [
                b.to(args.device) for b in batch
            ]
            if train_action_encoder:
                with torch.no_grad():
                    actions = best_action_encoder.encode(actions)
            state_trainer.train_one_step(observations,actions,next_observations)
        logger.save_checkpoint()
        test_batch = next(testloader_iter)
        observations, next_observations, actions, rewards, costs, done = [
            b.to(args.device) for b in test_batch
        ]
        if train_action_encoder:
            with torch.no_grad():
                actions = best_action_encoder.encode(actions)
        eval_loss = state_trainer.eval_one_step(observations,actions,next_observations)
        if eval_loss < best_loss:
            best_loss = eval_loss
            best_idx = epoch+1
            logger.save_checkpoint(suffix="best")
        logger.store(tab="train", best_idx=best_idx)
        logger.write(epoch+1, display=False)


if __name__ == "__main__":
    train()
