import os 
os.environ["MUJOCO_GL"] = "osmesa"
os.environ["PYOPENGL_PLATFORM"] = "osmesa"

from ast import parse
import gym
import numpy as np
import torch
import wandb
import argparse
import random
import sys, os
import time
import itertools
from datetime import datetime
from tqdm import trange

from prompt_dt.prompt_decision_transformer import ReachAvoidTransformer, GoalTransformer
from prompt_dt.prompt_seq_trainer import PromptSequenceTrainer, PromptSequenceReachAvoidTrainer, PromptSequenceTrainerLossDebug
#from prompt_dt.prompt_utils import get_env_list
#from prompt_dt.prompt_utils import get_prompt_batch, get_prompt, get_batch, get_batch_finetune
#from prompt_dt.prompt_utils import process_total_data_mean, load_data_prompt, process_info
from prompt_dt.prompt_utils import eval_episodes, get_prompt_batch
from cosine_annealing_warmup import CosineAnnealingWarmupRestarts
from raDT.baselines.envs.maps_obstacle import *

# env name to the utils module
# the utils module contains functions for: data&env loader, prompt&sequence batch loader
import envs.mazerunner.utils as mazerunner_utils
import envs.kitchen_toy.utils as kitchen_toy_utils
# import envs.crafter.utils as crafter_utils
import envs.kitchen.utils as kitchen_utils
import envs.reach_obstacle.utils as reach_obstacle_utils
import envs.slide_obstacle.utils as slide_obstacle_utils
import envs.push_obstacle.utils as push_obstacle_utils
import envs.pointmaze_obstacle.utils as pointmaze_obstacle_utils
import envs.cardiogenesis.utils as cardiogenesis_utils
# import envs.pick_and_place_obstacle.utils as pick_and_place_obstacle_utils

from raDT.constants import *

CONFIG_DICT = {
    'mazerunner': mazerunner_utils,
    'kitchen_toy': kitchen_toy_utils,
    # 'crafter': crafter_utils,
    'kitchen': kitchen_utils,
    'reach_obstacle': reach_obstacle_utils,
    'slide_obstacle': slide_obstacle_utils,
    'push_obstacle': push_obstacle_utils,
    'pointmaze_obstacle': pointmaze_obstacle_utils,
    'cardiogenesis': cardiogenesis_utils,
    # 'pick_and_place_obstacle': pick_and_place_obstacle_utils,

}

