import random

from functools import partial

import torch

from torch.utils.data import DataLoader

import wandb

from input_args import single_parse_args
from transfer.datasets.single_task_dataset import SingleTaskDataset
from transfer.envs.gym import (
    get_gym_env,
    get_reacher_env,
)
from transfer.envs.metaworld import (
    MT50,
    get_single_env,
)
from transfer.evaluation.evaluate_episodes import evaluate_k_episodes
from transfer.models.decision_transformer import DecisionTransformer
from transfer.models.mlp import MLPModel
from transfer.training.seq_trainer import SequenceTrainer
from transfer.utils.utils import set_seed


def experiment(variant):
    device = variant.get("device", "cuda")
    log_to_wandb = variant.get("log_to_wandb", False)

    env_name = variant["env"]
    use_actions = variant["use_actions"]
    use_returns = variant["use_returns"]
    model_type = variant["model_type"]
    dataset = variant["dataset"]

    exp_prefix = variant["exp_prefix"]
    group_name = f"{exp_prefix}-{dataset}"
    exp_prefix = f"{group_name}-{random.randint(int(1e5), int(1e6) - 1)}"

    set_seed(variant["seed"])

    if env_name == "Hopper-v3":
        env = get_gym_env(env_name)
        max_ep_len = 1000
        env_targets = [3600, 1800]  # evaluation conditioning targets
        scale = 1000.0  # normalization for rewards/returns
    elif env_name == "HalfCheetah-v3":
        env = get_gym_env(env_name)
        max_ep_len = 1000
        env_targets = [12000, 6000]
        scale = 1000.0
    elif env_name == "Walker2d-v3":
        env = get_gym_env(env_name)
        max_ep_len = 1000
        env_targets = [5000, 2500]
        scale = 1000.0
    elif env_name == "reacher2d":
        env = get_reacher_env()
        max_ep_len = 100
        env_targets = [76, 40]
        scale = 10.0
    elif env_name in list(MT50.train_classes.keys()) + list(MT50.test_classes.keys()):
        env = get_single_env(env_name)
        max_ep_len = 200
        env_targets = [3000, 1500]
        scale = 1000.0
    else:
        raise NotImplementedError

    if use_returns is False:
        env_targets = env_targets[:1]  # when we ignore targets, no need for different evaluations

    state_dim = env.observation_space.shape[0]
    act_dim = env.action_space.shape[0]

    K = variant["K"]
    batch_size = variant["batch_size"]
    num_eval_episodes = variant["num_eval_episodes"]
    pct_traj = variant.get("pct_traj", 1.0)
    mode = variant.get("mode", "normal")

    # load dataset
    dataset_path = variant["dataset_path"] / f"{dataset}.pkl"
    dataset = SingleTaskDataset(
        dataset_path,
        context_size=K,
        pct_traj=pct_traj,
        normalize_inputs=True,
        delayed=mode == "delayed",
        max_ep_len=max_ep_len,
        scale=scale,
    )
    weighted_sampler = dataset.create_weighted_sampler()
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        sampler=weighted_sampler,
        drop_last=True,
        pin_memory=True,
    )

    assert state_dim == dataset.state_dim
    assert act_dim == dataset.act_dim

    if model_type == "dt":
        model = DecisionTransformer(
            state_dim=state_dim,
            act_dim=act_dim,
            max_length=K,
            max_ep_len=max_ep_len,
            hidden_size=variant["embed_dim"],
            n_layer=variant["n_layer"],
            n_head=variant["n_head"],
            n_inner=4 * variant["embed_dim"],
            activation_function=variant["activation_function"],
            n_positions=1024,
            resid_pdrop=variant["dropout"],
            attn_pdrop=variant["dropout"],
            use_actions=use_actions,
            use_returns=use_returns,
        )
    elif model_type == "mlp":
        model = MLPModel(
            state_dim=state_dim,
            act_dim=act_dim,
            max_length=K,
            hidden_size=variant["embed_dim"],
            n_layer=variant["n_layer"],
            use_actions=use_actions,
            use_returns=use_returns,
        )
    else:
        raise NotImplementedError

    model = model.to(device=device)

    warmup_steps = variant["warmup_steps"]
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=variant["learning_rate"],
        weight_decay=variant["weight_decay"],
    )
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda steps: min((steps + 1) / warmup_steps, 1))

    eval_episodes = partial(
        evaluate_k_episodes,
        env=env,
        num_eval_episodes=num_eval_episodes,
        max_ep_len=max_ep_len,
        state_dim=state_dim,
        act_dim=act_dim,
        scale=scale,
        state_mean=dataset.state_mean,
        state_std=dataset.state_std,
        mode=mode,
        device=device,
    )
    trainer = SequenceTrainer(
        model=model,
        optimizer=optimizer,
        dataloader=dataloader,
        device=device,
        scheduler=scheduler,
        loss_fn=lambda s_hat, a_hat, r_hat, s, a, r: torch.mean((a_hat - a) ** 2),
        eval_fns=[eval_episodes(target_rew=tar) for tar in env_targets],
    )

    if log_to_wandb:
        wandb.init(name=exp_prefix, group=group_name, entity="<PLACEHOLDER>", project="offlinerl_pretraining", config=variant)
        # wandb.watch(model)  # wandb has some bug

    for iter in range(variant["max_iters"]):
        outputs = trainer.train_iteration(num_steps=variant["num_steps_per_iter"], iter_num=iter + 1, print_logs=True)
        if log_to_wandb:
            wandb.log(outputs)

    torch.save(model.state_dict(), "model.pth")


if __name__ == "__main__":
    args = single_parse_args()
    experiment(variant=vars(args))
