from typing import List, Callable, Union, Any, TypeVar, Tuple, overload
import yaml
import sys
import os
from datetime import datetime
import argparse

import wandb
import numpy as np

import torch as th
from torch.optim import Adam
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

import stable_baselines3 as sb3
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.logger import configure

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from model.vae.encoder import Encoder
from model.vae.decoder import Decoder
from model.reward import Reward
from model.discriminator import Discriminator
from model.sac.sac import SACCustomReward

import wrapper
from buffer import ExpertBuffer, LearnerBuffer
import env_utils
from utils import *


class Train_TraIRL():
    def __init__(self, config: dict) -> None:
        self.source_env_name = config['source_env_name']
        self.config = config

        self.source_envs = {name: None for name in self.source_env_name}

        # init encoder
        if self.config['use_single_encoder']:
            encoder = Encoder(in_dim=self.config['state_dim'], out_dim=self.config['abstraction_dim'], 
                              hidden_dims=self.config['encoder_hidden_dims'], device=self.config['device']).to(self.config['device'])
            self.encoders = {env_id: encoder for env_id in self.source_envs}
        else:
            self.encoders = {
                env_id: Encoder(in_dim=self.config['state_dim'][env_id], out_dim=self.config['abstraction_dim'], 
                                hidden_dims=self.config['encoder_hidden_dims'], device=self.config['device']).to(self.config['device']) for env_id in self.source_envs
            }
        self.encoders_optimizer = {env_id: Adam(encoder.parameters(), self.config['encoder_lr']) for env_id, encoder in self.encoders.items()}

        # init decoders, the # of decoders is equivalent to the # of the source envs
        self.decoders = {
            env_id: Decoder(self.config['abstraction_dim'], self.config['decoder_out_dim'][env_id], 
                            self.config['decoder_hidden_dims']).to(self.config['device']) for env_id in self.source_envs
        }
        self.decoders_optimizer = {env_id: Adam(decoder.parameters(), self.config['decoder_lr']) for env_id, decoder in self.decoders.items()}

        # init reward net, there is only one reward net, the input can be original state (use_encoder=False) or abstract state (use_encoder=True)
        self.reward_net = Reward(self.config['reward_in_dim'], self.config['reward_hidden_dims'], 
                                 current_obs_only=self.config['current_obs_only'], use_encoder=self.config['reward_use_encoder'], 
                                 device=self.config['device']).to(self.config['device'])
        self.reward_net_optimizer = Adam(self.reward_net.parameters(), self.config['reward_net_lr'], weight_decay=self.config['reward_weight_decay'])

        # init disc net, there is only one disc net
        self.disc_net = Discriminator(self.config['disc_in_dim'], self.config['disc_hidden_dims'], self.config['current_obs_only']).to(self.config['device'])
        self.disc_net_optimzier = Adam(self.disc_net.parameters(), self.config['disc_net_lr'], weight_decay=self.config['disc_weight_decay'])

        self.source_envs = {name: make_vec_env(env_id=self.config['base_env_id'], n_envs=self.config['n_envs'], seed=1234,
                                               env_kwargs=self.config['env_kwargs'],
                                               wrapper_class=getattr(wrapper, self.config['env_wrapper']), 
                                               wrapper_kwargs={
                                                   'reward_net': self.reward_net,
                                                   'encoder': self.encoders[name],
                                                   'joints_status': self.config['joints_status'][name],
                                               }) for name in self.source_env_name}
        
        self.source_envs_sample = {name: getattr(env_utils, self.config['env_init_func'])(name, self, self.config['joints_status'][name]) for name in self.source_env_name}

        action_noise = getattr(sb3.common.noise, self.config['action_noise']['type'])(
            mean = np.zeros(self.config['action_dim']), 
            sigma = np.ones(self.config['action_dim']) * self.config['action_noise']['std'], 
        )

        self.policies = {env_id: SACCustomReward(policy=self.config['policy_type'], env=env, 
                                                 policy_kwargs=self.config['policy_kwargs'],
                                                 learning_rate=self.config['policy_lr'], action_noise=action_noise, 
                                                 reward_net=self.reward_net,
                                                 encoder=self.encoders[env_id],
                                                 reward_use_encoder=self.config['reward_use_encoder'],
                                                 verbose=1,
                                                 stats_window_size=20,
                                                 learning_starts=10_000,
                                                 tau=self.config['policy_tau'],)
                                                 for env_id, env in self.source_envs.items()}
        
        wandb_id = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
        logger_path = f"runs/trairl/{config['base_env_id']}"
        wandb.tensorboard.patch(root_logdir=f'{logger_path}/{wandb_id}/log')

        # Start a new wandb run to track this script.
        self.wandb = wandb.init(
            project="TraIRL",
            config=config,
            sync_tensorboard=True,
            save_code=True,
            id=wandb_id
        )

        for env_id in self.policies:
            new_logger = configure(f'{logger_path}/{wandb_id}/log/sb3/{env_id}', ["stdout", "csv", "tensorboard", "log"])
            self.policies[env_id].set_logger(new_logger)
        
        self.run_dir = f'{logger_path}/{wandb_id}'
        self.writer = SummaryWriter(f'{self.run_dir}/log/tb', flush_secs=30)
        self.log_dict = {}
        
        # init buffers
        self.expert_buffer = ExpertBuffer(self.config['expert_files'], device=self.config['device'])
        self.learner_buffer = {env_id: LearnerBuffer(self.config['learner_buffer_size'], obs_dim=self.config['state_dim'], action_dim=self.config['action_dim']) for env_id in self.source_envs}
    
    def set_policy_training(self, mode: bool = True) -> None:
        [policy.policy.set_training_mode(mode) for policy in self.policies.values()]

    def update_policies(self) -> None:
        for policy in self.policies.values():
            policy.learn(int(self.config['policy_update_steps']), reset_num_timesteps=False)
    
    def save(self) -> None:
        timestep = self.policies[self.source_env_name[0]].num_timesteps
        save_folder = f'{self.run_dir}/saved_model/{timestep}'
        os.makedirs(save_folder, exist_ok=True)

        for env_id, policy in self.policies.items():
            policy.save(f'{save_folder}/{env_id}')
            th.save(self.encoders[env_id].state_dict(), f'{save_folder}/{env_id}_encoder.pth')
            th.save(self.decoders[env_id].state_dict(), f'{save_folder}/{env_id}_decoder.pth')
        th.save(self.reward_net.state_dict(), f'{save_folder}/{env_id}_reward_net.pth')
        th.save(self.disc_net.state_dict(), f'{save_folder}/{env_id}_disc_net.pth')
    
    def load(self, path: str) -> None:
        for env_id, policy in self.policies.items():
            policy.set_parameters(path + f'/{env_id}', exact_match=True)
            self.encoders[env_id].load_state_dict(th.load(path + f'/{env_id}_encoder.pth'))
            self.decoders[env_id].load_state_dict(th.load(path + f'/{env_id}_decoder.pth'))
        self.reward_net.load_state_dict(th.load(path + f'/{env_id}_reward_net.pth'))
        self.disc_net.load_state_dict(th.load(path + f'/{env_id}_disc_net.pth'))
    
    def update_learner_buffer(self) -> None:
        for env_id, policy in self.policies.items():
            env = self.source_envs_sample[env_id]
            obs, info = env.reset()
            for _ in range(self.config['learner_buffer_update_size']):
                action, _ = policy.predict(obs)
                next_obs, reward, terminated, truncated, info = env.step(action)
                self.learner_buffer[env_id].add(obs, action, reward, next_obs, terminated, truncated)
                obs = next_obs
                done = terminated or truncated
                if done:
                    obs, info = env.reset()            
    
    def update_disc_vae(self) -> None:
        self.disc_net.train()
        for env_id in self.source_envs:
            self.encoders[env_id].train()
            self.decoders[env_id].train()     

        for step in range(self.config['disc_update_steps']):

            total_loss = th.zeros(1, device=self.config['device'])

            env_names = list(self.source_envs.keys())
            if np.random.rand() < 0.5:
                env_names = env_names[::-1]
            for env_id in env_names:

                self.encoders[env_id].train()
                self.decoders[env_id].train()

                expert_obs, expert_action, expert_next_obs, expert_reward, expert_termination, expert_truncated = self.expert_buffer.sample(env_id)
                learner_obs, learner_action, learner_next_obs, learner_reward, learner_termination, learner_truncated = self.learner_buffer[env_id].sample()

                # encode obs
                expert_obs_z, expert_obs_mu, expert_obs_logvar, expert_obs = self.encoders[env_id].forward(expert_obs)
                learner_obs_z, learner_obs_mu, learner_obs_logvar, learner_obs = self.encoders[env_id].forward(learner_obs)

                # decode obs
                expert_obs_recon = self.decoders[env_id].forward(expert_obs_z)
                learner_obs_recon = self.decoders[env_id].forward(learner_obs_z)

                # obs vae loss
                expert_obs_recon_loss, expert_obs_kld_loss = vae_loss(expert_obs_recon, expert_obs, expert_obs_mu, expert_obs_logvar)
                learner_obs_recon_loss, learner_obs_kld_loss = vae_loss(learner_obs_recon, learner_obs, learner_obs_mu, learner_obs_logvar)

                if self.config['current_obs_only']:  # only use current obs
                    expert_next_obs_z, expert_next_obs_mu, expert_next_obs_logvar, expert_next_obs = None, None, None, None
                    learner_next_obs_z, learner_next_obs_mu, learner_next_obs_logvar, learner_next_obs = None, None, None, None
                    
                    expert_next_obs_recon = None
                    learner_next_obs_recon = None
                else:
                    # encode next obs
                    expert_next_obs_z, expert_next_obs_mu, expert_next_obs_logvar, expert_next_obs = self.encoders[env_id].forward(expert_next_obs)
                    learner_next_obs_z, learner_next_obs_mu, learner_next_obs_logvar, learner_next_obs = self.encoders[env_id].forward(learner_next_obs)

                    # decode next obs
                    expert_next_obs_recon = self.decoders[env_id].forward(expert_next_obs_z)
                    learner_next_obs_recon = self.decoders[env_id].forward(learner_next_obs_z)

                # next obs vae loss, if expert_next_obs_recon is None, then the loss is 0
                expert_next_obs_recon_loss, expert_next_obs_kld_loss = vae_loss(expert_next_obs_recon, expert_next_obs, expert_next_obs_mu, expert_next_obs_logvar)
                learner_next_obs_recon_loss, learner_next_obs_kld_loss = vae_loss(learner_next_obs_recon, learner_next_obs, learner_next_obs_mu, learner_next_obs_logvar)                    
                
                # disc output
                expert_disc_logit = self.disc_net.forward(expert_obs_z, expert_next_obs_z)
                learner_disc_logit = self.disc_net.forward(learner_obs_z, learner_next_obs_z)

                # disc loss
                disc_loss = - expert_disc_logit.mean() + learner_disc_logit.mean()
                disc_gradient_penalty_loss = gradient_penalty(self.disc_net, expert_obs_z, learner_obs_z, expert_next_obs_z, learner_next_obs_z)
                    
                total_loss += disc_loss + self.config['disc_gradient_penalty_weight'] * disc_gradient_penalty_loss + \
                    self.config['vae_recon_weight'] * (expert_obs_recon_loss + learner_obs_recon_loss + expert_next_obs_recon_loss + learner_next_obs_recon_loss) + \
                    self.config['vae_kld_weight'] * (expert_obs_kld_loss + learner_obs_kld_loss + expert_next_obs_kld_loss + learner_next_obs_kld_loss)

            # cycle consistency loss
            env_names = list(self.source_envs.keys())
            for env_id_1, env_id_2 in [env_names, env_names[::-1]]:
                # Sample observations 
                expert_obs, _, _, _, _, _ = self.expert_buffer.sample(env_id_1)
                learner_obs, _, _, _, _, _ = self.learner_buffer[env_id_1].sample()

                # Encode observations to get z
                expert_z, expert_mu, expert_logvar, _ = self.encoders[env_id_1].forward(expert_obs)
                learner_z, learner_mu, learner_logvar, _ = self.encoders[env_id_1].forward(learner_obs)

                # Decode z using the decoder of the second environment
                with th.no_grad():  # no update on decoders
                    expert_recon_obs_cycle = self.decoders[env_id_2].forward(expert_z)
                    learner_recon_obs_cycle = self.decoders[env_id_2].forward(learner_z)

                # Encode recon_obs_cycle to get z_cycle
                expert_z_cycle, expert_mu_cycle, expert_logvar_cycle, _ = self.encoders[env_id_2].forward(expert_recon_obs_cycle)
                learner_z_cycle, learner_mu_cycle, learner_logvar_cycle, _ = self.encoders[env_id_2].forward(learner_recon_obs_cycle)

                expert_kl_div = kl_div(expert_mu_cycle, expert_logvar_cycle, expert_mu, expert_logvar)
                learner_kl_div = kl_div(learner_mu_cycle, learner_logvar_cycle, learner_mu, learner_logvar)                

                total_loss += self.config['cycle_consistency_weight'] * (expert_kl_div + learner_kl_div)

                # log
                self._update_tb_logs('disc_loss', disc_loss, env_id)
                self._update_tb_logs('vae_recon_loss', expert_obs_recon_loss + learner_obs_recon_loss + expert_next_obs_recon_loss + learner_next_obs_recon_loss, env_id)
                self._update_tb_logs('vae_kld_loss', expert_obs_kld_loss + learner_obs_kld_loss + expert_next_obs_kld_loss + learner_next_obs_kld_loss, env_id)
                self._update_tb_logs('disc_gradient_penalty_loss', disc_gradient_penalty_loss, env_id)
                self._update_tb_logs('cycle_consistency_loss', expert_kl_div + learner_kl_div, env_id)
            
            # update disc & vae 
            self.disc_net_optimzier.zero_grad()
            if not self.config['use_single_encoder']:
                [self.encoders_optimizer[env_id].zero_grad() for env_id in self.source_envs]
            else:
                self.encoders_optimizer[list(self.source_envs.keys())[0]].zero_grad()
            [self.decoders_optimizer[env_id].zero_grad() for env_id in self.source_envs]

            total_loss.backward()

            # TODO: clip gradients

            self.disc_net_optimzier.step()
            if not self.config['use_single_encoder']:
                [self.encoders_optimizer[env_id].step() for env_id in self.source_envs]
            else:
                self.encoders_optimizer[list(self.source_envs.keys())[0]].step()
            [self.decoders_optimizer[env_id].step() for env_id in self.source_envs]
        
        self.disc_net.eval()
        for env_id in self.source_envs:
            self.encoders[env_id].eval()
            self.decoders[env_id].eval()  
    
    def update_reward(self) -> None:
        self.reward_net.train()

        for step in range(self.config['reward_update_steps']):
            total_loss = th.zeros(1, device=self.config['device'])
            
            env_names = list(self.source_envs.keys())
            if np.random.rand() < 0.5:
                env_names = env_names[::-1]
            for env_id in env_names:
                expert_obs, expert_action, expert_next_obs, expert_reward, expert_termination, expert_truncated = self.expert_buffer.sample(env_id)
                learner_obs, learner_action, learner_next_obs, learner_reward, learner_termination, learner_truncated = self.learner_buffer[env_id].sample()

                if self.config['current_obs_only']:  # only use current obs
                    expert_next_obs = None
                    learner_next_obs = None

                # calculate reward
                expert_reward, expert_reward_logit = self.reward_net.forward(expert_obs, expert_next_obs, encoder=self.encoders[env_id])
                learner_reward, learner_reward_logit = self.reward_net.forward(learner_obs, learner_next_obs, encoder=self.encoders[env_id])

                # calculate disc output
                with th.no_grad():
                    # encode obs
                    expert_obs_z, expert_obs_mu, expert_obs_logvar, expert_obs = self.encoders[env_id].forward(expert_obs)
                    learner_obs_z, learner_obs_mu, learner_obs_logvar, learner_obs = self.encoders[env_id].forward(learner_obs)

                    if self.config['current_obs_only']:  # only use current obs
                        expert_next_obs_z = None
                        learner_next_obs_z = None
                    else:
                        # encode next obs
                        expert_next_obs_z, expert_next_obs_mu, expert_next_obs_logvar, expert_next_obs = self.encoders[env_id].forward(expert_next_obs)
                        learner_next_obs_z, learner_next_obs_mu, learner_next_obs_logvar, learner_next_obs = self.encoders[env_id].forward(learner_next_obs)

                    # disc output
                    expert_disc_logit = self.disc_net.forward(expert_obs_z, expert_next_obs_z)
                    learner_disc_logit = self.disc_net.forward(learner_obs_z, learner_next_obs_z)

                # calculate reward loss, cov between reward & disc
                expert_reward_cov = (expert_reward_logit * expert_disc_logit).mean() - expert_reward_logit.mean() * expert_disc_logit.mean()
                learner_reward_cov = (learner_reward_logit * learner_disc_logit).mean() - learner_reward_logit.mean() * learner_disc_logit.mean()

                # total reward loss
                total_loss = - expert_reward_cov - learner_reward_cov 

                # log
                self._update_tb_logs('reward_cov', expert_reward_cov + learner_reward_cov, env_id)
                self._update_tb_logs('expert_reward_logit', expert_reward_logit.mean(), env_id)
                self._update_tb_logs('learner_reward_logit', learner_reward_logit.mean(), env_id)
                self._update_tb_logs('expert_disc_logit', expert_disc_logit.mean(), env_id)
                self._update_tb_logs('learner_disc_logit', learner_disc_logit.mean(), env_id)
            
            # update reward net
            self.reward_net_optimizer.zero_grad()
            total_loss.backward()
            self.reward_net_optimizer.step()

        self.reward_net.eval()

    def _update_tb_logs(self, loss_name: str, loss: th.Tensor, env_id: str) -> None:
        if loss_name not in self.log_dict:
            self.log_dict[loss_name] = {}
        if env_id not in self.log_dict[loss_name]:
            self.log_dict[loss_name][env_id] = th.tensor([])
        self.log_dict[loss_name][env_id] = th.cat((self.log_dict[loss_name][env_id], loss.detach().cpu().unsqueeze(0)))

    def _dump_tb_logs(self) -> None:
        '''
        self.log_dict is a dictionary that contains the logs to be dumped to tensorboard
        {
            kld_loss: {
                env_id_1: th.Tensor,
                env_id_2: th.Tensor,
            },
            recon_loss: {
                env_id_1: th.Tensor,
                env_id_2: th.Tensor,
            },
        }
        '''
        timestep = self.policies[self.source_env_name[0]].num_timesteps
        for loss_name, loss_dict in self.log_dict.items():
            self.writer.add_scalars(f'{loss_name}', 
                                    {
                                        env_id: loss.mean() for env_id, loss in loss_dict.items()
                                    }, 
                                    timestep)
        
        self.log_dict = {}

    def train(self, epochs=1_000_000) -> None:
        
        # self.load(f'{os.getcwd()}/runs/trairl/HalfCheetah-v5/2025_05_06_19_54_50/saved_model/5000')

        for epoch in range(epochs):
            self.set_policy_training(True)            

            self.update_policies()

            self.update_learner_buffer()

            self.set_policy_training(False)            

            self.update_disc_vae()

            self.update_reward()

            self._dump_tb_logs()

            if epoch % 20 == 0 and epoch > 20:
                self.save()


if __name__ == "__main__":
    argparse = argparse.ArgumentParser()
    argparse.add_argument('--config', type=str, default='trairl_ant.yaml', help='config file')
    args = argparse.parse_args()

    config = yaml.safe_load(open(f"./config/{args.config}", "r"))

    # config = yaml.safe_load(open("./config/trairl_halfcheetah.yaml", "r"))
    # config = yaml.safe_load(open("./config/trairl_ant.yaml", "r"))

    trainer = Train_TraIRL(config)

    # if continue training from a checkpoint, uncomment the following line
    # trainer.load('runs/trairl/HalfCheetah-v5/2025_03_07_15_47_38/saved_model/305000')

    trainer.train(epochs=10000)

