from typing import List, Callable, Union, Any, TypeVar, Tuple, overload
import yaml
import sys
import os
from datetime import datetime
from abc import abstractmethod
import argparse

import wandb
import numpy as np

import torch as th
from torch.optim import Adam
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

import stable_baselines3 as sb3
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.logger import configure

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from model.vae.encoder import Encoder
from model.vae.decoder import Decoder
from model.reward import Reward
from model.discriminator import Discriminator
from model.sac.sac import SACCustomReward

import wrapper
from buffer import ExpertBuffer, LearnerBuffer
import env_utils
from utils import *


class TrainTraIRLTransferBase():
    def __init__(self, config: dict) -> None:
        self.source_env_name = config['source_base_env_id']
        self.target_env_name = config['target_base_env_id']
        self.config = config

        # init encoder
        self.source_encoder = Encoder(in_dim=self.config['source_state_dim'], out_dim=self.config['abstraction_dim'], 
                                      hidden_dims=self.config['source_encoder_hidden_dims'], device=self.config['device']).to(self.config['device'])
        self.target_encoder = Encoder(in_dim=self.config['target_state_dim'], out_dim=self.config['abstraction_dim'], 
                                      hidden_dims=self.config['target_encoder_hidden_dims'], device=self.config['device']).to(self.config['device'])
        self.target_encoder_optimizer = Adam(self.target_encoder.parameters(), self.config['target_encoder_lr'])

        # init decoders, the # of decoders is equivalent to the # of the source envs
        self.source_decoder = Decoder(in_dim=self.config['abstraction_dim'], out_dim=self.config['source_state_dim'],
                                      hidden_dims=self.config['source_decoder_hidden_dims'], device=self.config['device']).to(self.config['device'])
        self.target_decoder = Decoder(in_dim=self.config['abstraction_dim'], out_dim=self.config['target_state_dim'],
                                      hidden_dims=self.config['target_decoder_hidden_dims'], device=self.config['device']).to(self.config['device'])
        self.target_decoder_optimizer = Adam(self.target_decoder.parameters(), self.config['target_decoder_lr'])

        # init reward net, there is only one reward net, the input can be original state (use_encoder=False) or abstract state (use_encoder=True)
        self.reward_net = Reward(self.config['reward_in_dim'], self.config['reward_hidden_dims'], 
                                 current_obs_only=self.config['current_obs_only'], use_encoder=True, 
                                 device=self.config['device']).to(self.config['device'])

        # init disc net, there is only one disc net
        self.disc_net = Discriminator(self.config['disc_in_dim'], self.config['disc_hidden_dims'], self.config['current_obs_only']).to(self.config['device'])

        target_env_wrapper_kwargs_dict = {
            'reward_net': self.reward_net,
            'encoder': self.target_encoder,
        }

        target_env_wrapper_kwargs_dict.update(self.config['target_env_wrapper_kwargs'])

        self.target_env = make_vec_env(env_id=self.config['target_env_id'], n_envs=self.config['target_env_n_envs'], seed=1234,
                                       env_kwargs=self.config['target_env_kwargs'],
                                       wrapper_class=getattr(wrapper, self.config['target_env_wrapper']) if self.config['target_env_wrapper'] else None,
                                       wrapper_kwargs=target_env_wrapper_kwargs_dict)

        self.target_envs_sample = getattr(env_utils, self.config['target_env_init_func'])(self.target_env_name,  self, **self.config['target_env_init_func_kwargs'])

        action_noise = getattr(sb3.common.noise, self.config['action_noise']['type'])(
            mean = np.zeros(self.config['target_env_action_dim']), 
            sigma = np.ones(self.config['target_env_action_dim']) * self.config['action_noise']['std'], 
        )

        self.target_policy = SACCustomReward(policy=self.config['policy_type'], env=self.target_env, 
                                            policy_kwargs=self.config['policy_kwargs'],
                                            learning_rate=self.config['policy_lr'], action_noise=action_noise, 
                                            reward_net=self.reward_net,
                                            encoder=self.target_encoder,
                                            reward_use_encoder=True,
                                            verbose=1,
                                            stats_window_size=20,
                                            learning_starts=self.config['policy_learning_starts'],
                                            tau=self.config['policy_tau'],
                                            buffer_size=self.config['policy_buffer_size'],
                                            gamma=self.config['policy_gamma'],
                                            seed=1234,)

        wandb_id = f'few_shot_{datetime.now().strftime("%Y_%m_%d_%H_%M_%S")}'
        logger_path = f"runs/trairl_few_shot/{config['target_env_id']}"
        wandb.tensorboard.patch(root_logdir=f'{logger_path}/{wandb_id}/log')

        # Start a new wandb run to track this script.
        self.wandb = wandb.init(
            project="TraIRL",
            config=config,
            sync_tensorboard=True,
            save_code=True,
            id=wandb_id
        )

        new_logger = configure(f'{logger_path}/{wandb_id}/log/sb3/{self.target_env_name}', ["stdout", "csv", "tensorboard", "log"])
        self.target_policy.set_logger(new_logger)

        self.run_dir = f'{logger_path}/{wandb_id}'
        self.writer = SummaryWriter(f'{self.run_dir}/log/tb', flush_secs=30)
        self.log_dict = {}

        # init buffers
        self.expert_buffer = ExpertBuffer(self.config['target_expert_file'], device=self.config['device'])
        self.learner_buffer = LearnerBuffer(self.config['learner_buffer_size'], obs_dim=self.config['target_state_dim'], 
                                            action_dim=self.config['target_env_action_dim'])

        self.load_source()

    def set_policy_training(self, mode: bool = True) -> None:
        self.target_policy.policy.set_training_mode(mode)

    def update_policies(self) -> None:
        self.target_policy.learn(int(self.config['policy_update_steps']), reset_num_timesteps=False)

    def save_target(self) -> None:
        timestep = self.target_policy.num_timesteps
        save_folder = f'{self.run_dir}/saved_model/{timestep}'
        os.makedirs(save_folder, exist_ok=True)

        self.target_policy.save(f'{save_folder}/policy')
        th.save(self.target_encoder.state_dict(), f'{save_folder}/encoder.pth')
        th.save(self.target_decoder.state_dict(), f'{save_folder}/decoder.pth')
        th.save(self.reward_net.state_dict(), f'{save_folder}/reward_net.pth')
        th.save(self.disc_net.state_dict(), f'{save_folder}/disc_net.pth')

    def load_target(self, path: str) -> None:
        self.target_policy.set_parameters(path, exact_match=True)
        self.target_encoder.load_state_dict(th.load(path + f'/encoder.pth'))
        self.target_decoder.load_state_dict(th.load(path + f'/decoder.pth'))
        self.reward_net.load_state_dict(th.load(path + f'/reward_net.pth'))
        self.disc_net.load_state_dict(th.load(path + f'/disc_net.pth'))

        self.reward_net.eval()
        self.disc_net.eval()

    def load_source(self) -> None:
        print(f'Loading source model...')
        self.source_encoder.load_state_dict(th.load(self.config['source_env_encoder_path']))
        self.source_decoder.load_state_dict(th.load(self.config['source_env_decoder_path']))
        self.reward_net.load_state_dict(th.load(self.config['source_env_reward_path']))
        self.disc_net.load_state_dict(th.load(self.config['source_env_disc_path']))

        self.source_encoder.eval()
        self.source_decoder.eval()
        self.reward_net.eval()
        self.disc_net.eval()

    def update_learner_buffer(self) -> None:
        env = self.target_envs_sample
        obs, info = env.reset()
        for _ in range(self.config['learner_buffer_update_size']):
            action, _ = self.target_policy.predict(obs)
            next_obs, reward, terminated, truncated, info = env.step(action)
            self.learner_buffer.add(obs, action, reward, next_obs, terminated, truncated)
            obs = next_obs
            done = terminated or truncated
            if done:
                obs, info = env.reset()

    @abstractmethod
    def update_target_vae(self) -> None:
        raise NotImplementedError

    def _update_tb_logs(self, loss_name: str, loss: th.Tensor, env_id: str) -> None:
        if loss_name not in self.log_dict:
            self.log_dict[loss_name] = {}
        if env_id not in self.log_dict[loss_name]:
            self.log_dict[loss_name][env_id] = th.tensor([])
        self.log_dict[loss_name][env_id] = th.cat((self.log_dict[loss_name][env_id], loss.detach().cpu().unsqueeze(0)))

    def _dump_tb_logs(self) -> None:
        '''
        self.log_dict is a dictionary that contains the logs to be dumped to tensorboard
        {
            kld_loss: {
                env_id_1: th.Tensor,
                env_id_2: th.Tensor,
            },
            recon_loss: {
                env_id_1: th.Tensor,
                env_id_2: th.Tensor,
            },
        }
        '''
        timestep = self.target_policy.num_timesteps
        for loss_name, loss_dict in self.log_dict.items():
            self.writer.add_scalars(f'{loss_name}', 
                                    {
                                        env_id: loss.mean() for env_id, loss in loss_dict.items()
                                    }, 
                                    timestep)

        self.log_dict = {}

    def train(self, epochs=1_000_000) -> None:
        for epoch in range(epochs):
            self.set_policy_training(True)            

            self.update_policies()

            self.update_learner_buffer()

            self.set_policy_training(False)            

            self.update_target_vae()

            self._dump_tb_logs()

            if epoch % 20 == 0 :
                self.save_target()


