import os
import os.path as osp

import gym

import diffgro
from diffgro.environments import make_env as _make_env
from diffgro.common.evaluations import evaluate, evaluate_complex
from diffgro.utils.config import load_config, save_config
from diffgro.utils import Parser, make_dir, print_r, print_y, print_b
from train import *


def train(args):
    # 0. Make Dummy Environment
    domain_name, task_name = args.env_name.split(".")
    if domain_name == 'metaworld':
        env = _make_env('metaworld', 'push-variant-v2')
    if domain_name == 'metaworld_complex':
        env = _make_env('metaworld_complex', 'puck-drawer-button-stick-variant-v2')
    print_y(f"Obs Space: {env.observation_space.shape}, Act Space: {env.action_space.shape}")

    # 1. Save Path
    if args.phase == 0:
        save_path = osp.join("./results/langdt", domain_name, task_name, args.tag)
    elif args.phase == 1:
        save_path = osp.join("./results/langdt", domain_name, domain_name, args.tag)
    else:
        raise NotImplementedError
    
    # 2. Make Buffer
    buff, task_list = make_buff(args, env)
    num_task = len(task_list)
    print_r(f"Number of tasks {num_task}")

    config = load_config('./config/algos/langdt.yml', domain_name)
    if args.train:
        # 3. Load Config
        config['planner']['params']['batch_size'] = \
            config['planner']['params']['batch_size'] * num_task if num_task > 1 else 64
        config['planner']['training']['total_timesteps'] = \
            1_000_000 if num_task > 1 else 200_000
        config['planner']['params']['seed'] = args.seed
        config['planner']['datasets'] = args.dataset_path

        # 4. Make Models
        model_path = save_path + "/planner"
        model = diffgro.LangDTPlanner(
            env=env,
            replay_buffer=buff,
            **config['planner']['params'],
            verbose=1,
            policy_kwargs=config['planner']['policy_kwargs'],
        )
    
        # 5. Training & Evaluation
        make_dir(save_path)
        save_config(save_path, config) # save configs
        model.learn(**config['planner']['training'])
        model.save(path=model_path)
    if args.test:
        print_y("Loading ... LangDT ...")
        planner = diffgro.LangDTPlanner.load(save_path + "/planner")
        if task_name != "all": task_list = [task_name] # evaluate one task
        # context dir
        if args.guide is not None:
            save_path = os.path.join(save_path, 'context', f"{args.guide}")
            make_dir(save_path)

        # # # # # # # # # # 
        tot_success = []
        for task in task_list:
            args.env_name = f"{domain_name}.{task}"
            env, domain_name, task_name = make_env(args)
            model = diffgro.LangDT(
                env, 
                planner, 
                verbose=args.verbose,
            )
            if args.guide is None:
                if domain_name == "metaworld":
                    success = evaluate(model, env, domain_name, task_name, args.n_episodes, True, args.video, save_path)
                else:
                    success = evaluate_complex(model, env, domain_name, task_name, args.n_episodes, True, args.video, save_path)
                tot_success.extend(success)
            else:
                # retreive context from configuration
                contexts = make_context(domain_name, task_name, args.multimodal)
                if args.prompt_idx is not None: contexts = contexts[args.prompt_idx:args.prompt_idx+1]
                print_y(f"Evaluating Contexts {contexts}")
                context_success, context_success_rew = [], []
                for context in contexts:
                    print_r(f"Guide: {args.guide} with Context: {context}")
                    # save the context
                    with open(os.path.join(save_path, 'evaluation.txt'), 'a') as f:
                        f.write(f'Context: {context}\n')
                    # setting context
                    # evaluation
                    if domain_name == "metaworld":
                        success, success_context_rew = evaluate(
                            model, env, domain_name, task_name, args.n_episodes, True, args.video, save_path, context=context)
                    else:
                        success, success_context_rew = evaluate_complex(
                            model, env, domain_name, task_name, args.n_episodes, True, args.video, save_path, context=context)
                    tot_success.extend(success)
                    context_success.extend(success)
                    context_success_rew.extend(success_context_rew)

        if len(task_list) > 1:
            eval_save(tot_success, save_path)
           

if __name__ == "__main__":
    args = Parser("train").parse_args()
    train(args)
