import os
import gym
import argparse
import torch as th
from torch.utils.data import TensorDataset, DataLoader
import torch.nn as nn
import numpy as np
import os
import time
import random
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.sac.sac import SAC, SAC1
from stable_baselines3.common.utils import set_random_seed
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.logger import configure

from experiments.common.setup_experiment import setup_experiment, flush_logs, get_value_logger
from core.constraints import BaseConstraint, BoxConstraint
from core.flow.real_nvp import RealNvp
from core.flow.train_flow import update_flow_batch
from core.flow.constrained_distribution import ConstrainedDistribution
import copy
from robosuite.wrappers import GymWrapper
import robosuite as suite


def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    th.manual_seed(seed)
    th.cuda.manual_seed(seed)
    th.backends.cudnn.deterministic = True
    th.backends.cudnn.benchmark = False

class decoder_network(nn.Module):
    def __init__(self, source_state_dim, source_action_dim, target_state_dim, target_action_dim, device, useTanh=True, hidden_size=256):
        super(decoder_network, self).__init__()
        self.state_emb = nn.Sequential(
            nn.Linear(target_state_dim, hidden_size//2),
            nn.ELU(),
            nn.Linear(hidden_size//2, hidden_size//2),
            nn.ELU(),
        )
        self.action_emb = nn.Sequential(
            nn.Linear(target_action_dim, hidden_size//2),
            nn.ELU(),
            nn.Linear(hidden_size//2, hidden_size//2),
            nn.ELU(),
        )

        self.out_layer = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ELU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ELU(),
            nn.Linear(hidden_size, source_state_dim + source_action_dim),
        )
        self.useTanh = useTanh
        self.device = device
        self.source_state_dim = source_state_dim
        self.source_action_dim = source_action_dim

    def forward(self, target_state, target_action):
        state_emb = self.state_emb(target_state.float())
        action_emb = self.action_emb(target_action.float())
        input = th.cat((state_emb, action_emb), dim=1)
        output = self.out_layer(input)
        if self.useTanh:
            output = nn.Tanh()(output)
        return output
    

parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, nargs='?', default=2)
parser.add_argument("--decoder_lr", type=float, nargs='?', default=1e-4)
parser.add_argument("--device", type=str, nargs='?', default="cuda:0")
parser.add_argument("--folder", type=str, nargs='?', default='')
parser.add_argument("--src_robot", type=str, nargs='?', default='Panda')
parser.add_argument("--tar_robot", type=str, nargs='?', default='UR5e')
parser.add_argument("--tar_env", type=str, nargs='?', default='Door')
parser.add_argument("--src_env", type=str, nargs='?', default='Door') 
args = parser.parse_args()
seed_everything(args.seed)

source_log_folder = f'{args.folder}/source/logs/{args.env}/{args.xml_file}/{str(args.seed)}/'
flow_folder= f'{args.folder}/source/data/{args.env}/{args.seed}/'
target_log_folder = f'{args.folder}/target/logs/{args.env}/{args.xml_file}/{args.seed}/'

state_flow = RealNvp.load_module(f"{flow_folder}stateFlow/final.pt").to(args.device)
action_flow = RealNvp.load_module(f"{flow_folder}actionFlow/final.pt").to(args.device)
src_state_mean = th.tensor(np.load(f"{flow_folder}stateFlow/state_info.npz")['mean']).to(args.device)
src_state_std = th.tensor(np.load(f"{flow_folder}stateFlow/state_info.npz")['std']).to(args.device)

source_env = DummyVecEnv([lambda: GymWrapper(
            suite.make(
                args.src_env,
                robots=args.src_robot, 
                use_camera_obs=False, 
                has_offscreen_renderer=False, 
                has_renderer=False,  
                reward_shaping=True, 
                control_freq=20, 
            )
        )]*1)
target_env = DummyVecEnv([lambda: GymWrapper(
    suite.make(
        args.tar_env,
        robots=args.tar_robot, 
        use_camera_obs=False,
        has_offscreen_renderer=False, 
        has_renderer=False,
        reward_shaping=True, 
        control_freq=20, 
    )
)]*1)
eval_env = DummyVecEnv([lambda: Monitor(GymWrapper(
    suite.make(
        args.tar_env,
        robots=args.tar_robot,  
        use_camera_obs=False,  
        has_offscreen_renderer=False, 
        has_renderer=False,  
        reward_shaping=True, 
        control_freq=20,
    )
))]*1)


print("-"*50 +'\n')
print(args)
print("-"*50 + f"\nsource domain state dimesion: {source_env.observation_space.shape[0]}, action dimesion: {source_env.action_space.shape[0]} \
       \ntarget domain state dimesion: {target_env.observation_space.shape[0]}, action dimesion: {target_env.action_space.shape[0]} \n" + "-"*50)

target_env.seed(seed=args.seed)
eval_env.seed(seed=args.seed)
set_random_seed(seed = args.seed)

new_logger = configure(target_log_folder, ["stdout", "csv",])
eval_callback = EvalCallback(eval_env, best_model_save_path=target_log_folder, log_path=target_log_folder, eval_freq=2000, deterministic=True, render=False, n_eval_episodes = 5)

source_model = SAC.load(f'{source_log_folder}best_model', device = args.device)
decoder = decoder_network(source_state_dim=source_env.observation_space.shape[0], source_action_dim=source_env.action_space.shape[0],\
                          target_state_dim=target_env.observation_space.shape[0], target_action_dim=target_env.action_space.shape[0],\
                            device=args.device, useTanh=True).to(args.device)

target_model = SAC1("MlpPolicy", env = target_env, SACmodel = source_model, decoder = decoder, state_flow = state_flow, action_flow = action_flow, 
                     decoder_lr = args.decoder_lr, src_state_std = src_state_std, 
                     src_state_mean = src_state_mean, verbose = 1, device = args.device, seed = args.seed, gamma = 0.9)
target_model.set_logger(new_logger)
target_model.learn(total_timesteps=int(1e5), progress_bar = True, callback=eval_callback)
