# Copyright (C) king.com Ltd 2025
# License: Apache 2.0
from ast import parse
import gym
import numpy as np
import torch
import wandb

import argparse
import random
import itertools

from cql.cql import CQL
from cql.trainer import CQLTrainer, get_batch_cql
from prompt_dt.prompt_utils import get_env_list
from prompt_dt.prompt_utils import process_total_data_mean, load_data_prompt, process_info

from collections import namedtuple
import json, pickle, os


def experiment_mix_env(
        exp_prefix,
        variant,
):
    device = variant['device']
    log_to_wandb = variant['log_to_wandb']
    seed = variant['seed']
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    ######
    # construct train and test environments
    ######

    cur_dir = os.getcwd()
    config_save_path = os.path.join(cur_dir, 'config')
    data_save_path = os.path.join(cur_dir, 'data')
    save_path = os.path.join(cur_dir, 'model_saved/')
    if not os.path.exists(save_path): os.mkdir(save_path)

    config_path_dict = {
        'cheetah_vel': "cheetah_vel/cheetah_vel_40.json",
        'cheetah_dir': "cheetah_dir/cheetah_dir_2.json",
        'ant_dir': "ant_dir/ant_dir_50.json",
        'ML1-pick-place-v2': "ML1-pick-place-v2/ML1_pick_place.json",
    }

    task_config = os.path.join(config_save_path, config_path_dict[args.env])
    with open(task_config, 'r') as f:
        task_config = json.load(f, object_hook=lambda d: namedtuple('X', d.keys())(*d.values()))
    train_env_name_list_all, test_env_name_list_all = [], []
    for task_ind in task_config.train_tasks:
        train_env_name_list_all.append(args.env + '-' + str(task_ind))
    for task_ind in task_config.test_tasks:
        test_env_name_list_all.append(args.env + '-' + str(task_ind))

    # CQL just trains on a single task
    task_idx = variant["train_task_idx"]

    if variant["eval_train_task"]:
        env_name = train_env_name_list_all[task_idx]
        env_name_list = [train_env_name_list_all[task_idx]]
    else:
        env_name = test_env_name_list_all[task_idx]
        env_name_list = [test_env_name_list_all[task_idx]]


    # training envs
    info, env_list = get_env_list(env_name_list, config_save_path, device, seed)

    batch_size = variant['batch_size']
    pct_traj = variant.get('pct_traj', 1.)
    mode = variant.get('mode', 'normal')
    dataset_mode = variant['dataset_mode']
    train_prompt_mode = variant['train_prompt_mode']

    # load train dataset
    trajectories_list, _ = load_data_prompt(env_name_list, data_save_path, dataset_mode, train_prompt_mode, args)

    # change to total train trajecotry 
    if variant['average_state_mean']:
        train_total = list(itertools.chain.from_iterable(trajectories_list))
        total_traj_list = train_total
        total_state_mean, total_state_std = process_total_data_mean(total_traj_list, mode)
        variant['total_state_mean'] = total_state_mean
        variant['total_state_std'] = total_state_std

    # process train info
    info = process_info(env_name_list, trajectories_list, info, mode, dataset_mode, pct_traj, variant)

    exp_prefix = exp_prefix + '-' + args.env + '-' + str(task_idx)
    num_env = len(env_name_list)
    group_name = f'{exp_prefix}-{str(num_env)}-Env-{dataset_mode}'
    exp_prefix = f'{group_name}-seed-{seed}'
    if variant['extra_exp_str']:
        exp_prefix = f'{exp_prefix}-{variant["extra_exp_str"]}'

    state_dim = env_list[0].observation_space.shape[0]
    act_dim = env_list[0].action_space.shape[0]

    model = CQL(
        state_dim=state_dim,
        act_dim=act_dim,
        device=device,
        policy_bc_loss_steps=variant['policy_BC_steps'],
    )
    model = model.to(device=device)

    assert len(env_name_list) == 1
    assert len(env_list) == 1
    env_name = env_name_list[0]
    env = env_list[0]
    info = info[env_name]

    trainer = CQLTrainer(
        model=model,
        get_prompt_batch=get_batch_cql(trajectories=trajectories_list[0], info=info, batch_size=batch_size),
    )

    if not variant['evaluation']:
        ######
        # start training
        ######
        if log_to_wandb:
            wandb.init(
                name=exp_prefix,
                group=group_name,
                project='prompt-decision-transformer-CQL',
                config=variant
            )
            save_path += wandb.run.name
            if not os.path.exists(save_path):
                os.mkdir(save_path)

        for iter in range(variant['max_iters']):
            outputs = trainer.pure_train_iteration_mix(
                num_steps=variant['num_steps_per_iter'],
            )

            if iter % args.train_eval_interval == 0:
                train_eval_logs = trainer.eval_iteration_multienv(env_name=env_name, env=env, info=info, n_episodes=args.num_eval_episodes)
                outputs.update(train_eval_logs)

            if iter % variant['save_interval'] == 0:
                trainer.save_model(
                    env_name=args.env,
                    postfix='iter_' + str(iter),
                    folder=save_path)

            outputs.update({"global_step": iter})

            if log_to_wandb:
                wandb.log(outputs)

        trainer.save_model(env_name=args.env, postfix='iter_' + str(iter), folder=save_path)

    else:
        raise NotImplementedError("CQL shouldn't be evaluated on other envs, only the training task.")


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', type=str,  default='ML1-pick-place-v2')  # ['cheetah_vel', 'ant_dir', 'ML1-pick-place-v2']
    parser.add_argument('--dataset_mode', type=str, default='expert')
    parser.add_argument('--test_dataset_mode', type=str, default='expert')
    parser.add_argument('--train_prompt_mode', type=str, default='expert')
    parser.add_argument('--test_prompt_mode', type=str, default='expert')
    parser.add_argument('--seed', type=int, default=1)

    parser.add_argument('--batch_size', type=int, default=1024)
    parser.add_argument('--evaluation', action='store_true', default=False)  # if True, no training, only eval the given model, should be true for prompt tune
    parser.add_argument('--extra_exp_str', type=str, default="CQL")
    parser.add_argument('--max_iters', type=int, default=2000)
    parser.add_argument('--policy_BC_steps', type=int, default=0)
    parser.add_argument('--train_task_idx', type=int, default=0)

    parser.add_argument('--eval_train_task', action='store_true', default=False)  # if True, does evaluation (and promp-tuning) on subset of in-distribution training tasks. If talse, does evaluation (and prompt-tuning) on holdout test tasks.
    parser.add_argument('--K', type=int, default=1)  # the length of trajs for PDT. 1 to just sample one step transitions like in regular RL.
    parser.add_argument('--num_eval_episodes', type=int, default=2)
    parser.add_argument('--load-path', type=str, default=None)
    parser.add_argument('--num_steps_per_iter', type=int, default=10)
    parser.add_argument('--device', type=str, default='cuda')
    parser.add_argument('--log_to_wandb', '-w', type=bool, default=True)
    parser.add_argument('--train_eval_interval', type=int, default=10)
    parser.add_argument('--test_eval_interval', type=int, default=100)
    parser.add_argument('--save-interval', type=int, default=100)
    parser.add_argument('--mode', type=str, default='normal')
    parser.add_argument('--no_state_normalize', action='store_true', default=False)
    parser.add_argument('--average_state_mean', action='store_true', default=True)

    # aren't used for CQL...
    parser.add_argument('--num_traj_prompt_j', type=int, default=1)
    parser.add_argument('--finetune', action='store_true', default=False)
    parser.add_argument('--finetune_steps', type=int, default=10)
    parser.add_argument('--finetune_batch_size', type=int, default=256)
    parser.add_argument('--finetune_opt', action='store_true', default=True)
    parser.add_argument('--finetune_lr', type=float, default=1e-4)
    parser.add_argument('--render', action='store_true', default=False)

    args = parser.parse_args()
    experiment_mix_env('gym-experiment', variant=vars(args))
