
import gym
import numpy as np
import torch
import wandb
from algorithms.a2c import A2C
from algorithms.ppo import PPO
from environments.parallel_envs import make_vec_envs
from models.policy import Policy
from utils import helpers as utl
from vae import VaribadVAE
import argparse
from torch.nn import functional as F

from pathlib import Path
import sys
ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(ROOT))
from rlkit.envs import ENVS
from rlkit.envs.wrappers import NormalizedBoxEnv
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

from config.mujoco import args_cheetah_dir_varibad, args_cheetah_vel_varibad, args_ant_dir_varibad, \
                          args_ant_goal_varibad, args_walker_varibad, args_hopper_varibad

class MetaLearner:
    """
    Meta-Learner class with the main training loop for variBAD.
    """
    def __init__(self, args):
        self.args = args
         # set env name for wandb logging
        env_name = self.args.env_name
        if env_name == 'HalfCheetahVel-v0':
            self.env_name = 'cheetah-vel'
            self.args.task_dim = 1
        elif env_name == 'HalfCheetahDir-v0':
            env_name = 'HalfCheetahVel-v0'
            self.env_name = 'cheetah-dir'
            self.args.task_dim = 1
        elif env_name == 'AntDir2D-v0':
            self.env_name = 'ant-dir'
            self.args.task_dim = 1
        elif env_name == 'AntGoal-v0':
            self.env_name = 'ant-goal'
            self.args.task_dim = 2
        elif env_name == 'Walker2DRandParams-v0':
            self.env_name = 'walker-rand-params'
            self.args.task_dim = 65
        elif env_name == 'HopperRandParams-v0':
            self.env_name = 'hopper-rand-params'
            self.args.task_dim = 41
        # if env_name
        self.env = NormalizedBoxEnv(ENVS[self.env_name]())

        if self.env_name == 'cheetah-vel':
            self.env.set_velocity(-2) # set velocity (-2)
        elif self.env_name == 'cheetah-dir':
            self.env.set_direction(-1) # set direction (backward)
        elif self.env_name== 'ant-goal':
            self.env.set_goal_position(1.5*np.pi,3) # set goal (angle = 1.5 pi, radius = 3)
        elif self.env_name == 'ant-dir':
            self.env.set_direction(1.5*np.pi) # set direction (angle = 1.5 pi)
        elif 'params' in self.env_name:
            self.env.set_test_task()
        self.env.set_seed(self.args.seed)

        if self.args.debug:
            pass
        else:
            if self.env_name == 'cheetah-dir':
                wandb.init(project = f'Meta Test cheetah-vel -> cheetah-dir',
                       name = f'VariBad ({self.args.seed})',
                       group = 'VariBad')
            else:
                wandb.init(project = f'Meta Test {self.env_name}',
                        name = f'VariBad ({self.args.seed})',
                        group = 'VariBad')
            
        utl.seed(self.args.seed, self.args.deterministic_execution)

        # calculate number of updates and keep count of frames/iterations
        self.num_updates = int(args.num_frames) // args.policy_num_steps // args.num_processes
        self.iter_idx = -1

        # calculate what the maximum length of the trajectories is
        self.args.max_trajectory_len = 200
        self.args.max_trajectory_len *= self.args.max_rollouts_per_task

        # get policy input dimensions
        self.args.state_dim = self.env.observation_space.shape[0]+1
        self.args.belief_dim = 0
        self.args.num_states = None

        # get policy output (action) dimensions
        self.args.action_dim = self.env.action_space.shape[0]

        # initialise VAE and policy
        self.vae = VaribadVAE(self.args, None, lambda: self.iter_idx)
        self.policy = self.initialise_policy()

        # load pretrained weights
        loaded = torch.load(f'varibad_policy/{env_name}/policy.pt', map_location=next(self.policy.actor_critic.parameters()).device)
        # self.policy.actor_critic.load_state_dict(loaded.state_dict())
        self.policy.actor_critic = loaded.to(device)

        loaded = torch.load(f'varibad_policy/{env_name}/encoder.pt', map_location=next(self.vae.encoder.parameters()).device)
        self.vae.encoder.load_state_dict(loaded.state_dict())

        # loaded = torch.load(f'varibad_policy/{env_name}/task_decoder.pt', map_location=next(self.vae.task_decoder.parameters()).device)
        # self.vae.task_decoder.load_state_dict(loaded.state_dict())

        # self.vae.encoder.load_state_dict(torch.load(f'varibad_policy/{env_name}/encoder.pt'))
        # self.vae.task_decoder.load_state_dict(torch.load(f'varibad_policy/{env_name}/task_decoder.pt'))
        
        # initialize max return
        self.max_return = -np.inf


    def initialise_policy(self):

        # initialise policy network
        policy_net = Policy(
            args=self.args,
            pass_state_to_policy=self.args.pass_state_to_policy,
            pass_latent_to_policy=self.args.pass_latent_to_policy,
            pass_belief_to_policy=self.args.pass_belief_to_policy,
            pass_task_to_policy=self.args.pass_task_to_policy,
            dim_state=self.args.state_dim,
            dim_latent=self.args.latent_dim * 2,
            dim_belief=self.args.belief_dim,
            dim_task=self.args.task_dim,
            #
            hidden_layers=self.args.policy_layers,
            activation_function=self.args.policy_activation_function,
            policy_initialisation=self.args.policy_initialisation,
            #
            action_space=self.env.action_space,
            init_std=self.args.policy_init_std,
        ).to(device)

        # initialise policy trainer
        if self.args.policy == 'a2c':
            policy = A2C(
                self.args,
                policy_net,
                self.args.policy_value_loss_coef,
                self.args.policy_entropy_coef,
                policy_optimiser=self.args.policy_optimiser,
                policy_anneal_lr=self.args.policy_anneal_lr,
                train_steps=self.num_updates,
                optimiser_vae=self.vae.optimiser_vae,
                lr=self.args.lr_policy,
                eps=self.args.policy_eps,
            )
        elif self.args.policy == 'ppo':
            policy = PPO(
                self.args,
                policy_net,
                self.args.policy_value_loss_coef,
                self.args.policy_entropy_coef,
                policy_optimiser=self.args.policy_optimiser,
                policy_anneal_lr=self.args.policy_anneal_lr,
                train_steps=self.num_updates,
                lr=self.args.lr_policy,
                eps=self.args.policy_eps,
                ppo_epoch=self.args.ppo_num_epochs,
                num_mini_batch=self.args.ppo_num_minibatch,
                use_huber_loss=self.args.ppo_use_huberloss,
                use_clipped_value_loss=self.args.ppo_use_clipped_value_loss,
                clip_param=self.args.ppo_clip_param,
                optimiser_vae=self.vae.optimiser_vae,
            )
        else:
            raise NotImplementedError

        return policy


    def eval(self):
        def get_latent_for_policy(args, latent_sample=None, latent_mean=None, latent_logvar=None):

            if (latent_sample is None) and (latent_mean is None) and (latent_logvar is None):
                return None

            if args.add_nonlinearity_to_latent:
                latent_sample = F.relu(latent_sample)
                latent_mean = F.relu(latent_mean)
                latent_logvar = F.relu(latent_logvar)
            if args.sample_embeddings:
                latent = latent_sample
            else:
                latent = torch.cat((latent_mean, latent_logvar), dim=-1)
            latent = latent.squeeze()

            return latent
        
        def squash_action(action, args):
            if args.norm_actions_post_sampling:
                return torch.tanh(action)
            else:
                return action
            
        def update_encoding(encoder, next_obs, action, reward, done, hidden_state):
            # reset hidden state of the recurrent net when we reset the task
            if done is not None:
                hidden_state = encoder.reset_hidden(hidden_state, done)
            with torch.no_grad():
                latent_sample, latent_mean, latent_logvar, hidden_state = encoder(actions=action.float(),
                                                                                states=next_obs,
                                                                                rewards=reward,
                                                                                hidden_state=hidden_state,
                                                                                return_prior=False)

            return latent_sample, latent_mean, latent_logvar, hidden_state


        total_count = 0
        if self.vae.encoder is not None:
            # reset latent state to prior
            latent_sample, latent_mean, latent_logvar, hidden_state = self.vae.encoder.prior(1)
        else:
            latent_sample = latent_mean = latent_logvar = hidden_state = None
        latent_sample = latent_sample.squeeze(0)
        while True:
            state = self.env.reset()
            state = torch.from_numpy(state).float().to(device)
            state = torch.concat((state, torch.zeros(1).to(device)), dim=0)  # add done info
            episode_return = 0
            for i in range(200):
                with torch.no_grad():
                    latent = get_latent_for_policy(args=self.args,
                                                latent_sample=latent_sample,
                                                latent_mean=latent_mean,
                                                latent_logvar=latent_logvar)
                    action = self.policy.act(state=state, latent=latent, belief=None, task=None, deterministic=True)
                    if isinstance(action, list) or isinstance(action, tuple):
                        value, action = action
                    else:
                        value = None
                action = action.to(device)
                action = squash_action(action, self.args).unsqueeze(0)
                next_obs, reward, done, infos = self.env.step(action.cpu().detach().numpy()[0])
                next_obs = np.concatenate((next_obs, [float(done)]))
                next_obs = torch.from_numpy(next_obs).float().to(device).unsqueeze(0)
                reward = torch.from_numpy(np.array([reward])).float().to(device).unsqueeze(0)
                done = torch.tensor(done).float().to(device)
                if self.vae.encoder is not None:
                    # update the hidden state
                    latent_sample, latent_mean, latent_logvar, hidden_state = update_encoding(encoder=self.vae.encoder,
                                                                                            next_obs=next_obs,
                                                                                            action=action,
                                                                                            reward=reward,
                                                                                            done=None,
                                                                                            hidden_state=hidden_state)
                state = next_obs[0]
                total_count += 1
                episode_return += reward.item()
                
                if done:
                    break
            print(total_count, episode_return)
            if self.args.debug:
                pass
            else:
                wandb.log({f"Return avg":episode_return},step=total_count)
            if total_count >= 200000:
                break
        
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--env-type', default='cheetah_dir_varibad')
    
    args, rest_args = parser.parse_known_args()
    env = args.env_type
    if env == 'cheetah_dir_varibad':
        args = args_cheetah_dir_varibad.get_args(rest_args)
    elif env == 'cheetah_vel_varibad':
        args = args_cheetah_vel_varibad.get_args(rest_args)
    elif env == 'ant_dir_varibad':
        args = args_ant_dir_varibad.get_args(rest_args)
    elif env == 'ant_goal_varibad':
        args = args_ant_goal_varibad.get_args(rest_args)
    elif env == 'walker_varibad':
        args = args_walker_varibad.get_args(rest_args)
    elif env == 'hopper_varibad':
        args = args_hopper_varibad.get_args(rest_args)
    args.seed = args.seed[0]
    learner = MetaLearner(args)
    learner.eval()

main()