def experiment(variant):
    exp_prefix = variant['exp_name']
    instance_prefix = variant['instance_prefix']
    device = variant['device']
    log_to_wandb = variant['log_to_wandb']

    if args.env in ['reach_obstacle']:
        variant['state_truncated_dim'] = 10
        variant['obs_start_index'] = 0
    elif args.env in ['slide_obstacle', 'push_obstacle', 'pick_and_place_obstacle']:
        variant['state_truncated_dim'] = 25 # includes object info now
        variant['obs_start_index'] = 3 # when calc avoid success, use the object pos not gripper pos
    elif args.env in ['pointmaze_obstacle']:
        variant['state_truncated_dim'] = 4 # original dim without adding avoid information
        variant['obs_start_index'] = 0 # when calc avoid success, use point position
    elif args.env in ['cardiogenesis']:
        variant['state_truncated_dim'] = 15 # original dim without adding avoid information
        variant['obs_start_index'] = 0 # when calc avoid success, use point position

    ######
    # construct train and test environments, datasets
    ######
    
    cur_dir = os.getcwd()
    #config_save_path = os.path.join(cur_dir, 'config')
    #data_save_path = os.path.join(cur_dir, 'data')
    save_path = os.path.join(SAVE_PATH, 'model_saved/')
    if not os.path.exists(save_path): os.mkdir(save_path)
    
    additional_kwargs = {}
    if args.env in ['cardiogenesis']:
        args.avoids_readable = eval(args.avoids_readable)
        if args.initial_state_list_readable:
            args.initial_state_list_readable = eval(args.initial_state_list_readable)
        additional_kwargs["fixed_interval"] = args.fixed_interval
        additional_kwargs["avoids_readable"] = args.avoids_readable
        additional_kwargs["fixed_goal_readable"] = args.fixed_goal_readable
        additional_kwargs["fixed_start_readable"] = args.fixed_start_readable
        additional_kwargs["initial_state_list_readable"] = args.initial_state_list_readable
    if args.dataset_path2 and args.d2_percent_mix:
        additional_kwargs["dataset_path2"], additional_kwargs["d2_percent_mix"] = args.dataset_path2, args.d2_percent_mix
    if args.maze and (args.num_avoid is not None):
        args.maze = eval(args.maze)
        additional_kwargs["maze"], additional_kwargs["num_avoid"] = args.maze, args.num_avoid
    if args.bsa_box_size:
        additional_kwargs["bsa_box_size"] = args.bsa_box_size
        if not args.isolated:
            args.buffer_size = args.bsa_box_size

    info, env_list, val_trajectories_list, test_info, test_env_list, test_trajectories_list, trajectories_list = \
        CONFIG_DICT[args.env].get_train_test_dataset_envs(\
            args.dataset_path, device, max_ep_len = variant['max_ep_len'], n_test_env = variant['n_test_env'], **additional_kwargs)

    print(f'Env Info: {info} \n\n Test Env Info: {test_info}\n\n\n')
    print(f'Env List: {env_list} \n\n Test Env List: {test_env_list}')

    K = variant['K']
    assert K==variant['max_ep_len'], "currently, training context K should be == max episode length"
    batch_size = variant['batch_size']
    print('Max ep length {}, training context length {}, batch size {}'.format(variant['max_ep_len'], K, batch_size))


    ######
    # construct dt model and trainer
    ######

    exp_prefix = exp_prefix + '-' + args.env
    #num_env = len(train_env_name_list)
    #group_name = f'{exp_prefix}-{str(num_env)}-Env-{dataset_mode}'
    dataset_name = variant['dataset_path'].split('/')[-1].split('.')[0] # ds filename without .pkl
    time_now = datetime.now().strftime("%Y%m%d%H%M%S")
    group_name = f'{exp_prefix}-{dataset_name}' # wandb group name
    exp_prefix = f'{exp_prefix}-{dataset_name}-{instance_prefix}-{time_now}' # wandb exp name

    state_dim = test_info[0]['state_dim'] #test_env_list[0].observation_space.shape[0]
    act_dim = test_info[0]['act_dim'] #test_env_list[0].action_space.shape[0]
    action_space = test_env_list[0].action_space
    prompt_dim = test_env_list[0].prompt_dim
    print('state {} action {} prompt goal {}'.format(state_dim, act_dim, prompt_dim))


    if variant["avoid_prompt"]:
        model = ReachAvoidTransformer(
            state_dim=state_dim,
            act_dim=act_dim,
            action_space=action_space,
            prompt_dim=prompt_dim,
            max_length=K,
            max_ep_len=variant['max_ep_len'], 
            hidden_size=variant['embed_dim'],
            n_layer=variant['n_layer'],
            n_head=variant['n_head'],
            n_inner=4 * variant['embed_dim'],
            activation_function=variant['activation_function'],
            n_positions=2048,
            resid_pdrop=variant['dropout'],
            attn_pdrop=variant['dropout'],
            adelta=variant['adelta']
        )
    else:
        model = GoalTransformer(
            state_dim=state_dim,
            act_dim=act_dim,
            action_space=action_space,
            prompt_dim=prompt_dim,
            max_length=K,
            max_ep_len=variant['max_ep_len'], 
            hidden_size=variant['embed_dim'],
            n_layer=variant['n_layer'],
            n_head=variant['n_head'],
            n_inner=4 * variant['embed_dim'],
            activation_function=variant['activation_function'],
            n_positions=2048,
            resid_pdrop=variant['dropout'],
            attn_pdrop=variant['dropout']
        )
    
    model = model.to(device=device)

    if variant['load_path']:
        saved_model_path = os.path.join(save_path, variant['load_path'])
        model.load_state_dict(torch.load(saved_model_path, map_location=torch.device(device)), strict=True)
        print('model initialized from: ', saved_model_path)

    warmup_steps = variant['warmup_steps']
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=variant['learning_rate'],
        weight_decay=variant['weight_decay'],
    )

    if (variant["scheduler"] is None) or (variant["scheduler"] == 'lambdalr'):
        scheduler = torch.optim.lr_scheduler.LambdaLR(
            optimizer,
            lambda steps: min((steps + 1) / warmup_steps, 1)
        )
    elif variant["scheduler"] == 'cosinewarmrestarts':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer,
            T_0 = variant["T_0"],
            T_mult = variant["T_mult"],
            eta_min = variant["eta_min"],
            last_epoch = variant["last_epoch"]
        )
    elif variant["scheduler"] == 'cosinewarmuprestarts':
        scheduler = CosineAnnealingWarmupRestarts(
            optimizer,
            first_cycle_steps = variant["T_0"],
            cycle_mult = variant["T_mult"],
            max_lr = variant["learning_rate"],
            min_lr = variant["learning_rate_min"],
            warmup_steps = variant["warmup_steps"],
            gamma = variant["cosine_gamma"],
            last_epoch = variant["last_epoch"]
        )
    #env_name = train_env_name_list[0]
    if variant["avoid_prompt"]:
        # trainer = PromptSequenceTrainerLossDebug(
        #     model=model,
        #     optimizer=optimizer,
        #     batch_size=batch_size,
        #     #get_batch=get_batch(trajectories_list[0], info[env_name], variant),
        #     scheduler=scheduler,
        #     #loss_fn=lambda s_hat, a_hat, r_hat, s, a, r: torch.mean((a_hat - a) ** 2),
        #     eval_fns=None,
        #     get_prompt=CONFIG_DICT[args.env].get_prompt,
        #     get_avoid_prompt=CONFIG_DICT[args.env].get_avoid_prompt,
        #     get_prompt_batch=get_prompt_batch(trajectories_list, test_info[0], variant, CONFIG_DICT[args.env].get_prompt, CONFIG_DICT[args.env].get_avoid_prompt),
        #     alpha2=variant["alpha2"]
        # )
        kwargs = {}
        if args.n_prioritized_traj:
            kwargs["n_prioritized_traj"] = args.n_prioritized_traj
            kwargs["n_prioritized_per_batch"] = args.n_prioritized_per_batch
        trainer = PromptSequenceReachAvoidTrainer(
            model=model,
            optimizer=optimizer,
            batch_size=batch_size,
            #get_batch=get_batch(trajectories_list[0], info[env_name], variant),
            scheduler=scheduler,
            #loss_fn=lambda s_hat, a_hat, r_hat, s, a, r: torch.mean((a_hat - a) ** 2),
            eval_fns=None,
            get_prompt=CONFIG_DICT[args.env].get_prompt,
            get_avoid_prompt=CONFIG_DICT[args.env].get_avoid_prompt,
            get_prompt_batch=get_prompt_batch(trajectories_list, test_info[0], variant, CONFIG_DICT[args.env].get_prompt, CONFIG_DICT[args.env].get_avoid_prompt, **kwargs),
            alpha1=variant["alpha1"],
            alpha2=variant["alpha2"],
            buffer_size=variant["buffer_size"],
            obs_start_index=variant["obs_start_index"],
            prompt_dim=prompt_dim
        )
    else:
        trainer = PromptSequenceTrainer(
            model=model,
            optimizer=optimizer,
            batch_size=batch_size,
            #get_batch=get_batch(trajectories_list[0], info[env_name], variant),
            scheduler=scheduler,
            #loss_fn=lambda s_hat, a_hat, r_hat, s, a, r: torch.mean((a_hat - a) ** 2),
            eval_fns=None,
            get_prompt=CONFIG_DICT[args.env].get_prompt,
            get_prompt_batch=get_prompt_batch(trajectories_list, test_info[0], variant, CONFIG_DICT[args.env].get_prompt)
        )


    if not variant['evaluation']:
        ######
        # start training
        ######
        if log_to_wandb:
            wandb.init(
                name=exp_prefix,
                group=group_name,
                project='goal-dt',
                config=variant
            )
            save_path += wandb.run.name
            os.mkdir(save_path)

        # construct model post fix
        '''
        model_post_fix = '_TRAIN_'+variant['train_prompt_mode']+'_TEST_'+variant['test_prompt_mode']
        if variant['no_prompt']:
            model_post_fix += '_NO_PROMPT'
        '''
        model_post_fix = ''
        
        for iter in trange(variant['max_iters']):
            # train for many batches
            outputs = trainer.train(
                num_steps=variant['num_steps_per_iter'], 
                no_prompt=False #args.no_prompt
                )

            # start evaluation
            if iter % args.test_eval_interval == 0 and iter>0:
                # evaluate on unseen test tasks
                test_eval_logs = trainer.eval_iteration_multienv(test_trajectories_list,
                    eval_episodes, test_info, variant, test_env_list, iter_num=iter + 1, 
                    print_logs=True, no_prompt=False, group='test')
                outputs.update(test_eval_logs)

                # evaluate on some training tasks
                if args.test_on_training_tasks:
                    train_eval_logs = trainer.eval_iteration_multienv(val_trajectories_list,
                        eval_episodes, info, variant, env_list, iter_num=iter + 1, 
                        print_logs=True, no_prompt=False, group='train')
                    outputs.update(train_eval_logs)

            if iter % variant['save_interval'] == 0:
                trainer.save_model(
                    env_name=args.env, 
                    postfix=model_post_fix+'_iter_'+str(iter), 
                    folder=save_path)

            outputs.update({"global_step": iter}) # set global step as iteration

            if log_to_wandb:
                wandb.log(outputs)
        
        trainer.save_model(env_name=args.env,  postfix=model_post_fix+'_iter_'+str(iter),  folder=save_path)

    else:
        ####
        # start evaluating
        ####
        saved_model_path = os.path.join(save_path, variant['load_path'])
        model.load_state_dict(torch.load(saved_model_path, map_location=torch.device(device)),  strict=True)
        print('model initialized from: ', saved_model_path)
        eval_iter_num = int(saved_model_path.split('_')[-1])

        eval_logs = trainer.eval_iteration_multienv(test_trajectories_list, eval_episodes, 
            test_info, variant, test_env_list, iter_num=eval_iter_num, print_logs=True, 
            no_prompt=False, group='test', recording_prefix = variant['recording_prefix'])

        
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--exp-name', type=str, default='gym-experiment')
    parser.add_argument('--instance_prefix', type=str, default='gym-experiment', help='prefix for specific instance in a group')
    parser.add_argument('--env', type=str, default='mazerunner')
    parser.add_argument('--dataset_path', type=str, default='envs/mazerunner/mazerunner-d15-g1-t50-astar.pkl')
    parser.add_argument('--dataset_path2', type=str, default=None)
    parser.add_argument('--d2_percent_mix', type=float, default=None)
    parser.add_argument('--test_optimal_prompt', action='store_true', default=False) # use 'optimal_prompts' saved in trajectories for test?

    parser.add_argument('--evaluation', action='store_true', default=False) 
    parser.add_argument('--render', action='store_true', default=False) 
    parser.add_argument('--load-path', type=str, default= None) # choose a model when in evaluation mode

    parser.add_argument('--max_prompt_len', type=int, default=5) # max len of sampled prompt
    parser.add_argument('--avoid_prompt', action='store_true', default=False)
    parser.add_argument('--max_avoid_prompt_len', type=int, default=10) # max len of sampled prompt
    parser.add_argument('--max_ep_len', type=int, default=50) # max episode len in both dataset & env
    parser.add_argument('--K', type=int, default=50) # max Transformer context len (the whole sequence is max_prompt_len+K)
    parser.add_argument('--subsample_trajectory', action='store_true', default=False) # subsample during training?
    parser.add_argument('--subsample_min_len', type=int, default=-1) # subsample traj[0:l], l~U[min_len, traj_len]
    parser.add_argument('--gamma', type=float, default=0.98, help='the discount factor')
    parser.add_argument('--adelta', type=float, default=0.0, help='delta for boosting')

    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--embed_dim', type=int, default=128)
    parser.add_argument('--n_layer', type=int, default=3)
    parser.add_argument('--n_head', type=int, default=1)
    parser.add_argument('--n_positions', type=int, default=1024, help='n_positions argument for transformer max context length')
    parser.add_argument('--activation_function', type=str, default='relu')
    parser.add_argument('--dropout', type=float, default=0.1)
    parser.add_argument('--learning_rate', '-lr', type=float, default=1e-4)
    parser.add_argument('--weight_decay', '-wd', type=float, default=1e-4)
    parser.add_argument('--warmup_steps', type=int, default=10000) # 10000*(number of environments)
    parser.add_argument('--num_eval_episodes', type=int, default=5) 
    parser.add_argument('--max_iters', type=int, default=5000) 
    parser.add_argument('--num_steps_per_iter', type=int, default=10)
    parser.add_argument('--n_test_env', type=int, default=50) 
    parser.add_argument('--device', type=str, default='cuda')
    parser.add_argument('--log_to_wandb', '-w', type=bool, default=True)
    parser.add_argument('--test_on_training_tasks', action='store_true', default=False) # eval on training tasks in addition to test tasks?
    parser.add_argument('--test_eval_interval', type=int, default=200)
    parser.add_argument('--save-interval', type=int, default=500)

    parser.add_argument('--scheduler', type=str, default=None)
    
    parser.add_argument('--T_0', type=int, default=10000)
    parser.add_argument('--T_mult', type=int, default=1)
    parser.add_argument('--eta_min', type=float, default=0.0)
    parser.add_argument('--last_epoch', type=int, default=-1)
    
    parser.add_argument('--learning_rate_min', type=float, default=0)
    # parser.add_argument('--warmup_steps', type=int, default=0)
    parser.add_argument('--cosine_gamma', type=float, default=1)

    parser.add_argument('--alpha1', type=float, default=1, help='how much to weigh success loss relative to action')
    parser.add_argument('--alpha2', type=float, default=1.1, help='how much to weigh neg samples in success loss relative to pos')

    parser.add_argument('--buffer_size', type=float, default=0.06, help='how much buffer around avoid boxes at eval time')
    
    parser.add_argument('--recording_prefix', type=str, default=None, help='prefix for video recording, if None then no recording')

    parser.add_argument('--maze', type=str, default=None, help='if maze env, type of maze')
    parser.add_argument('--num_avoid', type=int, default=None, help='if maze env, number of avoid states')
    parser.add_argument('--bsa_box_size', type=float, default=None, help='box size to use for box size analysis (half of total width)')
    parser.add_argument('--isolated', action='store_true', default=False, help='flag, if present then bsa_box_size and buffer_size are different and should take on spec values')

    parser.add_argument('--fixed_interval', type=int, default=10, help='card fixed interval')
    parser.add_argument('--avoids_readable', type=str, default="[]", help='list of avoid states')
    parser.add_argument('--fixed_goal_readable', type=str, default=None, help='should the goal be fixed instead of sampled in the env, for eval')
    parser.add_argument('--fixed_start_readable', type=str, default=None, help='should the start be fixed instead of sampled in the env, for eval')
    parser.add_argument('--initial_state_list_readable', type=str, default=None, help='list of possible initial states to sample from when not fixed start. if None defaults to attractors')
    args = parser.parse_args()
    # experiment(args.exp_name, variant=vars(args))
    experiment(variant=vars(args))
