
import os
os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
os.environ['MUJOCO_GL'] = 'egl'
import datetime
from absl import app, flags
from ml_collections import config_flags
from dataclasses import dataclass
import wrappers
from dataset_utils import SNSD4RLDataset,SNSD4RLMixedDataset, split_into_trajectories
import sys
import warnings
warnings.filterwarnings('ignore', category=DeprecationWarning)
from pathlib import Path
import hydra
from omegaconf import DictConfig
import numpy as np
import torch
from trainer import TrainerSNS as TrainerDILO
import dilo_utils
from models import  TwinQ, ValueFunction, TwinV

from dilo import DILO
import json
from  policy import GaussianPolicy, DeterministicPolicy
import time
import gym
import d4rl
import d4rl.gym_mujoco
import d4rl.kitchen
from logging_utils.logx import EpochLogger
import matplotlib.pyplot as plt
torch.backends.cudnn.benchmark = True 

FLAGS = flags.FLAGS

flags.DEFINE_string('env_name', 'hopper-medium-v2', 'Environment name.')
flags.DEFINE_string('save_dir', './tmp/', 'Tensorboard logging dir.')
flags.DEFINE_string('exp_name', 'dump', 'Epoch logging dir.')
flags.DEFINE_integer('seed', 42, 'Random seed.')
flags.DEFINE_integer('expert_trajectories', 200, 'Number of expert trajectories')
flags.DEFINE_integer('eval_episodes', 10,
                     'Number of episodes used for evaluation.')
flags.DEFINE_integer('log_interval', 5000, 'Logging interval.')
flags.DEFINE_integer('eval_interval', 5000, 'Eval interval.')
flags.DEFINE_integer('batch_size', 1024, 'Mini batch size.')
flags.DEFINE_float('temp', 1.0, 'Loss temperature')
flags.DEFINE_boolean('double', True, 'Use double q-learning')
flags.DEFINE_integer('max_steps', int(3e5), 'Number of training steps.')
flags.DEFINE_boolean('tqdm', True, 'Use tqdm progress bar.')
flags.DEFINE_integer('sample_random_times', 0, 'Number of random actions to add to smooth dataset')
flags.DEFINE_boolean('grad_pen', False, 'Add a gradient penalty to critic network')
flags.DEFINE_float('lambda_gp', 1, 'Gradient penalty coefficient')
flags.DEFINE_float('max_clip', 7., 'Loss clip value')
flags.DEFINE_integer('num_v_updates', 1, 'Number of value updates per iter')
flags.DEFINE_boolean('log_loss', False, 'Use log gumbel loss')
flags.DEFINE_float('alpha', 0.8, 'f-maximization strength')
flags.DEFINE_float('beta', 0.1, 'imitation strength vs bellman strength')
flags.DEFINE_string('maximizer', 'smoothed_chi', 'Which f divergence to use')
flags.DEFINE_string('grad', 'full', 'Which f divergence to use')

flags.DEFINE_boolean('noise', False, 'Add noise to actions')
flags.DEFINE_float('noise_std', 0.1, 'Noise std for actions')

config_flags.DEFINE_config_file(
    'config',
    'configs/mujoco_config.py',
    'File path to the training hyperparameter configuration.',
    lock_config=False)



@dataclass(frozen=True)
class ConfigArgs:
    sample_random_times: int
    grad_pen: bool
    noise: bool
    noise_std: float
    lambda_gp: int
    max_clip: float
    num_v_updates: int
    log_loss: bool



def evaluate(agent, env,
             num_episodes, device,  verbose: bool = False,normalization_stats=None) :
    stats = {'return': [], 'length': []}

    for _ in range(num_episodes):
        observation, done = env.reset(), False

        while not done:
            if 'obs_mean' in normalization_stats:
                action = agent.act(torch.FloatTensor((observation-normalization_stats['obs_mean'])/normalization_stats['obs_std']).to(device), deterministic=True)
            else:
                action = agent.act(torch.FloatTensor(observation).to(device), deterministic=True)
            observation, _, done, info = env.step(action.detach().cpu().numpy())

        for k in stats.keys():
            stats[k].append(info['episode'][k]) 
            if verbose:
                v = info['episode'][k]
                print(f'{k}:{v}')

    for k, v in stats.items():
        stats[k] = np.mean(v)

    return stats


def normalize(dataset):

    trajs = split_into_trajectories(dataset.observations, dataset.actions,
                                    dataset.rewards, dataset.masks,
                                    dataset.dones_float,
                                    dataset.next_observations)

    def compute_returns(traj):
        episode_return = 0
        for _, _, rew, _, _, _ in traj:
            episode_return += rew

        return episode_return

    trajs.sort(key=compute_returns)

    dataset.rewards /= compute_returns(trajs[-1]) - compute_returns(trajs[0])
    dataset.rewards *= 1000.0