class TrainTraIRLOneShot(TrainTraIRLTransferBase):
    def __init__(self, config: dict) -> None:
        super().__init__(config)   
    
    def update_target_vae(self) -> None:
        self.target_encoder.train()
        self.target_decoder.train() 

        for step in range(self.config['vae_update_steps']):

            total_loss = th.zeros(1, device=self.config['device'])

            expert_obs, expert_action, expert_next_obs, expert_reward, expert_termination, expert_truncated = self.expert_buffer.sample(self.target_env_name)
            learner_obs, learner_action, learner_next_obs, learner_reward, learner_termination, learner_truncated = self.learner_buffer.sample()

            # add noise to expert_obs
            expert_obs += th.normal(th.zeros_like(expert_obs), expert_obs.std(dim=0, keepdim=True) / 20)

            # encode obs
            expert_obs_z, expert_obs_mu, expert_obs_logvar, expert_obs = self.target_encoder.forward(expert_obs)
            learner_obs_z, learner_obs_mu, learner_obs_logvar, learner_obs = self.target_encoder.forward(learner_obs)

            # decode obs
            expert_obs_recon = self.target_decoder.forward(expert_obs_z)
            learner_obs_recon = self.target_decoder.forward(learner_obs_z)

            # obs vae loss
            expert_obs_recon_loss, expert_obs_kld_loss = vae_loss(expert_obs_recon, expert_obs, expert_obs_mu, expert_obs_logvar)
            learner_obs_recon_loss, learner_obs_kld_loss = vae_loss(learner_obs_recon, learner_obs, learner_obs_mu, learner_obs_logvar)

            if self.config['current_obs_only']:  # only use current obs
                expert_next_obs_z, expert_next_obs_mu, expert_next_obs_logvar, expert_next_obs = None, None, None, None
                learner_next_obs_z, learner_next_obs_mu, learner_next_obs_logvar, learner_next_obs = None, None, None, None
                
                expert_next_obs_recon = None
                learner_next_obs_recon = None
            else:
                # encode next obs
                expert_next_obs_z, expert_next_obs_mu, expert_next_obs_logvar, expert_next_obs = self.target_encoder.forward(expert_next_obs)
                learner_next_obs_z, learner_next_obs_mu, learner_next_obs_logvar, learner_next_obs = self.target_encoder.forward(learner_next_obs)

                # decode next obs
                expert_next_obs_recon = self.target_decoder.forward(expert_next_obs_z)
                learner_next_obs_recon = self.target_decoder.forward(learner_next_obs_z)

            # next obs vae loss, if expert_next_obs_recon is None, then the loss is 0
            expert_next_obs_recon_loss, expert_next_obs_kld_loss = vae_loss(expert_next_obs_recon, expert_next_obs, expert_next_obs_mu, expert_next_obs_logvar)
            learner_next_obs_recon_loss, learner_next_obs_kld_loss = vae_loss(learner_next_obs_recon, learner_next_obs, learner_next_obs_mu, learner_next_obs_logvar)                    
                
            # disc output
            expert_disc_logit = self.disc_net.forward(expert_obs_z, expert_next_obs_z)
            learner_disc_logit = self.disc_net.forward(learner_obs_z, learner_next_obs_z)

            # calculate reward
            expert_reward, expert_reward_logit = self.reward_net.forward(expert_obs, expert_next_obs, encoder=self.target_encoder)
            learner_reward, learner_reward_logit = self.reward_net.forward(learner_obs, learner_next_obs, encoder=self.target_encoder)

            # reward loss
            expert_reward_cov = (expert_disc_logit * expert_reward_logit).mean() - expert_disc_logit.mean() * expert_reward_logit.mean()
            learner_reward_cov = th.mean(((learner_disc_logit * learner_reward_logit).mean(0) - learner_disc_logit.mean(0) * learner_reward_logit.mean(0)) / 
                                         ((learner_disc_logit.std(0) + 1e-8) * (learner_reward_logit.std(0) + 1e-8)))

            # disc loss
            disc_loss = - expert_disc_logit.mean() + learner_disc_logit.mean()
            disc_gradient_penalty_loss = gradient_penalty(self.disc_net, expert_obs_z, learner_obs_z, expert_next_obs_z, learner_next_obs_z)

            total_loss  +=  self.config['disc_loss_weight'] * disc_loss + \
                            self.config['disc_gradient_penalty_weight'] * disc_gradient_penalty_loss + \
                            self.config['vae_recon_weight'] * (learner_obs_recon_loss + learner_next_obs_recon_loss) + \
                            self.config['vae_kld_weight'] * (learner_obs_kld_loss +  learner_next_obs_kld_loss)

            # cycle consistency loss
            expert_obs, expert_action, expert_next_obs, expert_reward, expert_termination, expert_truncated = self.expert_buffer.sample(self.target_env_name)
            learner_obs, learner_action, learner_next_obs, learner_reward, learner_termination, learner_truncated = self.learner_buffer.sample()

            # Encode observations to get z
            expert_z, expert_mu, expert_logvar, _ = self.target_encoder.forward(expert_obs)
            learner_z, learner_mu, learner_logvar, _ = self.target_encoder.forward(learner_obs)

            # Decode z using the decoder of the source environment
            expert_recon_obs_cycle = self.source_decoder.forward(expert_z)
            learner_recon_obs_cycle = self.source_decoder.forward(learner_z)

            # Encode recon_obs_cycle to get z_cycle
            expert_z_cycle, expert_mu_cycle, expert_logvar_cycle, _ = self.source_encoder.forward(expert_recon_obs_cycle)
            learner_z_cycle, learner_mu_cycle, learner_logvar_cycle, _ = self.source_encoder.forward(learner_recon_obs_cycle)

            expert_kl_div = kl_div(expert_mu_cycle, expert_logvar_cycle, expert_mu, expert_logvar)
            learner_kl_div = kl_div(learner_mu_cycle, learner_logvar_cycle, learner_mu, learner_logvar)                

            total_loss += self.config['cycle_consistency_weight'] * (expert_kl_div + learner_kl_div)

            # update disc & vae 
            self.target_encoder_optimizer.zero_grad()
            self.target_decoder_optimizer.zero_grad()

            total_loss.backward()

            self.target_encoder_optimizer.step()
            self.target_decoder_optimizer.step()

            # log
            self._update_tb_logs('disc_loss', disc_loss, self.target_env_name)
            self._update_tb_logs('vae_recon_loss', learner_obs_recon_loss + learner_next_obs_recon_loss, self.target_env_name)
            self._update_tb_logs('vae_kld_loss', learner_obs_kld_loss + learner_next_obs_kld_loss, self.target_env_name)
            self._update_tb_logs('disc_gradient_penalty_loss', disc_gradient_penalty_loss, self.target_env_name)
            self._update_tb_logs('cycle_consistency_loss', learner_kl_div, self.target_env_name)

            self._update_tb_logs('expert_disc_logit', expert_disc_logit.mean(), self.target_env_name)
            self._update_tb_logs('learner_disc_logit', learner_disc_logit.mean(), self.target_env_name)
            self._update_tb_logs('expert_reward_logit', expert_reward_logit.mean(), self.target_env_name)
            self._update_tb_logs('learner_reward_logit', learner_reward_logit.mean(), self.target_env_name)
            self._update_tb_logs('expert_reward_cov', expert_reward_cov, self.target_env_name)
            self._update_tb_logs('learner_reward_cov', learner_reward_cov, self.target_env_name)
            
        
        self.target_encoder.eval()
        self.target_decoder.eval()  


