import os
import numpy as np
import click
import json
import random
import torch
import torch.optim as optim

from rlkit.envs import ENVS
from rlkit.envs.wrappers import NormalizedBoxEnv
from meta_test_algo.ppo import PPOAgent
from meta_test_algo.pg import policy_gradient
from meta_test_algo.es import EvolutionStrategies
from meta_test_algo.es_cem import EvolutionStrategies_CEM
from meta_test_algo.es_mirror2 import EvolutionStrategies_MIRROR2
from meta_test_algo.es_mirror2_q import EvolutionStrategies_MIRROR2_Q
from meta_test_algo.esq import ESQ
from meta_test_algo.ddpg import DDPG
from meta_test_algo.q_finetuning import Q_finetuning
from meta_test_algo.es_mirror2_nonlinear import EvolutionStrategies_MIRROR2_nonlinear
from meta_test_algo.es_cma2 import EvolutionStrategies_CMA
from meta_test_algo.es_full2 import EvolutionStrategies_FULL
from configs.default import default_config

import wandb
from tqdm import tqdm
import re

class launcher():
    def __init__(self,test_env,train_env,test_file,variant,seed,
                 debug=False,
                 render=False,
                 scratch=False,
                 action_noise=False,
                 no_backbone=False):
        self.env = NormalizedBoxEnv(ENVS[variant['env_name']](**variant['env_params']))
        if test_env == train_env:
            # self.env.set_test_task(variant['n_eval_tasks'])
            if test_env == 'cheetah-vel':
                self.env.set_velocity(-2.0) # set velocity (-2)
            elif test_env == 'cheetah-dir':
                self.env.set_direction(+1) # set direction (forward)
            elif test_env == 'ant-goal':
                self.env.set_goal_position(1.5*np.pi,3) # set goal (angle = 1.5 pi, radius = 3)
            elif test_env == 'ant-dir':
                self.env.set_direction(1.5*np.pi) # set direction (angle = 1.5 pi)
            elif test_env =='humanoid-dir':
                self.env.set_direction(1.5*np.pi) # set direction (angle = 1.5 pi)
            elif 'params' in test_env:
                self.env.set_test_task()
        else:
            # self.env.set_train_task(variant['n_train_tasks'])
            if test_env == 'cheetah-vel':
                self.env.set_velocity(+2.0) # set velocity (+2)
            elif test_env == 'cheetah-dir':
                self.env.set_direction(-1) # set direction (backward)
            elif test_env == 'ant-goal':
                self.env.set_goal_position(0.5*np.pi,3) # set goal (angle = 0.5 pi, radius = 3)
            elif test_env == 'ant-dir':
                self.env.set_direction(0.5*np.pi) # set direction (angle = 1.5 pi)
            elif test_env =='humanoid-dir':
                self.env.set_direction(0.5*np.pi) # set direction (angle = 1.5 pi)
            elif 'params' in test_env:
                self.env.set_test_task()
        
        # set seed
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        self.env.set_seed(seed)

        self.tasks = self.env.get_all_task_idx()
        self.eval_tasks = list(self.tasks[-variant['n_eval_tasks']:])
        self.obs_dim = int(np.prod(self.env.observation_space.shape))
        self.action_dim = int(np.prod(self.env.action_space.shape))
        self.test_env = test_env
        self.train_env = train_env
        self.test_file = test_file

        # device
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  

        # meta test parameters
        self.method = variant['meta_test_params']['method']
        self.no_pretrained = variant['meta_test_params']['no_pretrained']

        self.total_step = variant['meta_test_params']['total_step']
        self.eval_step = variant['sr_params']['eval_step']
        if self.method == 'es_mirror_q':
            self.n_rollouts = variant['es_q_params']['n_rollouts']
            self.noise_sigma = variant['es_q_params']['noise_sigma']
            self.lr = variant['es_q_params']['lr']
            self.elite_frac = variant['es_q_params']['elite_frac']
        else:
            self.n_rollouts = variant['es_params']['n_rollouts']
            self.noise_sigma = variant['es_params']['noise_sigma']
            self.lr = variant['es_params']['lr']
            self.elite_frac = variant['es_params']['elite_frac']


        # instantiate networks
        self.latent_action_dim = variant['sr_params']['latent_action_dim']
        self.net_size = variant['net_size']
        self.agents = []

        self.debug = debug
        self.render = render
        self.scratch = scratch
        self.action_noise = action_noise
        # self.eval_tasks = [100]
        for _ in self.eval_tasks:
            if self.method == 'pg':
                agent = policy_gradient(obs_dim=self.obs_dim,
                                action_dim=self.action_dim,
                                net_size=self.net_size,
                                latent_action_dim=self.latent_action_dim,
                                device=self.device,
                                **variant['meta_test_params'])
            elif self.method == 'ppo':
                agent = PPOAgent(obs_dim=self.obs_dim,
                                action_dim=self.action_dim,
                                net_size=self.net_size,
                                latent_action_dim=self.latent_action_dim,
                                device=self.device,
                                **variant['meta_test_params'])
            elif self.method == 'es':
                agent = EvolutionStrategies(obs_dim=self.obs_dim,
                                            action_dim=self.action_dim,
                                            net_size=self.net_size,
                                            latent_action_dim=self.latent_action_dim,
                                            device=self.device,
                                            es_params = variant['es_params'],
                                            **variant['meta_test_params'])
            elif self.method == 'es_cem':
                agent = EvolutionStrategies_CEM(obs_dim=self.obs_dim,
                                                action_dim=self.action_dim,
                                                net_size=self.net_size,
                                                latent_action_dim=self.latent_action_dim,
                                                device=self.device,
                                                es_params = variant['es_params'],
                                                **variant['meta_test_params'])
            elif self.method == 'es_mirror':
                agent = EvolutionStrategies_MIRROR2(obs_dim=self.obs_dim,
                                                   action_dim=self.action_dim,
                                                   net_size=self.net_size,
                                                   latent_action_dim=self.latent_action_dim,
                                                   device=self.device,
                                                   es_params = variant['es_params'],
                                                   **variant['meta_test_params'])
            elif self.method == 'es_mirror_q':
                agent = EvolutionStrategies_MIRROR2_Q(obs_dim=self.obs_dim,
                                                   action_dim=self.action_dim,
                                                   net_size=self.net_size,
                                                   latent_action_dim=self.latent_action_dim,
                                                   device=self.device,
                                                   es_params = variant['es_q_params'],
                                                   **variant['meta_test_params'])
            elif self.method == 'esq':
                agent = ESQ(obs_dim=self.obs_dim,
                            action_dim=self.action_dim,
                            net_size=self.net_size,
                            latent_action_dim=self.latent_action_dim,
                            device=self.device,
                            esq_params = variant['esq_params'],
                            scratch=self.scratch,
                            action_noise = self.action_noise,
                            **variant['meta_test_params'])
            elif self.method == 'ddpg':
                agent = DDPG(obs_dim=self.obs_dim,
                            action_dim=self.action_dim,
                            net_size=self.net_size,
                            latent_action_dim=self.latent_action_dim,
                            device=self.device,
                            esq_params = variant['esq_params'],
                            scratch=self.scratch,
                            **variant['meta_test_params'])
            elif self.method == 'q_finetuning':
                agent = Q_finetuning(obs_dim=self.obs_dim,
                                                   action_dim=self.action_dim,
                                                   net_size=self.net_size,
                                                   latent_action_dim=self.latent_action_dim,
                                                   device=self.device,
                                                   es_params = variant['es_q_params'],
                                                   **variant['meta_test_params'])
            elif self.method == 'es_mirror_nonlinear':
                agent = EvolutionStrategies_MIRROR2_nonlinear(obs_dim=self.obs_dim,
                                                   action_dim=self.action_dim,
                                                   net_size=self.net_size,
                                                   latent_action_dim=self.latent_action_dim,
                                                   device=self.device,
                                                   es_params = variant['es_params'],
                                                   **variant['meta_test_params'])
            elif self.method == 'es_full':
                agent = EvolutionStrategies_FULL(obs_dim=self.obs_dim,
                                                 action_dim=self.action_dim,
                                                 net_size=self.net_size,
                                                 latent_action_dim=self.latent_action_dim,
                                                 device=self.device,
                                                 es_params = variant['es_params'],
                                                 **variant['meta_test_params'])
            elif self.method == 'es_cma':
                agent = EvolutionStrategies_CMA(obs_dim=self.obs_dim,
                                                 action_dim=self.action_dim,
                                                 net_size=self.net_size,
                                                 latent_action_dim=self.latent_action_dim,
                                                 device=self.device,
                                                 es_params = variant['es_params'],
                                                 **variant['meta_test_params'])

            if self.scratch:
                for param in agent.policy.shared_layer.parameters():
                    param.requires_grad = True
            else:
                self.shared_param = self.load_shared_policy(self.test_file)
                # self.shared_param = torch.load(f'./finetuning_test_model/test_model({test_env}).pt')
                if no_backbone:
                    pass
                else:
                    agent.load_shared_network(self.shared_param)
                for param in agent.policy.shared_layer.parameters():
                    param.requires_grad = False
            self.agents.append(agent)

        self.exp_noise_sigma = variant['esq_params']['exp_noise_sigma']
        self.w = variant['esq_params']['w']
        if not self.debug:
            if no_backbone:
                if test_env == train_env:
                    wandb.init(
                        project = f'Meta Test {self.test_env}',
                        name = f'(no backbone seed: {seed})',
                        group = 'no backbone'
                        )
                else:
                    wandb.init(
                        project = f'Meta Test {self.train_env} -> {self.test_env}',
                        name = f'(no backbone seed: {seed})',
                        group = 'no backbone'
                        )
            else:
                if test_env == train_env:
                    wandb.init(
                        project = f'Meta Test {self.test_env}',
                        name = f'(from scratch: {self.scratch}, action noise: {action_noise}, seed: {seed})',
                        group = f'exp noise sigma: {self.exp_noise_sigma}, w: {self.w}'
                        )
                else:
                    wandb.init(
                        project = f'Meta Test {self.train_env} -> {self.test_env}',
                        name = f'(from scratch: {self.scratch}, action noise: {action_noise}, seed: {seed})',
                        group = f'exp noise sigma: {self.exp_noise_sigma}, w: {self.w}'
                        )

    def load_shared_policy(self,epoch=None):
        if epoch==None:
            path = './shared_policy'
            pattern = re.compile(r'shared_policy\((\d+)\)\.pt')
            numbers = []
            for filename in os.listdir(path):
                match = pattern.fullmatch(filename)
                if match:
                    number = int(match.group(1))
                    numbers.append(number)
            if numbers:
                max_number = max(numbers)
            else:
                raise Exception("There is no file to read.")
            epoch = max_number
        if self.train_env == 'cheetah-dir':
            shared_policy_params = torch.load(f'./shared_policy/cheetah-vel/shared_policy({epoch}).pt')
        else:
            shared_policy_params = torch.load(f'./shared_policy/{self.train_env}/shared_policy({epoch}).pt')
        print('Shared policy ({}) is loaded.'.format(epoch))
        return shared_policy_params

    def learn(self):
        steps = np.zeros(len(self.eval_tasks),dtype=int)
        iters = int(self.total_step/(self.agents[0].max_path_length))
        ''''''''''''''''''''''''
        ret = self.eval()
        mean_ret = np.array(ret).mean()
        ''''''''''''''''''''''''
        if not self.debug:
            wandb.log({"Task mean return":mean_ret},step=0)
            for i, task in enumerate(self.eval_tasks):
                wandb.log({f"Task {task} returns":ret[i]},step=0)
        for step in tqdm(range(iters)):
            returns = []
            max_ret = -np.inf
            for idx,task in enumerate(self.eval_tasks):
                if steps[idx] > 200000:
                    break
                agent = self.agents[idx]
                self.env.reset_task(task)
                env_steps = agent.collet_data_and_train_filter(self.env)
                steps[idx] += env_steps
                ''''''''''''''''''''''''
                ret = agent.evaluation(self.env,self.test_env,rendering=self.render)
                print(ret)
                if self.render:
                    if steps[idx] > 200000:
                        if (ret > max_ret) and (not self.debug):
                            print('rendering....')
                            ret = agent.evaluation(self.env,self.test_env,rendering=self.render,render_type='rgb_array')
                max_ret = max(max_ret,ret)
                returns.append(ret)
                ''''''''''''''''''''''''
                if not self.debug:
                    wandb.log({f"Task {task} returns":ret},step=steps[idx])
            if not self.debug:
                wandb.log({"Task mean return":np.array(returns).mean()},step=steps[idx])
                

    def eval(self):
        returns = []
        for idx, task in enumerate(self.eval_tasks):
            agent = self.agents[idx]
            self.env.reset_task(task)
            ret = agent.evaluation(self.env,self.test_env,rendering=self.render)
            returns.append(ret)
        return returns