def make_env_and_dataset(env_name: str,
                         seed: int, normalize_obs=False):
    env = gym.make(env_name)
    env = wrappers.EpisodeMonitor(env)
    env = wrappers.SinglePrecision(env)

    env.seed(seed)
    env.action_space.seed(seed)
    env.observation_space.seed(seed)
    expert_dataset = None
    
    if 'kitchen' in env_name:
        expert_env = gym.make(f"kitchen-complete-v0")
        expert_env = wrappers.EpisodeMonitor(expert_env)
        expert_env = wrappers.SinglePrecision(expert_env)
        expert_dataset = SNSD4RLDataset(expert_env, transitions=183)
    elif 'halfcheetah-random' in env_name:
        expert_env = gym.make(f"halfcheetah-expert-v2")
        expert_env = wrappers.EpisodeMonitor(expert_env)
        expert_env = wrappers.SinglePrecision(expert_env)
        expert_dataset = SNSD4RLDataset(expert_env, transitions=1000)
    elif 'halfcheetah-medium' in env_name:
        expert_env = gym.make(f"halfcheetah-expert-v2")
        expert_env = wrappers.EpisodeMonitor(expert_env)
        expert_env = wrappers.SinglePrecision(expert_env)
        expert_dataset = SNSD4RLDataset(expert_env, transitions=1000)
    elif 'hopper-random' in env_name:
        expert_env = gym.make(f"hopper-expert-v2")
        expert_env = wrappers.EpisodeMonitor(expert_env)
        expert_env = wrappers.SinglePrecision(expert_env)
        expert_dataset = SNSD4RLDataset(expert_env, transitions=1000)
    elif "walker2d-random" in env_name:
        expert_env = gym.make(f"walker2d-expert-v2")
        expert_env = wrappers.EpisodeMonitor(expert_env)
        expert_env = wrappers.SinglePrecision(expert_env)
        expert_dataset = SNSD4RLDataset(expert_env, transitions=1000)
    elif "ant-random" in env_name:
        expert_env = gym.make(f"ant-expert-v2")
        expert_env = wrappers.EpisodeMonitor(expert_env)
        expert_env = wrappers.SinglePrecision(expert_env)
        expert_dataset = SNSD4RLDataset(expert_env, transitions=1000)
    elif "ant-medium" in env_name:
        expert_env = gym.make(f"ant-expert-v2")
        expert_env = wrappers.EpisodeMonitor(expert_env)
        expert_env = wrappers.SinglePrecision(expert_env)
        expert_dataset = SNSD4RLDataset(expert_env, transitions=1000)
    elif "door-human" in env_name:
        expert_env = gym.make(f"door-expert-v0")
        expert_env = wrappers.EpisodeMonitor(expert_env)
        expert_env = wrappers.SinglePrecision(expert_env)
        expert_dataset = SNSD4RLDataset(expert_env, transitions=200)
    elif "door-cloned" in env_name:
        expert_env = gym.make(f"door-expert-v0")
        expert_env = wrappers.EpisodeMonitor(expert_env)
        expert_env = wrappers.SinglePrecision(expert_env)
        expert_dataset = SNSD4RLDataset(expert_env, transitions=200)
    elif "hammer-human" in env_name:
        expert_env = gym.make(f"hammer-expert-v0")
        expert_env = wrappers.EpisodeMonitor(expert_env)
        expert_env = wrappers.SinglePrecision(expert_env)
        expert_dataset = SNSD4RLDataset(expert_env, transitions=200)
    elif "hammer-cloned" in env_name:
        expert_env = gym.make(f"hammer-expert-v0")
        expert_env = wrappers.EpisodeMonitor(expert_env)
        expert_env = wrappers.SinglePrecision(expert_env)
        expert_dataset = SNSD4RLDataset(expert_env, transitions=200)


    offline_min=None
    offline_max=None

    if 'kitchen' not in env_name:
        dataset = SNSD4RLMixedDataset(env, expert_env, expert_trajectories=FLAGS.expert_trajectories,env_name=env_name)
    else:
        
        dataset = SNSD4RLMixedDataset(env, expert_env, expert_trajectories=1,env_name=env_name) #D4RLDataset(env)
    print("Expert dataset size: {} Offline dataset size: {}".format(expert_dataset.observations.shape[0],dataset.observations.shape[0]))

    normalization_stats  = {}
    if normalize_obs:
        obs_mean = np.concatenate([expert_dataset.observations, dataset.observations]).mean(axis=0)
        obs_std = np.concatenate([expert_dataset.observations, dataset.observations]).std(axis=0)+1e-3
        dataset.observations = (dataset.observations-obs_mean)/obs_std
        expert_dataset.observations = (expert_dataset.observations-obs_mean)/obs_std
        dataset.next_observations = (dataset.next_observations-obs_mean)/obs_std
        expert_dataset.next_observations = (expert_dataset.next_observations-obs_mean)/obs_std
        normalization_stats['obs_mean'] = obs_mean
        normalization_stats['obs_std'] = obs_std

    return env, dataset, expert_dataset, offline_min, offline_max, normalization_stats


