import os
import os.path as osp

import gym

import diffgro
from diffgro.environments import make_env as _make_env
from diffgro.common.evaluations import evaluate, evaluate_complex
from diffgro.utils.config import load_config, save_config
from diffgro.utils import Parser, make_dir, print_r, print_y, print_b
from train import *


def train(args):
    # 0. Make Dummy Environment
    domain_name, task_name = args.env_name.split(".")
    if domain_name == "metaworld":
        env = _make_env("metaworld", "push-variant-v2")
    if domain_name == "metaworld_complex":
        env = _make_env("metaworld_complex", "puck-drawer-button-stick-variant-v2")
    print_y(
        f"Obs Space: {env.observation_space.shape}, Act Space: {env.action_space.shape}"
    )

    # 1. Save Path
    if args.phase == 0:
        save_path = osp.join("./results/lcd", domain_name, task_name, args.tag)
    elif args.phase == 1:
        save_path = osp.join("./results/lcd", domain_name, domain_name, args.tag)
    else:
        raise NotImplementedError

    # 2. Make Buffer
    buff, task_list = make_buff(args, env)
    num_task = len(task_list)
    print_r(f"Number of tasks {num_task}")

    config = load_config("./config/algos/lcd.yml", domain_name)
    if args.train:
        # 3. Load Config
        config["planner"]["params"]["batch_size"] = (
            config["planner"]["params"]["batch_size"] * num_task if num_task > 1 else 64
        )
        config["planner"]["training"]["total_timesteps"] = (
            1_000_000 if num_task > 1 else 200_000
        )
        config["planner"]["params"]["seed"] = args.seed
        config["planner"]["datasets"] = args.dataset_path

        # 4. Make Models
        model_path = save_path + "/planner"
        model = diffgro.LCDPlanner(
            env=env,
            replay_buffer=buff,
            **config["planner"]["params"],
            verbose=1,
            policy_kwargs=config["planner"]["policy_kwargs"],
        )

        # 4. Training & Evaluation
        make_dir(save_path)
        save_config(save_path, config)  # save configs
        model.learn(**config["planner"]["training"])
        model.save(path=model_path)
    if args.test:
        planner = diffgro.LCDPlanner.load(save_path + "/planner")
        if task_name != "all":
            task_list = [task_name]  # evaluate one task
        # context dir
        if args.guide is not None:
            save_path = os.path.join(save_path, "context", f"{args.guide}")
            make_dir(save_path)

        # # # # # # # # # #
        tot_success = []
        for task in task_list:
            args.env_name = f"{domain_name}.{task}"
            env, domain_name, task_name = make_env(args)

            if args.guide is None or args.guide == "test":
                model = diffgro.LCD(
                    env,
                    planner,
                    delta=args.delta,
                    guide=args.guide,
                    guide_pt=args.prompt,
                    verbose=args.verbose,
                )
                model._setup_guide()
                print_y("Evaluating without context !!")
                if domain_name == "metaworld":
                    success = evaluate(
                        model,
                        env,
                        domain_name,
                        task_name,
                        args.n_episodes,
                        True,
                        args.video,
                        save_path,
                    )
                else:
                    success = evaluate_complex(
                        model,
                        env,
                        domain_name,
                        task_name,
                        args.n_episodes,
                        True,
                        args.video,
                        save_path,
                    )
                tot_success.extend(success)
            else:
                model = diffgro.LCD(
                    env,
                    planner,
                    delta=args.delta,
                    guide=args.guide,
                    guide_pt=None,
                    verbose=args.verbose,
                )
                contexts = make_context(domain_name, task_name, args.multimodal)
                if args.prompt_idx is not None:
                    contexts = contexts[args.prompt_idx : args.prompt_idx + 1]
                print_y(f"Evaluating Contexts {contexts}")
                context_success, context_success_rew = [], []
                for context in contexts:
                    print_r(f"Guide: {args.guide} with Context: {context}")
                    # save the context
                    with open(os.path.join(save_path, "evaluation.txt"), "a") as f:
                        f.write(f"Context: {context}\n")
                    with open(os.path.join(save_path, "code.txt"), "a") as f:
                        f.write(f"Context: {context}\n")

                    # evaluation
                    if domain_name == "metaworld":
                        model.context_info = [context]
                        model._setup_guide()
                        success, success_context_rew = evaluate(
                            model,
                            env,
                            domain_name,
                            task_name,
                            args.n_episodes,
                            True,
                            args.video,
                            save_path,
                            context=context,
                        )
                    else:
                        model.context_info = context
                        model._setup_guide()
                        success, success_context_rew = evaluate_complex(
                            model,
                            env,
                            domain_name,
                            task_name,
                            args.n_episodes,
                            True,
                            args.video,
                            save_path,
                            context=context,
                        )
                    tot_success.extend(success)
                    context_success.extend(success)
                    context_success_rew.extend(success_context_rew)

        if len(task_list) > 1:
            eval_save(tot_success, save_path)


if __name__ == "__main__":
    args = Parser("train").parse_args()
    train(args)
