# m1

import warnings
warnings.filterwarnings('ignore')
import os
import time
import pickle
import argparse
import random
import numpy as np
import torch
import metaworld
import metaworld

from stable_baselines3.common.torch_layers import CombinedExtractor
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback

from bc.bc import BC
from bc.buffers import BCReplayBuffer
from bc.features_extractor import (
    ClipExtractor, PromptAttentionExtractor, CURLExtractor, ATCExtractor, ACPExtractor, ConPEExtractor, ATTEMPTExtractor, SESoMExtractor, LUSRExtractor
)
from utils import check_env, make_env, make_vec_env, evaluate_env, evaluate_dfs_env, evaluate_short_dfs_env, render_env, get_expert

def seed_fix(n):
    random.seed(n)
    np.random.seed(n)
    torch.manual_seed(n)
    if torch.backends.cudnn.enabled:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

SEED = 0
# seed_fix(4)

def main(args):
    # seeding
    # benchmark = metaworld.MT10(seed=SEED)
    benchmark = metaworld.MT50(seed=SEED)
    
    # check valid env
    print(f"<=======Env {args.env}========>")
    if not check_env(args, benchmark): 
        print(f"Not a valid environment {args.env}")
        return
    
    # get expert policy
    expert = get_expert(args)

    # make gym env 
    env = make_vec_env(args, benchmark, seed=SEED, expert=expert, dict_obs=True, training=False) # should give expert for onlineBC
    eval_env = make_vec_env(args, benchmark, seed=SEED, dict_obs=True, training=False, norm_reward=False)
    obs_space, act_space = env.observation_space, env.action_space
    if args.expert and args.eval:
        evaluate_short_dfs_env(eval_env, expert, path='./', n_eval_episodes=1, verbose=True)
        return
    #print(f"Obs Space: {obs_space}, Act Space: {act_space}")
    
    # load offline buffer 
    load_path = f'./datasets/{args.env[:-3]}/{args.domain_factor}/{args.env}_0_v0.pkl'

    """
    with open(load_path, 'rb') as f:
        dataset = pickle.load(f)
        buffer_size = len(dataset['observations'])

    offline_buffer = ReplayBuffer(
        buffer_size=buffer_size,
        observation_space=obs_space,
        action_space=act_space,
        device="cuda",
        n_envs=1,
    )
    obs, acts, rews, = dataset['observations'], dataset['actions'], dataset['rewards']
    dones, infos = dataset['terminals'], dataset['infos']
    for i in range(buffer_size-1):
        offline_buffer.add(
            obs=obs[i],
            next_obs=obs[i+1],
            action=acts[i],
            reward=rews[i],
            done=dones[i],
            infos=infos[i]
        )
    """

    # make path
    path = f'./models/bc/{args.env}/{args.tag}'
    os.makedirs(path, exist_ok=True)

    # This is Online BC ...
    if "conpe" in args.tag:
        # source_model = "/path/to/metaworld/models/bc/reach-v2/embclip/policy.pth"

        model = BC(
            policy='MultiInputPolicy',
            env=env,
            learning_rate=3e-4,
            batch_size=64,
            verbose=1,
            train_freq=1,
            buffer_size=10_000,
            replay_buffer_class=BCReplayBuffer,
            #offline_buffer=offline_buffer,
            device="cuda",
            without_exploration=False,
            tensorboard_log=path,
            policy_kwargs=dict(
                net_arch=[256, 256],
                use_dist=args.use_dist,
                features_extractor_class=PromptAttentionExtractor, #ClipExtractor, #CombinedExtractor # PromptAttentionExtractor
                features_extractor_kwargs={'env': args.env[:-3], 'prompt_env': args.prompt_env[:-3]},
            ),
        )
    if "conpe-tl" in args.tag:
        source_model = "/path/to/metaworld/models/bc/door-open-v2/embclip/policy.pth"

        model = BC(
            policy='MultiInputPolicy',
            env=env,
            learning_rate=3e-4,
            batch_size=64,
            verbose=1,
            train_freq=1,
            buffer_size=10_000,
            replay_buffer_class=BCReplayBuffer,
            #offline_buffer=offline_buffer,
            device="cuda",
            without_exploration=False,
            tensorboard_log=path,
            policy_kwargs=dict(
                net_arch=[256, 256],
                use_dist=args.use_dist,
                features_extractor_class=ConPEExtractor, #ClipExtractor, #CombinedExtractor # PromptAttentionExtractor
                features_extractor_kwargs={'env': args.env[:-3], 'prompt_env': args.prompt_env[:-3]},
            ),
            source_model=source_model
        )
    elif "attempt" in args.tag:
        source_model = "/path/to/metaworld/models/bc/door-open-v2/embclip/policy.pth"

        model = BC(
            policy='MultiInputPolicy',
            env=env,
            learning_rate=3e-4,
            batch_size=64,
            verbose=1,
            train_freq=1,
            buffer_size=10_000,
            replay_buffer_class=BCReplayBuffer,
            #offline_buffer=offline_buffer,
            device="cuda",
            without_exploration=False,
            tensorboard_log=path,
            policy_kwargs=dict(
                net_arch=[256, 256],
                use_dist=args.use_dist,
                features_extractor_class=ATTEMPTExtractor, #ClipExtractor, #CombinedExtractor # PromptAttentionExtractor
                features_extractor_kwargs={'env': args.env[:-3], 'prompt_env': args.prompt_env[:-3]},
            ),
            source_model=source_model
        )
    elif "sesom" in args.tag:
        source_model = "/path/to/metaworld/models/bc/door-open-v2/embclip/policy.pth"

        model = BC(
            policy='MultiInputPolicy',
            env=env,
            learning_rate=3e-4,
            batch_size=64,
            verbose=1,
            train_freq=1,
            buffer_size=10_000,
            replay_buffer_class=BCReplayBuffer,
            #offline_buffer=offline_buffer,
            device="cuda",
            without_exploration=False,
            tensorboard_log=path,
            policy_kwargs=dict(
                net_arch=[256, 256],
                use_dist=args.use_dist,
                features_extractor_class=SESoMExtractor, #ClipExtractor, #CombinedExtractor # PromptAttentionExtractor
                features_extractor_kwargs={'env': args.env[:-3], 'prompt_env': args.prompt_env[:-3]},
            ),
            source_model=source_model
        )
    elif "embclip" in args.tag:
        model = BC(
            policy='MultiInputPolicy',
            env=env,
            learning_rate=3e-4,
            batch_size=64,
            verbose=1,
            train_freq=1,
            buffer_size=10_000,
            replay_buffer_class=BCReplayBuffer,
            #offline_buffer=offline_buffer,
            device="cuda",
            without_exploration=False,
            tensorboard_log=path,
            policy_kwargs=dict(
                net_arch=[256, 256],
                use_dist=args.use_dist,
                features_extractor_class=ClipExtractor, #ClipExtractor, #CombinedExtractor # PromptAttentionExtractor
                features_extractor_kwargs={'env': args.env[:-3]},
            )
        )

    elif "curl" in args.tag:
        model = BC(
            policy='MultiInputPolicy',
            env=env,
            learning_rate=3e-4,
            batch_size=64,
            verbose=1,
            train_freq=1,
            buffer_size=10_000,
            replay_buffer_class=BCReplayBuffer,
            #offline_buffer=offline_buffer,
            device="cuda",
            without_exploration=False,
            tensorboard_log=path,
            policy_kwargs=dict(
                net_arch=[256, 256],
                use_dist=args.use_dist,
                features_extractor_class=CURLExtractor, #ClipExtractor, #CombinedExtractor # PromptAttentionExtractor
                features_extractor_kwargs={'env': args.env[:-3]},
            )
        )
    elif "atc" in args.tag:
        model = BC(
            policy='MultiInputPolicy',
            env=env,
            learning_rate=3e-4,
            batch_size=64,
            verbose=1,
            train_freq=1,
            buffer_size=10_000,
            replay_buffer_class=BCReplayBuffer,
            #offline_buffer=offline_buffer,
            device="cuda",
            without_exploration=False,
            tensorboard_log=path,
            policy_kwargs=dict(
                net_arch=[256, 256],
                use_dist=args.use_dist,
                features_extractor_class=ATCExtractor, #ClipExtractor, #CombinedExtractor # PromptAttentionExtractor
                features_extractor_kwargs={'env': args.env[:-3]},
            )
        )
    elif "acp" in args.tag:
        model = BC(
            policy='MultiInputPolicy',
            env=env,
            learning_rate=3e-4,
            batch_size=64,
            verbose=1,
            train_freq=1,
            buffer_size=10_000,
            replay_buffer_class=BCReplayBuffer,
            #offline_buffer=offline_buffer,
            device="cuda",
            without_exploration=False,
            tensorboard_log=path,
            policy_kwargs=dict(
                net_arch=[256, 256],
                use_dist=args.use_dist,
                features_extractor_class=ACPExtractor, #ClipExtractor, #CombinedExtractor # PromptAttentionExtractor
                features_extractor_kwargs={'env': args.env[:-3]},
            )
        )
    elif "lusr" in args.tag:
        model = BC(
            policy='MultiInputPolicy',
            env=env,
            learning_rate=3e-4,
            batch_size=64,
            verbose=1,
            train_freq=1,
            buffer_size=10_000,
            replay_buffer_class=BCReplayBuffer,
            #offline_buffer=offline_buffer,
            device="cuda",
            without_exploration=False,
            tensorboard_log=path,
            policy_kwargs=dict(
                net_arch=[256, 256],
                use_dist=args.use_dist,
                features_extractor_class=LUSRExtractor, #ClipExtractor, #CombinedExtractor # PromptAttentionExtractor
                features_extractor_kwargs={'env': args.env[:-3]},
            )
        )

    if args.train:
        eval_freq = int(args.timesteps/10)
        callbacks = [
            CheckpointCallback(eval_freq, path, verbose=1),
            EvalCallback(eval_env, n_eval_episodes=10, eval_freq=eval_freq, log_path=path),
        ]
        
        print("==== Before Training ====")
        start_time = time.time()
        #avg_rew, avg_len, avg_suc = evaluate_env(eval_env, model.policy, n_eval_episodes=10, verbose=True)
        end_time = time.time()
        print(f"Time Elapsed: {end_time - start_time}")
        model.learn(total_timesteps=args.timesteps, callback=callbacks, log_interval=100)
        print("==== After Training =====")
        avg_rew, avg_len, avg_suc = evaluate_env(eval_env, model.policy, n_eval_episodes=10, verbose=True)
        os.makedirs(path, exist_ok=True)
        model.save(path=path+'/rl_model_final_steps.zip')
        env.save(path+'/env.zip')
    
    if args.eval:
        model = model.load(path + f'/rl_model_{args.checkpoint}_steps.zip')
        #eval_env = eval_env.load(path + '/env.zip', eval_env)
        # clear normalization
        eval_env.training, eval_env.norm_reward = False, False
        # clear features extractor noise
        model.policy.actor.features_extractor.noise_std = 0.0
        policy = model.policy
        if not args.render:
            if args.eval_source:
                print("Evaluating Source!!")
                eval_env.envs[0].env.set_source = True
                avg_rew, avg_len, avg_suc = evaluate_env(eval_env, model.policy, n_eval_episodes=10, verbose=True)
            elif args.eval_short:
                print("Evaluating Short!!")
                eval_env.envs[0].env.set_source = False
                evaluate_short_dfs_env(eval_env, model.policy, path, n_eval_episodes=50, verbose=True)
            elif args.eval_seen_random:
                print("Evaluating Random!!")
                eval_env.envs[0].env.set_source = False
                eval_env.envs[0].env.set_seen_random = True
                avg_rew, avg_len, avg_suc = evaluate_env(eval_env, model.policy, n_eval_episodes=50, verbose=True)
            elif args.eval_unseen_random:
                print("Evaluating Random!!")
                eval_env.envs[0].env.set_source = False
                eval_env.envs[0].env.set_unseen_random = True
                avg_rew, avg_len, avg_suc = evaluate_env(eval_env, model.policy, n_eval_episodes=50, verbose=True)
            else:
                print("Evaluating !!")
                eval_env.envs[0].env.set_source = False
                avg_rew, avg_len, avg_suc = evaluate_env(eval_env, model.policy, n_eval_episodes=10, verbose=True)

    if args.render:
        render_env(args, eval_env, policy, args.fov, path, verbose=True)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    # env
    parser.add_argument("--env", type=str, default='reach-v2')
    # training
    parser.add_argument("--timesteps", type=int, default=10_000)
    parser.add_argument("--tag", type=str, default='v0')
    parser.add_argument("--source", action="store_true", default=False)
    # evaluation
    parser.add_argument("--eval_source", action="store_true", default=False)
    parser.add_argument("--eval_short", action="store_true", default=False)
    parser.add_argument("--eval_seen_random", action="store_true", default=False)
    parser.add_argument("--eval_unseen_random", action="store_true", default=False)
    # misc
    parser.add_argument("--max_step", type=str, default=500)
    parser.add_argument("--terminate", action="store_true", default=False)
    # model
    parser.add_argument("--use_dist", action="store_true", default=False)
    parser.add_argument("--reach_prompt", action="store_true", default=False)
    # domain factors
    parser.add_argument("--domain_factor", type=str, default='GRAV')
    parser.add_argument("--fov", type=str, default='cam0-0')
    parser.add_argument("--xwind_id", type=int, default=0)
    parser.add_argument("--gravity_id", type=int, default=0)
    parser.add_argument("--bright_id", type=int, default=0)
    parser.add_argument("--contrast_id", type=int, default=0)
    parser.add_argument("--saturation_id", type=int, default=0)
    parser.add_argument("--hue_id", type=int, default=0)
    # 
    parser.add_argument("--seed", type=int, default=777)
    parser.add_argument("--train", action="store_true", default=False)
    parser.add_argument("--eval", action="store_true", default=False)
    parser.add_argument("--expert", action="store_true", default=False)
    parser.add_argument("--render", action="store_true", default=False)
    parser.add_argument("--collect", action="store_true", default=False)
    parser.add_argument("--checkpoint", type=str, default='final')
    args = parser.parse_args()
    

    # setting prompt env
    if args.reach_prompt:
        args.prompt_env = 'reach-v2'
    else:
        args.prompt_env = args.env
    main(args)
