import os
import gym
import argparse
import torch as th
from torch.utils.data import TensorDataset, DataLoader
import torch.nn as nn
import numpy as np
import os
import time
import random
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3 import SAC, SAC1
from stable_baselines3.common.utils import set_random_seed
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.logger import configure

from experiments.common.setup_experiment import setup_experiment, flush_logs, get_value_logger
from core.constraints import BaseConstraint, BoxConstraint
from core.flow.real_nvp import RealNvp
from core.flow.train_flow import update_flow_batch
from core.flow.constrained_distribution import ConstrainedDistribution

def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    th.manual_seed(seed)
    th.cuda.manual_seed(seed)
    th.backends.cudnn.deterministic = True
    th.backends.cudnn.benchmark = False

class action_decoder_network(nn.Module):
    def __init__(self, source_action_dim, target_action_dim, hidden_size, device):
        super(action_decoder_network, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(target_action_dim, hidden_size),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(hidden_size, hidden_size),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(hidden_size, source_action_dim),
            nn.Tanh()
        )
        self.device = device
        self.target_action_dim = target_action_dim

    def forward(self, target_action):
        source_action = self.model(target_action.float()) * 0.5
        return source_action

class decoder_network(nn.Module):
    def __init__(self, source_state_dim, target_state_dim, hidden_size, device):
        super(decoder_network, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(target_state_dim, hidden_size),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(hidden_size, hidden_size),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(hidden_size, source_state_dim),
            nn.Tanh()
        )
        self.device = device
        self.target_state_dim = target_state_dim

    def forward(self, target_state):
        source_state = self.model(target_state.float()) * 0.5
        return source_state

# Note: pybullet is not compatible yet with Gymnasium
# you might need to use `import rl_zoo3.gym_patches`
# and use gym (not Gymnasium) to instantiate the env
# Alternatively, you can use the MuJoCo equivalent "HalfCheetah-v4"
parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, nargs='?', default=2)
parser.add_argument("--decoder_lr", type=float, nargs='?', default=0.01)
parser.add_argument("--flow_lr", type=float, nargs='?', default=0.01)
parser.add_argument("--action_decoder_lr", type=float, nargs='?', default=0.01)
parser.add_argument("--source_state_dim", type=int, nargs='?', default=11)
parser.add_argument("--target_state_dim", type=int, nargs='?', default=13)
parser.add_argument("--source_action_dim", type=int, nargs='?', default=3)
parser.add_argument("--target_action_dim", type=int, nargs='?', default=4)
parser.add_argument("--hidden_size", type=int, nargs='?', default=256)
parser.add_argument("--device", type=str, nargs='?', default="cuda:3")
parser.add_argument("--train_sample_count", type=int, nargs='?', default=25000)
parser.add_argument("--folder", type=str, nargs='?', default='/home/')
parser.add_argument("--xml_file", type=str, nargs='?', default='hopper_target.xml')
parser.add_argument("--env", type=str, nargs='?', default='Hopper-v3')
args = parser.parse_args()

seed_everything(args.seed)
source_log_folder = args.folder + "source/stable-baselines3-1.7.0/logs/source/seed" + str(args.seed) + "/"
target_log_folder = args.folder + "target_domain/logs/bellman/seed" + str(args.seed) + "/"
flow_model = RealNvp.load_module(args.folder + "source/stable-baselines3-1.7.0/model/flow_seed" + str(args.seed) + ".pt").to(args.device)
action_flow_model = RealNvp.load_module(args.folder + "source/stable-baselines3-1.7.0/model_action/flow_seed" + str(args.seed) + ".pt").to(args.device)

source_env = DummyVecEnv([lambda: gym.make(args.env)]*1)
target_env = DummyVecEnv([lambda: gym.make(args.env, xml_file = args.folder + "assets/" + args.xml_file)]*1)
eval_env = DummyVecEnv([lambda: Monitor(gym.make(args.env, xml_file = args.folder + "assets/" + args.xml_file))]*1)

source_env.seed(seed=args.seed)
target_env.seed(seed=args.seed)
eval_env.seed(seed=args.seed)
set_random_seed(seed = args.seed)
new_logger = configure(target_log_folder, ["stdout", "csv", "tensorboard"])
eval_callback = EvalCallback(eval_env, best_model_save_path=target_log_folder, log_path=target_log_folder, eval_freq=500, deterministic=True, render=False, n_eval_episodes = 5)
# Automatically normalize the input features and reward

data_file = args.folder + "source/stable-baselines3-1.7.0/data/seed" + str(args.seed) + "/state.npy"
data = th.from_numpy(np.load(data_file)).double().to(args.device)
bound = np.zeros((2, args.source_state_dim))
train_data = data[:args.train_sample_count]
for i in range(len(train_data[0])):
    bound[0][i] = max(train_data[:, i])
    bound[1][i] = min(train_data[:, i])

decoder = decoder_network(args.source_state_dim, args.target_state_dim, args.hidden_size, args.device).to(args.device)
action_decoder = action_decoder_network(args.source_action_dim, args.target_action_dim, args.hidden_size, args.device).to(args.device)

source_model = SAC.load(source_log_folder + "best_model", device = args.device)
target_model = SAC1("MlpPolicy", env = target_env, SACmodel = source_model, decoder = decoder, flow_model = flow_model, action_decoder = action_decoder, decoder_lr = args.decoder_lr, action_decoder_lr = args.action_decoder_lr, flow_lr = args.flow_lr, action_flow_model = action_flow_model, bound = bound, verbose = 1, device = args.device, seed = args.seed)
target_model.set_logger(new_logger)
target_model.learn(total_timesteps=1000000, progress_bar = True, callback=eval_callback)