def main(_):

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    ts_str = datetime.datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d_%H-%M-%S")
    save_dir = os.path.join(FLAGS.save_dir, ts_str)
    exp_id = f"results/offline_imitation_sns/{FLAGS.env_name}/{FLAGS.expert_trajectories}_expert/observations/" + FLAGS.exp_name
    log_folder = exp_id + '/'+FLAGS.exp_name+'_s'+str(FLAGS.seed) 
    logger_kwargs={'output_dir':log_folder, 'exp_name':FLAGS.exp_name}
    
    e_logger = EpochLogger(**logger_kwargs)
    write_config = {}

    # Iterate through all flags and add them to the dictionary
    for flag_name in FLAGS.flags_by_module_dict()[sys.argv[0]]:
        key = flag_name.serialize().split('=')[0]
        if 'config' in key:
            continue
        value = flag_name.value

        write_config[key] = value
    write_config.update(dict(FLAGS.config))
    with open(log_folder+"/config.json", "w") as outfile: 
        json.dump(write_config, outfile,indent=4)
    hparam_str_dict = dict(seed=FLAGS.seed, env=FLAGS.env_name)
    hparam_str = ','.join([
        '%s=%s' % (k, str(hparam_str_dict[k]))
        for k in sorted(hparam_str_dict.keys())
    ])
    os.makedirs(save_dir, exist_ok=True)

    env, dataset, expert_dataset, offline_min, offline_max, normalization_stats = make_env_and_dataset(FLAGS.env_name, FLAGS.seed,normalize_obs=False)

    kwargs = dict(FLAGS.config)
    
    args = ConfigArgs(sample_random_times=FLAGS.sample_random_times,
                      grad_pen=FLAGS.grad_pen,
                      lambda_gp=FLAGS.lambda_gp,
                      noise=FLAGS.noise,
                      max_clip=FLAGS.max_clip,
                      num_v_updates=FLAGS.num_v_updates,
                      log_loss=FLAGS.log_loss,
                      noise_std=FLAGS.noise_std)
    obs_dim = env.observation_space.shape[0]
    act_dim = env.action_space.shape[0]
    if FLAGS.grad=='semi':
        agent = DILO(qf=TwinQ(state_dim=obs_dim,act_dim=obs_dim),vf = TwinV(state_dim=obs_dim),policy=GaussianPolicy(obs_dim,act_dim),
                                                        optimizer_factory=torch.optim.Adam,
                                                        tau=FLAGS.alpha, maximizer=FLAGS.maximizer, gradient_type=FLAGS.grad, beta=0.5, lr=3e-4, discount=0.99, alpha=0.005).to(device)
    else:
        agent = DILO(qf=TwinQ(state_dim=obs_dim,act_dim=obs_dim),vf = ValueFunction(state_dim=obs_dim),policy=DeterministicPolicy(obs_dim,act_dim),
                                                        optimizer_factory=torch.optim.Adam,
                                                        tau=FLAGS.alpha, maximizer=FLAGS.maximizer, gradient_type=FLAGS.grad, beta=0.5,use_twinV=True, lr=3e-4, discount=0.99, alpha=0.005).to(device)

    trainer = TrainerDILO()


    best_eval_returns = -np.inf
    eval_returns = []
    for i in range(1, FLAGS.max_steps + 1): # Remove TQDM
        batch = dataset.sample(FLAGS.batch_size)
        expert_batch = expert_dataset.sample(FLAGS.batch_size)
        
        update_info, st = trainer.update(agent, batch, expert_batch)

        if i % FLAGS.eval_interval == 0:

            eval_stats = evaluate(agent.policy, env, FLAGS.eval_episodes, device,normalization_stats=normalization_stats)


            if eval_stats['return'] >= best_eval_returns:
                # Store best eval returns
                best_eval_returns = eval_stats['return']

            e_logger.log_tabular('Iterations', i)
            e_logger.log_tabular('AverageNormalizedReturn', eval_stats['return'])
            e_logger.log_tabular('SeenExpertV', update_info['expert_v_val'])
            e_logger.log_tabular('SeenReplayV', update_info['replay_v_val'])
            e_logger.log_tabular('UnseenExpertV', update_info['unseen_expert_v_val'])
            e_logger.log_tabular('UnseenReplayV', update_info['unseen_replay_v_val'])
            e_logger.log_tabular('UnseenExpertPolW', update_info['unseen_expert_pol_weight'])
            e_logger.log_tabular('UnseenReplayPolW', update_info['unseen_replay_pol_weight'])
            e_logger.log_tabular('Policy Loss', update_info['policy_loss'])
            e_logger.dump_tabular()
            eval_returns.append((i, eval_stats['return']))
            print("Iterations: {} Average Return: {}".format(i,eval_stats['return']))




if __name__ == '__main__':
    app.run(main)
