# Adapted from https://raw.githubusercontent.com/nikhilbarhate99/min-decision-transformer/refs/heads/master/scripts/train.py
# Original work Copyright (c) 2022 Nikhil Barhate
# Modifications Copyright (c) 2025 King.com Ltd

import argparse
import os
import random
import csv
from datetime import datetime

import numpy as np
import gymnasium as gym

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import envs # register custom envs
from misc.utils import get_device, TrajectoryDataset, evaluate_on_pointDirEnv
from models.transformer import DecisionTransformer


def train(args):

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = True

    rtg_scale = args.rtg_scale 

    env_name = args.env
    traj_file_name = "trajectories"
    rtg_targets = args.rtg_targets
    rtg_targets = [int(rtg) for rtg in rtg_targets]
    max_eval_ep_len = args.max_eval_ep_len  # max len of one episode
    num_eval_ep = args.num_eval_ep          # num of evaluation episodes

    batch_size = args.batch_size            # training batch size
    lr = args.lr                            # learning rate
    wt_decay = args.wt_decay                # weight decay
    warmup_steps = args.warmup_steps        # warmup steps for lr scheduler

    # total updates = max_train_iters x num_updates_per_iter
    max_train_iters = args.max_train_iters
    num_updates_per_iter = args.num_updates_per_iter

    context_len = args.context_len      # K in decision transformer
    n_blocks = args.n_blocks            # num of transformer blocks
    embed_dim = args.embed_dim          # embedding (hidden) dim of transformer
    n_heads = args.n_heads              # num of transformer heads
    transformer_dropout_p = args.transformer_dropout_p          # dropout probability for the transformer layers
    mlp_dropout_p = args.pred_mlp_dropout_p    # dropout probability for the prediction MLPs


    dataset_paths = []
    for dataset_dir in args.dataset_dirs:
        dataset_path = f'{dataset_dir}/{traj_file_name}.pkl'
        dataset_paths.append(dataset_path)

    device = get_device(use_cuda=args.cuda)

    start_time = datetime.now().replace(microsecond=0)
    start_time_str = start_time.strftime("%y-%m-%d_%H-%M-%S")

    exp_dir = f"PDT_runs/{args.env}_{start_time_str}_seed{args.seed}_{args.exp_str}"
    os.makedirs(exp_dir, exist_ok=True)

    writer = SummaryWriter(exp_dir)

    with open(f"{exp_dir}/args.txt", 'w') as f:
        for arg in vars(args):
            f.write(f"{arg}: {getattr(args, arg)}\n")

    save_model_path =  f"{exp_dir}/model.pt"
    save_best_model_path = save_model_path[:-3] + "_best.pt"

    log_csv_path = f"{exp_dir}/log.csv"
    csv_writer = csv.writer(open(log_csv_path, 'a', 1))

    csv_header_list = ["duration", "num_updates", "action_loss", "state_loss", "rtg_loss"]
    for rtg in rtg_targets:
        csv_header_list.append(f"eval_avg_return_{rtg}")
        csv_header_list.append(f"eval_avg_len_{rtg}")
    csv_header = (csv_header_list)
    csv_writer.writerow(csv_header)

    print("=" * 60, flush=True)
    print("start time: " + start_time_str, flush=True)
    print("=" * 60, flush=True)

    print("device set to: " + str(device), flush=True)
    print("model save path: " + save_model_path, flush=True)
    print("log csv save path: " + log_csv_path, flush=True)

    data_iters = []
    data_loaders = []
    all_datasets = []
    for ds_idx, dataset_path in enumerate(dataset_paths):
        print("loading dataset: " + dataset_path, flush=True)
        use_state_dims = [int(dim) for dim in args.use_state_dims]
        traj_dataset = TrajectoryDataset(
            dataset_path,
            context_len,
            rtg_scale,
            traj_prompt_j=args.traj_prompt_j,
            traj_prompt_h=args.traj_prompt_h,
            use_state_dims=use_state_dims,
            use_sparse_reward=args.pdt_use_sparse_reward,
            use_every_nth_traj=args.dataset_every_nth_traj,
        )
        all_datasets.append(traj_dataset)

        train_dataset_subset_idxs = [int(idx) for idx in args.train_dataset_subset_idxs]
        if ds_idx in train_dataset_subset_idxs:
            traj_data_loader = DataLoader(
                traj_dataset,
                batch_size=batch_size,
                shuffle=True,
                pin_memory=True,
                drop_last=True
            )
            data_loaders.append(traj_data_loader)

            data_iter = iter(traj_data_loader)
            data_iters.append(data_iter)
    print("Done loading datasets", flush=True)

    state_means, state_stds = [], []
    for dataset in all_datasets:
        state_mean, state_std = dataset.get_state_stats()
        state_means.append(state_mean)
        state_stds.append(state_std)
    state_mean = np.mean(np.array(state_means), axis=0)
    state_std = np.mean(np.array(state_stds), axis=0)

    for ds_idx, dataset in enumerate(all_datasets):
        if args.norm_obs:
            dataset.state_mean = state_mean
            dataset.state_std = state_std
            dataset.normalize_states()

    dataset_state_dim = all_datasets[0].trajectories[0]["observations"][0].shape[0]
    dataset_act_dim = all_datasets[0].trajectories[0]["actions"][0].shape[0]

    np.save(f"{exp_dir}/state_mean.npy", state_mean)
    np.save(f"{exp_dir}/state_std.npy", state_std)

    env = gym.make(env_name)
    _ = env.reset(seed=args.seed)
    try:
        act_dim = env.action_space.shape[0]
    except IndexError:
        if len(env.action_space.shape) == 0:
            act_dim = 1
        else:
            raise NotImplementedError("Could not determine action dim from eval env")

    state_dim = len(args.use_state_dims)
    assert state_dim == dataset_state_dim, "eval env and dataset state dim mismatch"
    assert act_dim == dataset_act_dim, "eval env and dataset action dim mismatch"

    model = DecisionTransformer(
        state_dim=state_dim,
        act_dim=act_dim,
        n_blocks=n_blocks,
        h_dim=embed_dim,
        context_len=context_len,
        n_heads=n_heads,
        transformer_drop_p=transformer_dropout_p,
        mlp_drop_p=mlp_dropout_p,
        mlp_num_layers=args.pred_mlp_num_layers,
        which_model=args.model,
        traj_prompt_j=args.traj_prompt_j,
        traj_prompt_h=args.traj_prompt_h,
    ).to(device)

    if args.load_path:
        model.load_state_dict(torch.load(args.load_path, weights_only=False, map_location=torch.device('cpu')))

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wt_decay)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda steps: min((steps+1)/warmup_steps, 1))

    total_updates = 0

    for i_train_iter in range(max_train_iters):
        print("starting training iteration: " + str(i_train_iter), flush=True)

        log_action_losses = []
        log_state_losses = []
        log_rtg_losses = []
        model.train()

        for _ in range(num_updates_per_iter):

            all_timesteps = []
            all_states = []
            all_actions = []
            all_rtgs = []
            all_traj_masks = []
            all_prompt_timesteps = []
            all_prompt_states = []
            all_prompt_actions = []
            all_prompt_rtgs = []
            for ds_idx in range(len(data_iters)):
                try:
                    timesteps, states, actions, returns_to_go, traj_mask, traj_prompt_timesteps, traj_prompt_states, traj_prompt_actions, traj_prompt_rtgs = next(data_iters[ds_idx])
                except StopIteration:
                    data_iters[ds_idx] = iter(data_loaders[ds_idx])
                    timesteps, states, actions, returns_to_go, traj_mask, traj_prompt_timesteps, traj_prompt_states, traj_prompt_actions, traj_prompt_rtgs = next(data_iters[ds_idx])

                timesteps = timesteps.to(device)
                states = states.to(device)
                actions = actions.to(device)
                returns_to_go = returns_to_go.to(torch.float32).to(device).unsqueeze(dim=-1)
                traj_mask = traj_mask.to(device)
                traj_prompt_timesteps = traj_prompt_timesteps.to(device)
                traj_prompt_states = traj_prompt_states.to(device)
                traj_prompt_actions = traj_prompt_actions.to(device)
                traj_prompt_rtgs = traj_prompt_rtgs.to(torch.float32).to(device).unsqueeze(dim=-1)

                all_timesteps.append(timesteps)
                all_states.append(states)
                all_actions.append(actions)
                all_rtgs.append(returns_to_go)
                all_traj_masks.append(traj_mask)
                all_prompt_timesteps.append(traj_prompt_timesteps)
                all_prompt_states.append(traj_prompt_states)
                all_prompt_actions.append(traj_prompt_actions)
                all_prompt_rtgs.append(traj_prompt_rtgs)

            timesteps = torch.cat(all_timesteps, dim=0)
            states = torch.cat(all_states, dim=0)
            actions = torch.cat(all_actions, dim=0)
            returns_to_go = torch.cat(all_rtgs, dim=0)
            traj_mask = torch.cat(all_traj_masks, dim=0)
            traj_prompt_timesteps = torch.cat(all_prompt_timesteps, dim=0)
            traj_prompt_states = torch.cat(all_prompt_states, dim=0)
            traj_prompt_actions = torch.cat(all_prompt_actions, dim=0)
            traj_prompt_rtgs = torch.cat(all_prompt_rtgs, dim=0)

            num_datasets = len(data_iters)
            assert timesteps.shape == (num_datasets * batch_size, context_len)
            assert states.shape == (num_datasets * batch_size, context_len, state_dim)
            assert actions.shape == (num_datasets * batch_size, context_len, act_dim)
            assert returns_to_go.shape == (num_datasets * batch_size, context_len, 1)
            assert traj_mask.shape == (num_datasets * batch_size, context_len)
            assert traj_prompt_timesteps.shape == (num_datasets * batch_size, args.traj_prompt_j * args.traj_prompt_h)
            assert traj_prompt_states.shape == (num_datasets * batch_size, args.traj_prompt_j * args.traj_prompt_h, state_dim)
            assert traj_prompt_actions.shape == (num_datasets * batch_size, args.traj_prompt_j * args.traj_prompt_h, act_dim)
            assert traj_prompt_rtgs.shape == (num_datasets * batch_size, args.traj_prompt_j * args.traj_prompt_h, 1)

            action_target = torch.clone(actions).detach().to(device)
            state_target = torch.clone(states).detach().to(device)
            rtg_pred_target = torch.clone(returns_to_go).detach().to(device)

            state_preds, action_preds, return_preds, action_logits, action_preds_prompt, state_prompt_preds, return_prompt_preds = model.forward(
                timesteps=timesteps,
                states=states,
                actions=actions,
                returns_to_go=returns_to_go,
                traj_prompt_timesteps=traj_prompt_timesteps,
                traj_prompt_states=traj_prompt_states,
                traj_prompt_actions=traj_prompt_actions,
                traj_prompt_rtgs=traj_prompt_rtgs
                )

            # only consider non padded elements
            action_preds = action_preds.reshape(-1, act_dim)[traj_mask.view(-1,) > 0]
            action_targets = action_target.view(-1, act_dim)[traj_mask.view(-1,) > 0]

            state_preds = state_preds.reshape(-1, state_dim)[traj_mask.view(-1,) > 0]
            state_targets = state_target.view(-1, state_dim)[traj_mask.view(-1,) > 0]

            rtg_preds = return_preds.reshape(-1, 1)[traj_mask.view(-1,) > 0]
            rtg_pred_targets = rtg_pred_target.view(-1, 1)[traj_mask.view(-1,) > 0]

            loss = torch.tensor(0.0).to(device)

            # main loss is for predicting the actions in sequence
            if args.loss_fn == "cross_entropy":
                assert args.action_space == "discrete", "cross entropy loss only for discrete action space"
                action_logits = action_logits.view(-1, 4)[traj_mask.view(-1, ) > 0]
                action_loss = F.cross_entropy(action_logits, action_targets.squeeze(), reduction='mean')
            else:
                action_loss = F.mse_loss(action_preds, action_targets, reduction='mean')
            n_loss_terms = 1
            loss += action_loss
            log_action_losses.append(action_loss.detach().cpu().item())

            # optional loss for rtgs in sequence
            if args.use_rtg_loss:
                rtg_loss = F.mse_loss(rtg_preds, rtg_pred_targets, reduction='mean')
                n_loss_terms += 1
                loss += rtg_loss
                log_rtg_losses.append(rtg_loss.detach().cpu().item())

            # optional loss for states in sequence
            if args.use_state_loss:
                state_loss = F.mse_loss(state_preds, state_targets, reduction='mean')
                n_loss_terms += 1
                loss += state_loss
                log_state_losses.append(state_loss.detach().cpu().item())

            loss /= n_loss_terms

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.25)
            optimizer.step()
            scheduler.step()

        if i_train_iter % 25 == 0:
            i_save_path = f"{exp_dir}/model_{i_train_iter}.pt"
            print("saving current model at: " + i_save_path, flush=True)
            torch.save(model.state_dict(), i_save_path)

        eval_dict = {}
        if env is not None and i_train_iter % 1 == 0:
            for rtg_target in rtg_targets:
                results = evaluate_on_pointDirEnv(
                    model=model,
                    device=device,
                    context_len=context_len,
                    env=env,
                    rtg_target=rtg_target,
                    rtg_scale=rtg_scale,
                    num_eval_ep=num_eval_ep,
                    max_test_ep_len=max_eval_ep_len,
                    state_mean=state_mean,
                    state_std=state_std,
                    state_dim=state_dim,
                    act_dim=act_dim,
                    env_id=env_name,
                    n_traj_prompt_segments=args.traj_prompt_j,
                    traj_prompt_seg_len=args.traj_prompt_h,
                    train_datasets=all_datasets,
                    use_state_dims=use_state_dims,
                )

                eval_dict[f"{str(rtg_target)}_return"] = results["eval/avg_reward"]
                eval_dict[f"{str(rtg_target)}_len"] = results["eval/avg_ep_len"]
        else:
            print(f"skipping DT rtg evaluation after epoch {i_train_iter}", flush=True)
            for rtg_target in rtg_targets:
                eval_dict[f"{str(rtg_target)}_return"] = -1
                eval_dict[f"{str(rtg_target)}_len"] = -1

        mean_action_loss = np.mean(log_action_losses)
        mean_state_loss = np.mean(log_state_losses) if len(log_state_losses) > 0 else 0
        mean_rtg_loss = np.mean(log_rtg_losses) if len(log_rtg_losses) > 0 else 0
        time_elapsed = str(datetime.now().replace(microsecond=0) - start_time)

        total_updates += num_updates_per_iter

        log_str = ("=" * 60 + '\n' +
           "train iter: " + str(i_train_iter) + '\n' +
           "time elapsed: " + time_elapsed  + '\n' +
           "num of updates: " + str(total_updates) + '\n' +
           "action loss: " +  format(mean_action_loss, ".5f") + '\n' +
           "state_loss: " + format(mean_state_loss, ".5f") + '\n' +
           "rtg_loss: " + format(mean_rtg_loss, ".5f") + '\n'
        )
        writer.add_scalar("dt/action_loss", mean_action_loss, total_updates)
        writer.add_scalar("dt/rtg_loss", mean_rtg_loss, total_updates)
        writer.add_scalar("dt/state_loss", mean_state_loss, total_updates)
        for rtg_target in rtg_targets:
            log_str += f"eval avg return ({rtg_target}): " + format(eval_dict[f"{rtg_target}_return"], ".5f") + '\n'
            writer.add_scalar(f"dt/eval_rtg{rtg_target}", eval_dict[f"{rtg_target}_return"], total_updates)
            log_str += f"eval avg ep len ({rtg_target}): " + str(eval_dict[f"{rtg_target}_len"]) + '\n'
            writer.add_scalar(f"dt/eval_rtg{rtg_target}_len", eval_dict[f"{rtg_target}_len"], total_updates)

        print(log_str, flush=True)

        log_data = [time_elapsed, total_updates, mean_action_loss, mean_state_loss, mean_rtg_loss]
        for rtg in rtg_targets:
            log_data.append(eval_dict[f"{rtg}_return"])
            log_data.append(eval_dict[f"{rtg}_len"])

        csv_writer.writerow(log_data)

        print("saving current model at: " + save_model_path, flush=True)
        torch.save(model.state_dict(), save_model_path)

    print("=" * 60, flush=True)
    print(f"finished training after {max_train_iters} epochs!", flush=True)
    print("=" * 60, flush=True)
    end_time = datetime.now().replace(microsecond=0)
    time_elapsed = str(end_time - start_time)
    end_time_str = end_time.strftime("%y-%m-%d-%H-%M-%S")
    print("started training at: " + start_time_str, flush=True)
    print("finished training at: " + end_time_str, flush=True)
    print("total training time: " + time_elapsed, flush=True)
    print("saved max score model at: " + save_best_model_path, flush=True)
    print("saved last updated model at: " + save_model_path, flush=True)
    print("=" * 60, flush=True)

    return exp_dir


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument('--exp_str', type=str, default="default")
    parser.add_argument('--seed', type=int, default=0)

    parser.add_argument('--model', type=str, default='traj_pdt', help="DT for decision transformer, traj_pdt for PDT")
    parser.add_argument("--load_path", type=str, default="", help="Path to load model from")
    parser.add_argument('--traj_prompt_j', type=int, default=1, help="The number J of episode segments in the trajectory prompt")
    parser.add_argument('--traj_prompt_h', type=int, default=3, help="The number H of steps per episode segment in the trajectory prompt")
    parser.add_argument('--state_loss', action='store_true', dest="use_state_loss", help="Whether to include loss for state_prediction", default=False)
    parser.add_argument('--rtg_loss', action='store_true', dest='use_rtg_loss', help="Whether to include loss for rtg_prediction", default=False)
    parser.add_argument('--norm_obs', action='store_true', dest='norm_obs', default=True, help="Whether to normalize the observations in the dataset")

    parser.add_argument('--env', type=str, default='CircleStopEnv-randomAngle_randomRad-v0')
    parser.add_argument('--rtg_targets', nargs='+', default=[0, -10], help="The return targets to evaluate the DT on")
    parser.add_argument('--rtg_scale', type=int, default=1, help="Divide the returns to go by this factor")
    parser.add_argument('--num_eval_ep', type=int, default=10, help="Number of episodes to evaluate the DT after each training iteration")

    parser.add_argument('--loss_fn', type=str, default='mse')
    parser.add_argument('--max_eval_ep_len', type=int, default=100)
    parser.add_argument('--dataset_dirs', nargs='+', default=[
        "PPO_runs/CircleStopEnv-angle0.0_radius0.9-v0_2_25-02-12_20-50-23_seed2",
        "PPO_runs/CircleStopEnv-angle0.0_radius1.9-v0_2_25-02-12_20-51-28_seed2",
        "PPO_runs/CircleStopEnv-angle0.0_radius2.9-v0_2_25-02-12_20-52-27_seed2",
    ], help="The per-task datasets to train the PDT on")
    parser.add_argument('--train_dataset_subset_idxs', nargs='+', default=[0, 1, 2], help="Indices of datasets to use for training. This can be used to train on a subset but evaluate on all tasks.")

    parser.add_argument('--use_state_dims', nargs='+', default=[0, 1], help="The indices of the state dimensions to use with the PDT. Can be used to train on a subset of the state dimensions, to hide certain information from the model.")
    parser.add_argument('--pdt_use_sparse_reward', action='store_true', dest='pdt_use_sparse_reward', default=True, help="If true, use the 'sparse_reward' field from info to overwrite the reward in the dataset")
    parser.add_argument('--dataset_every_nth_traj', type=int, default=1, help="Only use every n-th trajectory in the dataset, for making smaller datasets if needed")

    parser.add_argument('--context_len', type=int, default=5)
    parser.add_argument('--n_blocks', type=int, default=3)
    parser.add_argument('--embed_dim', type=int, default=128)
    parser.add_argument('--n_heads', type=int, default=1)
    parser.add_argument('--transformer_dropout_p', type=float, default=0.1)
    parser.add_argument('--pred_mlp_dropout_p', type=float, default=0.0)
    parser.add_argument('--pred_mlp_num_layers', type=int, default=2)

    parser.add_argument('--batch_size', type=int, default=64)  # per dataset
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--wt_decay', type=float, default=1e-4)
    parser.add_argument('--warmup_steps', type=int, default=10000)

    parser.add_argument('--max_train_iters', type=int, default=100)
    parser.add_argument('--num_updates_per_iter', type=int, default=100)

    parser.add_argument('--cuda', action='store_true', default=True)

    args = parser.parse_args()

    exp_dir = train(args)

    # EXTREMELY IMPORTANT, last print statement should be experiment dir for capturing it from bash
    print(exp_dir)