class TrainTraIRLZeroShot(TrainTraIRLTransferBase):
    def __init__(self, config: dict) -> None:
        super().__init__(config)   
    
    def update_target_vae(self) -> None:
        self.target_encoder.train()
        self.target_decoder.train() 

        for step in range(self.config['vae_update_steps']):

            total_loss = th.zeros(1, device=self.config['device'])

            expert_obs, expert_action, expert_next_obs, expert_reward, expert_termination, expert_truncated = self.expert_buffer.sample(self.target_env_name)
            learner_obs, learner_action, learner_next_obs, learner_reward, learner_termination, learner_truncated = self.learner_buffer.sample()

            # HER expert_obs
            sim_expert_obs = expert_obs.clone()
            sim_expert_obs[:, 8] = th.normal(th.ones_like(expert_obs[:, 8]) * 7, th.ones_like(expert_obs[:, 8]) * 1.5)

            # encode obs
            expert_obs_z, expert_obs_mu, expert_obs_logvar, expert_obs = self.target_encoder.forward(expert_obs)
            learner_obs_z, learner_obs_mu, learner_obs_logvar, learner_obs = self.target_encoder.forward(learner_obs)
            sim_expert_obs_z, sim_expert_obs_mu, sim_expert_obs_logvar, sim_expert_obs = self.target_encoder.forward(sim_expert_obs)

            # decode obs
            expert_obs_recon = self.target_decoder.forward(expert_obs_z)
            learner_obs_recon = self.target_decoder.forward(learner_obs_z)
            sim_expert_obs_recon = self.target_decoder.forward(sim_expert_obs_z)

            # obs vae loss
            # expert_obs_recon_loss, expert_obs_kld_loss = vae_loss(expert_obs_recon, expert_obs, expert_obs_mu, expert_obs_logvar)
            learner_obs_recon_loss, learner_obs_kld_loss = vae_loss(learner_obs_recon, learner_obs, learner_obs_mu, learner_obs_logvar)
            sim_expert_obs_recon_loss, sim_expert_obs_kld_loss = vae_loss(sim_expert_obs_recon, sim_expert_obs, sim_expert_obs_mu, sim_expert_obs_logvar)

            if self.config['current_obs_only']:  # only use current obs
                expert_next_obs_z, expert_next_obs_mu, expert_next_obs_logvar, expert_next_obs = None, None, None, None
                learner_next_obs_z, learner_next_obs_mu, learner_next_obs_logvar, learner_next_obs = None, None, None, None
                sim_expert_next_obs_z, sim_expert_next_obs_mu, sim_expert_next_obs_logvar, sim_expert_next_obs = None, None, None, None
                
                expert_next_obs_recon = None
                learner_next_obs_recon = None
                sim_expert_next_obs_recon = None
            else:
                # encode next obs
                expert_next_obs_z, expert_next_obs_mu, expert_next_obs_logvar, expert_next_obs = self.target_encoder.forward(expert_next_obs)
                learner_next_obs_z, learner_next_obs_mu, learner_next_obs_logvar, learner_next_obs = self.target_encoder.forward(learner_next_obs)
                sim_expert_next_obs_z, sim_expert_next_obs_mu, sim_expert_next_obs_logvar, sim_expert_next_obs = self.target_encoder.forward(sim_expert_obs)

                # decode next obs
                expert_next_obs_recon = self.target_decoder.forward(expert_next_obs_z)
                learner_next_obs_recon = self.target_decoder.forward(learner_next_obs_z)
                sim_expert_next_obs_recon = self.target_decoder.forward(sim_expert_next_obs_z)

            # next obs vae loss, if expert_next_obs_recon is None, then the loss is 0
            # expert_next_obs_recon_loss, expert_next_obs_kld_loss = vae_loss(expert_next_obs_recon, expert_next_obs, expert_next_obs_mu, expert_next_obs_logvar)
            learner_next_obs_recon_loss, learner_next_obs_kld_loss = vae_loss(learner_next_obs_recon, learner_next_obs, learner_next_obs_mu, learner_next_obs_logvar)     
            sim_expert_next_obs_recon_loss, sim_expert_next_obs_kld_loss = vae_loss(sim_expert_next_obs_recon, sim_expert_next_obs, sim_expert_next_obs_mu, sim_expert_next_obs_logvar)               
                
            # disc output
            expert_disc_logit = self.disc_net.forward(expert_obs_z, expert_next_obs_z)
            learner_disc_logit = self.disc_net.forward(learner_obs_z, learner_next_obs_z)
            sim_expert_disc_logit = self.disc_net.forward(sim_expert_obs_z, sim_expert_next_obs_z)

            # disc loss
            disc_loss = - sim_expert_disc_logit.mean() + learner_disc_logit.mean()

            # disc gradient penalty loss
            gradient_penalty_loss = gradient_penalty(self.disc_net, sim_expert_obs_z, learner_obs_z, sim_expert_next_obs_z, learner_next_obs_z)

            # calculate reward
            expert_reward, expert_reward_logit = self.reward_net.forward(expert_obs, expert_next_obs, encoder=self.target_encoder)
            learner_reward, learner_reward_logit = self.reward_net.forward(learner_obs, learner_next_obs, encoder=self.target_encoder)
            sim_expert_reward, sim_expert_reward_logit = self.reward_net.forward(sim_expert_obs, sim_expert_next_obs, encoder=self.target_encoder)

            # reward cov
            expert_reward_cov = (expert_disc_logit * expert_reward_logit).mean() - expert_disc_logit.mean() * expert_reward_logit.mean()
            learner_reward_cov = th.mean(((learner_disc_logit * learner_reward_logit).mean(0) - learner_disc_logit.mean(0) * learner_reward_logit.mean(0)) / ((learner_disc_logit.std(0) + 1e-8) * (learner_reward_logit.std(0) + 1e-8)))
            sim_expert_reward_cov = th.mean(((sim_expert_disc_logit * sim_expert_reward_logit).mean(0) - sim_expert_disc_logit.mean(0) * sim_expert_reward_logit.mean(0)) / ((sim_expert_disc_logit.std(0) + 1e-8) * (sim_expert_reward_logit.std(0) + 1e-8)))

            # disc & criteria cov loss
            learner_cov_loss = (learner_disc_logit * learner_obs[:, 8]).mean() - learner_disc_logit.mean() * learner_obs[:, 8].mean()
            sim_expert_cov_loss = (sim_expert_disc_logit * sim_expert_obs[:, 8]).mean() - sim_expert_disc_logit.mean() * sim_expert_obs[:, 8].mean() 

            cov_loss = - learner_cov_loss - sim_expert_cov_loss
            cov_loss_weight = self.config['cov_loss_weight']

            total_loss += cov_loss_weight * cov_loss + \
                          self.config['disc_loss_weight'] * disc_loss + \
                          self.config['disc_gradient_penalty_weight'] * gradient_penalty_loss + \
                          self.config['vae_recon_weight'] * (learner_obs_recon_loss + learner_next_obs_recon_loss) + \
                          self.config['vae_kld_weight'] * (learner_obs_kld_loss +  learner_next_obs_kld_loss)

            # cycle consistency loss
            # expert_obs, expert_action, expert_next_obs, expert_reward, expert_termination, expert_truncated = self.expert_buffer.sample(self.target_env_name)
            learner_obs, learner_action, learner_next_obs, learner_reward, learner_termination, learner_truncated = self.learner_buffer.sample()

            # Encode observations to get z
            # expert_z, expert_mu, expert_logvar, _ = self.target_encoder.forward(expert_obs)
            learner_z, learner_mu, learner_logvar, _ = self.target_encoder.forward(learner_obs)

            # Decode z using the decoder of the source environment
            # expert_recon_obs_cycle = self.source_decoder.forward(expert_z)
            learner_recon_obs_cycle = self.source_decoder.forward(learner_z)

            # Encode recon_obs_cycle to get z_cycle
            # expert_z_cycle, expert_mu_cycle, expert_logvar_cycle, _ = self.source_encoder.forward(expert_recon_obs_cycle)
            learner_z_cycle, learner_mu_cycle, learner_logvar_cycle, _ = self.source_encoder.forward(learner_recon_obs_cycle)

            # expert_kl_div = kl_div(expert_mu_cycle, expert_logvar_cycle, expert_mu, expert_logvar)
            learner_kl_div = kl_div(learner_mu_cycle, learner_logvar_cycle, learner_mu, learner_logvar)                

            total_loss += self.config['cycle_consistency_weight'] * learner_kl_div

            # update disc & vae 
            self.target_encoder_optimizer.zero_grad()
            self.target_decoder_optimizer.zero_grad()

            total_loss.backward()

            self.target_encoder_optimizer.step()
            self.target_decoder_optimizer.step()

            # log
            self._update_tb_logs('vae_recon_loss', learner_obs_recon_loss + learner_next_obs_recon_loss, self.target_env_name)
            self._update_tb_logs('vae_kld_loss', learner_obs_kld_loss + learner_next_obs_kld_loss, self.target_env_name)
            self._update_tb_logs('cycle_consistency_loss', learner_kl_div, self.target_env_name)
            self._update_tb_logs('scaled_cov_loss', cov_loss_weight * cov_loss, self.target_env_name)
            self._update_tb_logs('original_cov_loss', cov_loss, self.target_env_name)

            self._update_tb_logs('expert_disc_logit', expert_disc_logit.mean(), self.target_env_name)
            self._update_tb_logs('learner_disc_logit', learner_disc_logit.mean(), self.target_env_name)
            self._update_tb_logs('expert_reward_logit', expert_reward_logit.mean(), self.target_env_name)
            self._update_tb_logs('learner_reward_logit', learner_reward_logit.mean(), self.target_env_name)
            self._update_tb_logs('expert_reward_cov', expert_reward_cov, self.target_env_name)
            self._update_tb_logs('learner_reward_cov', learner_reward_cov, self.target_env_name)
            
        
        self.target_encoder.eval()
        self.target_decoder.eval()  
  
    


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, default='trairl_one_shot_halfcheetah.yaml')
    args = parser.parse_args()
    config = yaml.safe_load(open(f"./config/{args.config}", "r"))

    # config = yaml.safe_load(open("./config/trairl_one_shot_halfcheetah.yaml", "r"))
    # trainer = TrainTraIRLOneShot(config)

    # config = yaml.safe_load(open("./config/trairl_zero_shot_halfcheetah.yaml", "r"))
    trainer = TrainTraIRLZeroShot(config)

    trainer.train()