def deep_update_dict(fr, to):
    ''' update dict of dicts with new values '''
    # assume dicts have same keys
    for k, v in fr.items():
        if type(v) is dict:
            deep_update_dict(v, to[k])
        else:
            to[k] = v
    return to

@click.command()
@click.option('--test_env',default=None)
@click.option('--train_env',default=None)
@click.option('--file_number',default=None)
@click.option('--seed',default=0)
@click.option('--debug', is_flag=True, default=False)
@click.option('--render', is_flag=True, default=False)
@click.option('--scratch', is_flag=True, default=False)
@click.option('--action_noise',is_flag=True, default=False)
@click.option('--exp_noise_sigma',default=None)
@click.option('--w',default=None)
@click.option('--no_backbone',is_flag=True,default=False)

def main(test_env,train_env,file_number,seed,debug,render,scratch,action_noise,exp_noise_sigma,w,no_backbone):
    

    variant = default_config
    if test_env == 'cheetah-dir':
        config = './configs/cheetah-dir.json'
    elif test_env == 'cheetah-vel':
        config = './configs/cheetah-vel.json'
    elif test_env == 'ant-goal':
        config = './configs/ant-goal.json'
    elif test_env == 'ant-dir':
        config = './configs/ant-dir.json'
    elif test_env == 'humanoid-dir':
        config = './configs/humanoid-dir.json'
    elif test_env == 'hopper-rand-params':
        config = './configs/hopper_rand_params.json'
    elif test_env == 'walker-rand-params':
        config = './configs/walker_rand_params.json'
    if config:
        with open(os.path.join(config)) as f:
            exp_params = json.load(f)
        variant = deep_update_dict(exp_params, variant)
        if not(exp_noise_sigma == None):
            variant['esq_params']['exp_noise_sigma'] = float(exp_noise_sigma)
        if not(w == None):
            variant['esq_params']['w'] = float(w)


    if train_env == None:
        train_env = test_env

    pattern = re.compile(r'shared_policy\((\d+)\)\.pt')
    numbers = []
    if test_env == 'cheetah-dir':
        path = f'./shared_policy/cheetah-vel'
    else:
        path = f'./shared_policy/{train_env}'

    for filename in os.listdir(path):
        match = pattern.fullmatch(filename)
        if match:
            number = int(match.group(1))
            numbers.append(number)
    if numbers:
        max_number = max(numbers)
    else:
        raise Exception("There is no file to read.")
    if file_number == None:
        test_file = max_number
    else:
        test_file = file_number

    launch = launcher(test_env,train_env,test_file,variant,seed,
                      debug=debug,
                      render=render,
                      scratch=scratch,
                      action_noise=action_noise,
                      no_backbone=no_backbone)
    launch.learn()

if __name__ == "__main__":
    main()