

from tqdm import tqdm

import click
import os
import yaml
import numpy as np
import torch
import re
import torch.optim as optim
from maml_rl.optimizer import DifferentiableSGD
from rlkit.envs import ENVS
from rlkit.envs.wrappers import NormalizedBoxEnv
import wandb
from maml_rl.trpo import TRPO
from maml_rl.sampler import Sampler
from maml_rl.buffer import Buffer
from meta_test_algo.render_utils import initialize_viewer,render


class launcher():
    def __init__(self,test_env,test_file,seed,debug):
        self.test_env = test_env
        with open(os.path.join("maml_rl/configs", f"{test_env}.yaml"),"r",encoding="utf-8") as file:
            env_config = yaml.load(file, Loader=yaml.FullLoader)
        self.debug = debug
        self.env = NormalizedBoxEnv(ENVS[test_env]())
        np.random.seed(seed)
        torch.manual_seed(seed)

        if self.test_env == 'cheetah-vel':
            self.env.set_velocity(-2) # set velocity (-2)
        elif self.test_env == 'cheetah-dir':
            self.env.set_direction(-1) # set direction (backward)
        elif self.test_env == 'ant-goal':
            self.env.set_goal_position(1.5*np.pi,3) # set goal (angle = 1.5 pi, radius = 3)
        elif self.test_env == 'ant-dir':
            self.env.set_direction(1.5*np.pi) # set direction (angle = 1.5 pi)
        elif 'params' in self.test_env:
            self.env.set_test_task()
        self.env.set_seed(seed)

        observ_dim: int = self.env.observation_space.shape[0]
        action_dim: int = self.env.action_space.shape[0]
        policy_hidden_dim: int = env_config["policy_hidden_dim"]
        vf_hidden_dim: int = env_config["value_function_hidden_dim"]

        self.num_samples = env_config["maml_params"]["num_samples"]
        self.meta_batch_size = env_config["maml_params"]["meta_batch_size"]
        self.max_steps = env_config["maml_params"]["max_steps"]

        device = (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))


        self.agent = TRPO(observ_dim=observ_dim,
                        action_dim=action_dim,
                        policy_hidden_dim=policy_hidden_dim,
                        vf_hidden_dim=vf_hidden_dim,
                        device=device,
                        **env_config["pg_params"])
        if test_env == 'cheetah-dir':
            pretrained_policy_params = torch.load(f'./maml_policy/cheetah-vel/maml_policy({test_file}).pt')
        else:
            pretrained_policy_params = torch.load(f'./maml_policy/{test_env}/maml_policy({test_file}).pt')
        self.agent.policy.load_state_dict(pretrained_policy_params)
        self.optimizer = DifferentiableSGD(self.agent.policy,
                                            lr=env_config["maml_params"]["inner_learning_rate"],
                                            maml=False)
        self.sampler = Sampler(env=self.env,
                            agent=self.agent,
                            action_dim=action_dim,
                            max_step=env_config["maml_params"]["max_steps"],
                            device=device)
        self.buffer = Buffer(agent=self.agent,
                        observ_dim=observ_dim,
                        action_dim=action_dim,
                        max_size=self.num_samples,
                        device=device)

        print(f'MAML policy({test_file}) is loaded.')

        if self.debug:
            pass
        else:
            if self.test_env == 'cheetah-dir':
                wandb.init(project = 'Meta Test cheetah-vel -> cheetah-dir',
                        name = f'MAML({seed})',
                        group = 'MAML')
            else:
                wandb.init(project = f'Meta Test {self.test_env}',
                        name = f'MAML({seed})',
                        group = 'MAML')
            
    def learn(self):
        ret = self.eval(self.agent)
        if self.debug:
            pass
        else:
            wandb.log({f"Task mean return":ret},step=0)
        total_steps = 0
        while True:
            self.agent.policy.is_deterministic = False
            trajs = self.sampler.obtain_samples(max_samples=self.num_samples)
            steps = 0
            for traj in trajs:
                steps += len(traj['cur_obs'])
            self.buffer.add_task_trajs(trajs)
            train_batch = self.buffer.get_task_trajs()
            inner_loss = self.agent.policy_loss(train_batch)
            total_steps += steps

            self.optimizer.zero_grad()
            inner_loss.backward()
            self.optimizer.step()
            
            self.buffer.clear_task()
            ret = self.eval(self.agent)
            if self.debug:
                pass
            else:
                wandb.log({f"Task mean return":ret},step=total_steps)
            print('steps: ',total_steps, 'Return: ',ret)
            if total_steps >= 200000:
                break

    def eval(self,agent):
        with torch.no_grad():
            agent.policy.is_deterministic = True
            rewards = 0
            
            # frames = []
            # render_type='human'
            # initialize_viewer(self.env,render_type)
            n_eval_epi = 10
            for _ in range(n_eval_epi):
                cur_step  = 0
                obs = self.env.reset()
                done = np.zeros(1)
                while not (done or cur_step == self.max_steps):
                    # render(self.env,self.test_env,render_type)
                    action = agent.get_action(obs)
                    next_obs, reward, done, info = self.env.step(action)
                    reward = np.array(reward)
                    done = np.array(int(done))
                    rewards += reward

                    obs = next_obs
                    cur_step += 1
            rewards /= n_eval_epi

        return rewards
  

@click.command()
@click.option('--test_env',default=None)
@click.option('--file_number',default=None)
@click.option('--seed',default=0)
@click.option('--debug', is_flag=True, default=False)


def main(test_env, file_number,seed,debug):

    pattern = re.compile(r'maml_policy\((\d+)\)\.pt')
    numbers = []

    if test_env == 'cheetah-dir':
        path = f'./maml_policy/cheetah-vel'
    else:
        path = f'./maml_policy/{test_env}'

    for filename in os.listdir(path):
        match = pattern.fullmatch(filename)
        if match:
            number = int(match.group(1))
            numbers.append(number)
    if numbers:
        max_number = max(numbers)
    else:
        raise Exception("There is no file to read.")
    test_file = max_number
    if file_number == None:
        test_file = max_number
    else:
        test_file = file_number

    launch = launcher(test_env,test_file,seed,debug)
    launch.learn()

if __name__ == "__main__":
    main()