import random

from functools import partial

import numpy as np
import torch

from torch.utils.data import DataLoader

import wandb

from input_args import mt_parse_args
from transfer.datasets.multi_task_dataset import MultiTaskDataset
from transfer.datasets.single_task_dataset import SingleTaskDataset
from transfer.envs.metaworld import get_single_env
from transfer.evaluation.evaluate_episodes import evaluate_k_episodes
from transfer.models.decision_transformer import DecisionTransformer
from transfer.models.mlp import MLPModel
from transfer.training.seq_trainer import SequenceTrainer
from transfer.utils.utils import set_seed


def experiment(variant):
    device = variant.get("device", "cuda")
    log_to_wandb = variant.get("log_to_wandb", False)

    envs = variant["env"]
    use_actions = variant["use_actions"]
    use_returns = variant["use_returns"]
    model_type = variant["model_type"]
    dataset_names = variant["dataset"]
    add_one_hot = variant["add_one_hot"]

    # check if envs are aligned with datasets
    assert np.all([e in d for e, d in zip(envs, dataset_names)])

    num_tasks = len(envs)
    envs = [
        get_single_env(env, add_one_hot=add_one_hot, one_hot_idx=i, one_hot_len=num_tasks) for i, env in enumerate(envs)
    ]
    max_ep_len = 200
    env_targets = [3000, 1500]
    scale = 1000.0

    exp_prefix = variant["exp_prefix"]
    group_name = f"{exp_prefix}-multitask-{len(dataset_names)}"
    exp_prefix = f"{group_name}-{random.randint(int(1e5), int(1e6) - 1)}"

    set_seed(variant["seed"])

    if use_returns is False:
        env_targets = env_targets[:1]  # when we ignore targets, no need for different evaluations

    state_dim = envs[0].observation_space.shape[0]
    act_dim = envs[0].action_space.shape[0]

    K = variant["K"]
    batch_size = variant["batch_size"]
    num_eval_episodes = variant["num_eval_episodes"]
    pct_traj = variant.get("pct_traj", 1.0)
    mode = variant.get("mode", "normal")

    # load dataset
    datasets = []
    for i, dataset_name in enumerate(dataset_names):
        dataset_path = variant["dataset_path"] / f"{dataset_name}.pkl"
        single_dataset = SingleTaskDataset(
            dataset_path,
            context_size=K,
            pct_traj=pct_traj,
            normalize_inputs=True,
            delayed=mode == "delayed",
            max_ep_len=max_ep_len,
            scale=scale,
            add_one_hot=add_one_hot,
            one_hot_idx=i,
            one_hot_len=len(dataset_names),
        )
        datasets.append(single_dataset)
    dataset = MultiTaskDataset(datasets)

    weighted_sampler = dataset.create_weighted_sampler()
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        sampler=weighted_sampler,
        drop_last=True,
        pin_memory=True,
    )

    assert state_dim == dataset.state_dim
    assert act_dim == dataset.act_dim

    if model_type == "dt":
        model = DecisionTransformer(
            state_dim=state_dim,
            act_dim=act_dim,
            max_length=K,
            max_ep_len=max_ep_len,
            hidden_size=variant["embed_dim"],
            n_layer=variant["n_layer"],
            n_head=variant["n_head"],
            n_inner=4 * variant["embed_dim"],
            activation_function=variant["activation_function"],
            n_positions=1024,
            resid_pdrop=variant["dropout"],
            attn_pdrop=variant["dropout"],
            use_actions=use_actions,
            use_returns=use_returns,
        )
    elif model_type == "mlp":
        model = MLPModel(
            state_dim=state_dim,
            act_dim=act_dim,
            max_length=K,
            hidden_size=variant["embed_dim"],
            n_layer=variant["n_layer"],
            use_actions=use_actions,
            use_returns=use_returns,
        )
    else:
        raise NotImplementedError

    model = model.to(device=device)

    warmup_steps = variant["warmup_steps"]
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=variant["learning_rate"],
        weight_decay=variant["weight_decay"],
    )
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda steps: min((steps + 1) / warmup_steps, 1))

    eval_episodes = partial(
        evaluate_k_episodes,
        num_eval_episodes=num_eval_episodes,
        max_ep_len=max_ep_len,
        state_dim=state_dim,
        act_dim=act_dim,
        scale=scale,
        mode=mode,
        device=device,
    )
    trainer = SequenceTrainer(
        model=model,
        optimizer=optimizer,
        dataloader=dataloader,
        device=device,
        scheduler=scheduler,
        loss_fn=lambda s_hat, a_hat, r_hat, s, a, r: torch.mean((a_hat - a) ** 2),
        eval_fns=[
            eval_episodes(target_rew=tar, env=env, state_mean=dataset.state_mean, state_std=dataset.state_std)
            for env, dataset in zip(envs, dataset.datasets)
            for tar in env_targets
        ],
    )

    if log_to_wandb:
        wandb.init(name=exp_prefix, group=group_name, entity="<PLACEHOLDER>", project="offlinerl_pretraining", config=variant)
        # wandb.watch(model)  # wandb has some bug

    for iter in range(variant["max_iters"]):
        outputs = trainer.train_iteration(num_steps=variant["num_steps_per_iter"], iter_num=iter + 1, print_logs=True)
        if log_to_wandb:
            wandb.log(outputs)

    torch.save(model.state_dict(), "model.pth")


if __name__ == "__main__":
    args = mt_parse_args()
    experiment(variant=vars(args))
