import datetime
import os
import pprint
import re
import sys
import time
import uuid
from copy import deepcopy
import argparse

import absl.app
import absl.flags
import d4rl
import gym
import numpy as np
import torch
import wandb
from tqdm import trange
import GPUtil

from b4mrl import get_new_density_env, get_new_friction_env, get_new_gravity_env
from mixed_replay_buffer import MixedReplayBuffer
from model import FullyConnectedQFunction, SamplerPolicy, TanhGaussianPolicy
from sampler import StepSampler, TrajSampler
from sim2real_sac import Sim2realSAC
from utils_h2o import (Timer, WandBLogger, define_flags_with_default,
                get_user_flags, prefix_metrics, print_flags,
                set_random_seed)
from configs.train_h20_defaults import get_cfg_defaults
from sim_envs.get_sim import get_sim_env
from utils.env_utils import get_gt_env, get_d4rl_normalized_score, get_eval_statistics
from utils.utils import get_date_time_str, save_config


sys.path.append("..")

from Network.Weight_net import ConcatDiscriminator


def main(config):
    deviceIds = GPUtil.getFirstAvailable(order='memory', maxLoad=0.95, maxMemory=0.95)
    device = torch.device(f'cuda:{deviceIds[0]}')

    save_path = os.path.join(config.agent_path, config.wandb.name, get_date_time_str())
    os.makedirs(save_path)
    config_dict = save_config(config, os.path.join(save_path, 'config.yaml'))

    set_random_seed(config.system.seed)

    if config.wandb.enable:
        wandb.config.update(config_dict)

    # Original code changed from here
    real_env = get_gt_env(config.env.type)
    sim_env = get_sim_env(config.simulator.type, config.simulator)

    # a step sampler for "simulated" training
    train_sampler = StepSampler(sim_env, config.algorithm.max_traj_length)
    # a trajectory sampler for "real-world" evaluation
    eval_sampler = TrajSampler(real_env, config.algorithm.max_traj_length)

    # replay buffer
    num_state = real_env.observation_space.shape[0]
    num_action = real_env.action_space.shape[0]
    replay_buffer = MixedReplayBuffer(config.algorithm.reward_scale, config.algorithm.reward_bias,
                                      config.algorithm.clip_action, num_state, num_action,
                                      data_path=config.data.path, use_d4rl=config.data.use_d4rl, device=device,
                                      buffer_ratio=config.algorithm.replaybuffer_ratio,
                                      residual_ratio=config.algorithm.real_residual_ratio,
                                      noise_coeff=config.data.noise_coeff)

    # discriminators
    d_sa = ConcatDiscriminator(num_state + num_action, 256, 2, device, dropout=config.algorithm.dis_dropout).float().to(
        device)
    d_sas = ConcatDiscriminator(2 * num_state + num_action, 256, 2, device, dropout=config.algorithm.dis_dropout).float().to(
        device)

    # agent
    policy = TanhGaussianPolicy(
        eval_sampler.env.observation_space.shape[0],
        eval_sampler.env.action_space.shape[0],
        arch=config.algorithm.policy_arch,
        log_std_multiplier=config.algorithm.policy_log_std_multiplier,
        log_std_offset=config.algorithm.policy_log_std_offset,
        orthogonal_init=config.algorithm.orthogonal_init,
    )

    qf1 = FullyConnectedQFunction(
        eval_sampler.env.observation_space.shape[0],
        eval_sampler.env.action_space.shape[0],
        arch=config.algorithm.qf_arch,
        orthogonal_init=config.algorithm.orthogonal_init,
    )
    target_qf1 = deepcopy(qf1)

    qf2 = FullyConnectedQFunction(
        eval_sampler.env.observation_space.shape[0],
        eval_sampler.env.action_space.shape[0],
        arch=config.algorithm.qf_arch,
        orthogonal_init=config.algorithm.orthogonal_init,
    )
    target_qf2 = deepcopy(qf2)

    cql = Sim2realSAC.get_default_config()

    if cql.target_entropy >= 0.0:
        cql.target_entropy = -np.prod(eval_sampler.env.action_space.shape).item()

    sac = Sim2realSAC(cql, policy, qf1, qf2, target_qf1, target_qf2, d_sa, d_sas, replay_buffer,
                      device, dynamics_model=None)
    sac.torch_to_device(device)

    # sampling policy is always the current policy: \pi
    sampler_policy = SamplerPolicy(policy, device)

    # viskit_metrics = {}

    # train and evaluate for n_epochs
    for epoch in trange(config.algorithm.n_epochs):
        metrics = {}

        # TODO rollout from the simulator
        with Timer() as rollout_timer:
            # rollout and append simulated trajectories to the replay buffer
            train_sampler.sample(
                sampler_policy, config.algorithm.n_rollout_steps_per_epoch,
                deterministic=False, replay_buffer=replay_buffer, joint_noise_std=config.algorithm.joint_noise_std
            )
            metrics['epoch'] = epoch

        # TODO Train from the mixed data
        with Timer() as train_timer:
            for batch_idx in trange(config.algorithm.n_train_step_per_epoch):
                real_batch_size = int(config.algorithm.batch_size * (1 - config.algorithm.batch_ratio))
                sim_batch_size = int(config.algorithm.batch_size * config.algorithm.batch_ratio)
                if batch_idx + 1 == config.algorithm.n_train_step_per_epoch:
                    metrics.update(
                        prefix_metrics(sac.train(real_batch_size, sim_batch_size), 'sac')
                    )
                else:
                    sac.train(real_batch_size, sim_batch_size)


        # TODO Evaluate in the real world
        with Timer() as eval_timer:
            if epoch == 0 or (epoch + 1) % config.algorithm.eval_period == 0:
                trajs = eval_sampler.sample(
                    sampler_policy, config.algorithm.eval_n_trajs, deterministic=True
                )

                eval_dsa_loss, eval_dsas_loss = sac.discriminator_evaluate()
                metrics['eval_dsa_loss'] = eval_dsa_loss
                metrics['eval_dsas_loss'] = eval_dsas_loss

                all_rewards = [np.sum(t['rewards']) for t in trajs]
                avg_reward, std_reward, avg_norm_reward, std_norm_reward = get_eval_statistics(all_rewards,  config.env.type)

                current_time_step = epoch * config.algorithm.n_train_step_per_epoch
                eval_episodes = config.algorithm.eval_n_trajs
                print("---------------------------------------")
                print(f"Epoch {current_time_step}: Evaluation over {eval_episodes} episodes: {avg_reward:.3f} +- {std_reward:.3f}, Normalized score = {avg_norm_reward:.3f} +- {std_norm_reward:.3f}")
                print("---------------------------------------")
                if config.wandb.enable:
                    metrics.update(
                        {'eval mean reward': avg_reward, 'avg_norm_reward': avg_norm_reward,
                         'eval std reward': std_reward,
                         'epochs': current_time_step})

                if config.save_model:
                    save_data = {'sac': sac, 'epoch': epoch}
                    torch.save(save_data, os.path.join(save_path, 'model.pkl'))

        # metrics['rollout_time'] = rollout_timer()
        # metrics['train_time'] = train_timer()
        # metrics['eval_time'] = eval_timer()
        # metrics['epoch_time'] = rollout_timer() + train_timer() + eval_timer()

        if config.wandb.enable:
            wandb.log(metrics)

    if config.save_model:
        save_data = {'sac': sac, 'epoch': epoch}
        torch.save(save_data, os.path.join(save_path, 'model.pkl'))


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="parse args")
    parser.add_argument('--config-file', type=str, default=None)
    # parser.add_argument('--config-list', nargs="+", default=None)

    parser.add_argument('--no-wandb', action='store_true')
    parser.add_argument('--seed', type=int, default=0)

    args = parser.parse_args()

    config_dict = {'system.seed': args.seed,
                   'wandb.enable': not args.no_wandb}

    print('New run with parameters:')
    print(config_dict)

    config_list = []
    for key in config_dict:
        if config_dict[key] is not None:
            config_list.append(key)
            config_list.append(config_dict[key])

    config = get_cfg_defaults(args.config_file, config_list)

    if config.wandb.enable:
        wandb.init(project=config.wandb.project_name, entity="mechanistic_offline_rl", config={})

    main(config)