import json
import os.path as osp, os
os.environ['MUJOCO_GL'] = "osmesa"

import tensorflow as tf
import numpy as np

from sac_models import StochasticActor, Critic, SAC
from samplers import Sampler
from PIL import Image

from utils import load_expert_trajectories, load_learner_trajectories
from utils import log_trajectory_statistics
from envs.envs import (ExpertInvertedPendulumEnv, ExpertInvertedDoublePendulumEnv)
from envs.more_envs import CustomReacher2Env, CustomReacher3Env
from envs.manipulation_envs import PusherEnv, PusherHumanSimEnv, ReachEnv, ReachHumanSimEnv
from buffers import LearnerAgentReplayBuffer, DemonstrationsReplayBuffer
from model_network import Encoder, Decoder, WGANdiscriminator, Labelnet, Labelnet_frame
from diffil_model import DIFFIL
import csv

# from utils import save_expert_trajectories, save_learner_trajectories

import argparse

# ==================================================
def parse_arguments():
    parser = argparse.ArgumentParser(description='Run experiment using DIFF-IL with given parameters file.')
    parser.add_argument('--env_name', help="The source environment name.")
    parser.add_argument('--env_type', help="The domain difference in the target environment.", default=None)
    parser.add_argument('--exp_name', help="The experiment name for logging", type=str, default=0)
    parser.add_argument('--gpu_id', help="Select positive number if using multiple gpus", type=int, default=0)
    parser.add_argument('--epochs', help="The number of training epoch", type=int, default=1000)
    parser.add_argument('--recon', help="Scale of recon loss", type=float, default=1)
    parser.add_argument('--fcon', help="Scale of feature consistency loss", type=float, default=1)
    parser.add_argument('--label_source', help="Scale of sequence source label", type=float, default=10)
    parser.add_argument('--label_target', help="Scale of sequence target label", type=float, default=0.001)
    parser.add_argument('--label_frame', help="Scale of frame label", type=float, default=10)
    parser.add_argument('--fwgan_gen', help="Scale of fwgan generator", type=float, default=1)
    parser.add_argument('--fwgan_disc', help="Scale of fwgan discriminator", type=float, default = 1)
    parser.add_argument('--fwgan_alpha', help="WGAN control coefficient", type=float, default=0.5)
    parser.add_argument('--model_num_per_epoch', help="Training Model per epoch", type=int, default=100)
    parser.add_argument('--model_RL_per_step', help="Training RL step per epoch", type=int, default=1000)

    args = parser.parse_args()
    args.algo = 'DIFF-IL'
    return args

def down_up_scale(img, downsize_shape):

    img_pil = Image.fromarray(img)
    down = img_pil.resize(downsize_shape, Image.BICUBIC)
    up = down.resize(img_pil.size, Image.BICUBIC)
    return np.asarray(up)  

def run_experiment(parsers):
    gpu_id = parsers.gpu_id
    gpu_allow_growth = True
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if len(gpus) > 0:
        tf.config.experimental.set_visible_devices(gpus[gpu_id], 'GPU')
        if gpu_allow_growth:
            try:
                for i in range(len(gpus)):
                    tf.config.experimental.set_memory_growth(gpus[i], True)
            except RuntimeError as e:
                print(e)
    file_location = 'expert_data'
    prior_file_location = 'prior_data'
    test_runs_per_epoch = 10
    init_random_samples = 1000

    # Learner policy parameter
    l_buffer_size = 50000
    l_batch_size = 256
    l_exploration_noise = 0.2
    l_learning_rate = 1e-3

    l_gamma = 0.99
    l_polyak = 0.995
    l_entropy_coefficient = 0.1
    l_tune_entropy_coefficient = True
    l_target_entropy = None
    l_clip_actor_gradients = False

    # Model Parameters
    env_name = parsers.env_name
    env_type = parsers.env_type
    n_expert_demos = 50000
    n_expert_prior_demos = 50000
    n_agent_prior_demos = 50000
    exp_num = parsers.exp_name
    epochs = parsers.epochs
    RL_num = parsers.model_RL_per_step
    recon_scale = parsers.recon
    feature_consistency_scale = parsers.fcon
    alpha = parsers.fwgan_alpha
    disc_loss_scale = parsers.fwgan_disc
    gen_loss_scale = parsers.fwgan_gen
    label_loss_source_scale = parsers.label_source
    label_loss_target_scale = parsers.label_target
    label_loss_frame = parsers.label_frame
    model_num_per_epoch = parsers.model_num_per_epoch

    use_resolution = False

    if env_name == 'InvertedPendulum-v2':
        im_side = 32
        im_shape = [im_side, im_side]
        episode_limit = 1000
        random_epi_limit = 50
        expert_prior_location = 'InvertedPendulum-v2'
        d_l_batch_size = d_e_batch_size = 128

        if env_type == 'to_two':
            env = ExpertInvertedDoublePendulumEnv()
            agent_prior_location = 'InvertedDoublePendulum-v2'
            agent_learner_location = 'InvertedDoublePendulum-v2_learner'
            exp_name = 'IPtoIDP/'
        else:
            raise NotImplementedError

    elif env_name == 'InvertedDoublePendulum-v2':
        im_side = 32
        im_shape = [im_side, im_side]
        episode_limit = 1000
        random_epi_limit = 50
        expert_prior_location = 'InvertedDoublePendulum-v2'
        d_l_batch_size = d_e_batch_size = 128

        if env_type == 'to_one':
            env = ExpertInvertedPendulumEnv()
            agent_prior_location = 'InvertedPendulum-v2'
            agent_learner_location = 'InvertedPendulum-v2_learner'
            exp_name = 'IDPtoIP/'
        else:
            raise NotImplementedError

    elif env_name == 'Reacher3-v2':
        im_side = 48
        im_shape = [im_side, im_side]
        episode_limit = 50
        random_epi_limit = 50
        expert_prior_location = 'Reacher3-v2'
        d_l_batch_size = d_e_batch_size = 64

        if env_type == 'agent' or env_type == 'to_two':
            env = CustomReacher2Env()
            agent_prior_location = 'Reacher2-v2'
            agent_learner_location = 'Reacher2-v2_learner'
            exp_name = 'ThreeReacher_custom/'
        else:
            raise NotImplementedError

    elif env_name == 'Reacher2-v2':
        im_side = 48
        im_shape = [im_side, im_side]
        episode_limit = 50
        random_epi_limit = 50
        expert_prior_location = 'Reacher2-v2'
        d_l_batch_size = d_e_batch_size = 64

        if env_type == 'agent' or env_type == 'to_three':
            agent_prior_location = 'Reacher3-v2'
            agent_learner_location = 'Reacher3-v2_learner'
            env = CustomReacher3Env()
            exp_name = 'ThreeReacher_custom/'
        else:
            raise NotImplementedError
    # ================================================== DMC
    elif env_name == 'DMPendulum':
        im_side = 32
        im_shape = [im_side, im_side]
        episode_limit = 1000
        random_epi_limit = 200
        expert_prior_location = 'DMPendulum'
        d_l_batch_size = d_e_batch_size = 128

        if env_type == 'to_cartpoleswingup':
            from envs.dmcontrol_envs import DMCartPoleSwingUpEnv
            env = DMCartPoleSwingUpEnv()
            exp_name = 'DMPendulum_to_cartpoleswingup/'
            agent_prior_location = 'DMCartPoleSwingUp'
            agent_learner_location = 'DMCartPoleSwingUp_learner'

        elif env_type == 'to_acrobot':
            from envs.dmcontrol_envs import DMAcrobotEnv
            env = DMAcrobotEnv()
            exp_name = 'DMPendulum_to_acrobot/'
            agent_prior_location = 'DMAcrobot'
            agent_learner_location = 'DMAcrobot_learner'
        else:
            raise NotImplementedError

    elif env_name == 'DMCartPoleSwingUp':
        im_side = 32
        im_shape = [im_side, im_side]
        expert_prior_location = 'DMCartPoleSwingUp'
        episode_limit = 1000
        random_epi_limit = 200
        d_l_batch_size = d_e_batch_size = 128

        if env_type == 'to_pendulum':
            from envs.dmcontrol_envs import DMPendulumEnv
            exp_name = 'DMCartPoleSwingUp_to_pendulum/'
            env = DMPendulumEnv()
            agent_prior_location = 'DMPendulum'
            agent_learner_location = 'DMPendulum_learner'

        elif env_type == 'to_acrobot':
            from envs.dmcontrol_envs import DMAcrobotEnv
            env = DMAcrobotEnv()
            exp_name = 'DMCartpoleswingup_to_acrobot/'
            agent_prior_location = 'DMAcrobot'
            agent_learner_location = 'DMAcrobot_learner'
        else:
            raise NotImplementedError

    elif env_name == 'DMWalker':
        im_side = 64
        im_shape = [im_side, im_side]
        expert_prior_location = 'DMWalker'
        episode_limit = 200
        random_epi_limit = 200
        d_l_batch_size = d_e_batch_size =64

        if env_type == 'to_cheetah':
            from envs.dmcontrol_envs import DMCheetahEnv
            env = DMCheetahEnv()
            exp_name = 'DMWalker_to_cheetah/'
            agent_prior_location = 'DMCheetah'
            agent_learner_location = 'DMCheetah_learner'

        elif env_type == 'to_hopper':
            from envs.dmcontrol_envs import DMHopperEnv
            env = DMHopperEnv()
            exp_name = 'DMWalker_to_hopper/'
            agent_prior_location = 'DMHopper'
            agent_learner_location = 'DMHopper_learner'
        else:
            raise NotImplementedError

    elif env_name == 'DMCheetah':
        im_side = 64
        im_shape = [im_side, im_side]
        expert_prior_location = 'DMCheetah'
        episode_limit = 200
        random_epi_limit = 200
        d_l_batch_size = d_e_batch_size =64

        if env_type == 'to_walker':
            from envs.dmcontrol_envs import DMWalkerEnv
            env = DMWalkerEnv()
            exp_name = 'DMCheetah_to_walker/'
            agent_prior_location = 'DMWalker'
            agent_learner_location = 'DMWalker_learner'

        elif env_type == 'to_hopper':
            from envs.dmcontrol_envs import DMHopperEnv
            env = DMHopperEnv()
            exp_name = 'DMCheetah_to_hopper/'
            agent_prior_location = 'DMHopper'
            agent_learner_location = 'DMHopper_learner'
        else:
            raise NotImplementedError

    elif env_name == 'DMHopper':
        im_side = 64
        im_shape = [im_side, im_side]
        expert_prior_location = 'DMHopper'
        episode_limit = 200
        random_epi_limit = 200
        d_l_batch_size = d_e_batch_size =64

        if env_type == 'to_walker':
            from envs.dmcontrol_envs import DMWalkerEnv
            env = DMWalkerEnv()
            exp_name = 'DMHopper_to_walker/'
            agent_prior_location = 'DMWalker'
            agent_learner_location = 'DMWalker_learner'

        elif env_type == 'to_cheetah':
            from envs.dmcontrol_envs import DMCheetahEnv
            env = DMCheetahEnv()
            exp_name = 'DMHopper_to_cheetah/'
            agent_prior_location = 'DMCheetah'
            agent_learner_location = 'DMCheetah_learner'
        else:
            raise NotImplementedError

    # ================================================== Robot Manipulation
    elif env_name == 'Pusher-v2':
        im_side = 48
        im_shape = [im_side, im_side]
        expert_prior_location = 'Pusher-v2'
        episode_limit = 200
        random_epi_limit = 200
        d_l_batch_size = d_e_batch_size = 64
        if env_type == 'expert':
            env = PusherEnv()
            agent_prior_location = 'Pusher-v2'
        elif env_type == 'to_human':
            env = PusherHumanSimEnv()
            agent_prior_location = 'PusherHumanSim-v2'
            agent_learner_location = 'PusherHumanSim-v2_learner'
            exp_name = 'Pusher_to_Human/'
        elif env_type == 'to_resolution':
            use_resolution = True
            env = PusherEnv()
            agent_prior_location = 'Pusher-v2'
            agent_learner_location = 'Pusher-v2_learner'
            exp_name = 'Pusherr_to_highres/'
        elif env_type == 'to_resolution_human':
            use_resolution = True
            env = PusherHumanSimEnv()
            agent_prior_location = 'PusherHumanSim-v2'
            agent_learner_location = 'PusherHumanSim-v2_learner'
            exp_name = 'Pusher_to_Human/'
        else:
            raise NotImplementedError

    elif env_name == 'PusherHumanSim-v2':
        im_side = 48
        im_shape = [im_side, im_side]
        expert_prior_location = 'PusherHumanSim-v2'
        episode_limit = 200
        random_epi_limit = 200
        d_l_batch_size = d_e_batch_size = 64
        if env_type == 'expert':
            env = PusherHumanSimEnv()
            agent_prior_location = 'PusherHumanSim-v2'
        elif env_type == 'to_robot':
            env = PusherEnv()
            agent_prior_location = 'Pusher-v2'
            agent_learner_location = 'Pusher-v2_learner'
            exp_name = 'Pusherhum_to_robot/'
        elif env_type == 'to_resolution':
            use_resolution = True
            env = PusherHumanSimEnv()
            agent_prior_location = 'PusherHumanSim-v2'
            agent_learner_location = 'PusherHumanSim-v2_learner'
            exp_name = 'Pusherhum_to_highres/'
        elif env_type == 'to_resolution_robot':
            use_resolution = True
            env = PusherEnv()
            agent_prior_location = 'Pusher-v2'
            agent_learner_location = 'Pusher-v2_learner'
            exp_name = 'Pusherhum_to_highres_robot/'
        else:
            raise NotImplementedError


    elif env_name == 'Reach-v2':
        im_side = 48
        im_shape = [im_side, im_side]
        expert_prior_location = 'Reach-v2'
        episode_limit = 200
        random_epi_limit = 200
        l_buffer_size = 50000
        d_l_batch_size = d_e_batch_size = 64
        if env_type == 'expert':
            env = ReachEnv()
            agent_prior_location = 'Reach-v2'
        elif env_type == 'to_human':
            env = ReachHumanSimEnv()
            agent_prior_location = 'ReachHumanSim-v2'
            agent_learner_location = 'ReachHumanSim-v2_learner'
            exp_name = 'Reach_to_Human/'
        elif env_type == 'expert':
            env = ReachEnv()
            agent_prior_location = 'Reach-v2'
        elif env_type == 'to_resolution':
            use_resolution = True
            env = ReachEnv()
            agent_prior_location = 'Reach-v2'
            agent_learner_location = 'Reach-v2_learner'
            exp_name = 'Reach_to_highres/'
        elif env_type == 'to_resolution_human':
            use_resolution = True
            env = ReachHumanSimEnv()
            agent_prior_location = 'ReachHumanSim-v2'
            agent_learner_location = 'ReachHumanSim-v2_learner'
            exp_name = 'Reach_to_highres_hum/'
        else:
            raise NotImplementedError

    elif env_name == 'ReachHumanSim-v2':
        im_side = 48
        im_shape = [im_side, im_side]
        expert_prior_location = 'ReachHumanSim-v2'
        episode_limit = 200
        random_epi_limit = 200
        d_l_batch_size = d_e_batch_size = 64
        if env_type == 'expert':
            env = ReachHumanSimEnv()
            agent_prior_location = 'ReachHumanSim-v2'
        elif env_type == 'to_robot':
            env = ReachEnv()
            agent_prior_location = 'Reach-v2'
            agent_learner_location = 'Reach-v2_learner'
            exp_name = 'Reachhum_to_robot/'
        elif env_type == 'to_resolution':
            use_resolution = True
            env = ReachHumanSimEnv()
            agent_prior_location = 'ReachHumanSim-v2'
            agent_learner_location = 'ReachHumanSim-v2_learner'
            exp_name = 'Reachhum_to_highres/'
        elif env_type == 'to_resolution_robot':
            use_resolution = True
            env = ReachEnv()
            agent_prior_location = 'Reach-v2'
            agent_learner_location = 'Reach-v2_learner'
            exp_name = 'Reachhum_to_highres_robot/'
        else:
            raise NotImplementedError
    else:
        raise NotImplementedError

    # Buffer Setup
    expert_buffer = DemonstrationsReplayBuffer(
        load_expert_trajectories(env_name, file_location, visual_data=True, load_ids=True,
                                 max_demos=n_expert_demos), episode_limit)
    expert_visual_data_shape = expert_buffer.get_random_batch(1)['ims'][0].shape
    print('Visual data shape: {}'.format(expert_visual_data_shape))
    past_frames = expert_visual_data_shape[0]
    print('Past frames: {}'.format(past_frames))

    source_random_buffer = DemonstrationsReplayBuffer(load_expert_trajectories(
        expert_prior_location, prior_file_location, visual_data=True, load_ids=True,
        max_demos=n_expert_prior_demos),random_epi_limit)

    target_random_buffer = DemonstrationsReplayBuffer(load_expert_trajectories(
        agent_prior_location, prior_file_location, visual_data=True, load_ids=True,
        max_demos=n_agent_prior_demos),random_epi_limit)

    action_size = env.action_space.shape[0]


    if use_resolution:
        org_img = expert_buffer.ims
        org_img_first = expert_buffer.first_ims
        downsize_shape = (32, 32)

        expert_buffer.ims = np.stack([down_up_scale(img, downsize_shape) for img in org_img])
        expert_buffer.first_ims = np.stack([down_up_scale(img, downsize_shape) for img in org_img_first])

        org_img = source_random_buffer.ims
        org_img_first = source_random_buffer.first_ims

        source_random_buffer.ims = np.stack([down_up_scale(img, downsize_shape) for img in org_img])
        source_random_buffer.first_ims = np.stack([down_up_scale(img, downsize_shape) for img in org_img_first])

    # SAC Setup
    def make_actor():
        actor = StochasticActor([tf.keras.layers.Dense(256, 'relu', kernel_initializer='orthogonal'),
                                 tf.keras.layers.Dense(256, 'relu', kernel_initializer='orthogonal'),
                                 tf.keras.layers.Dense(action_size * 2,
                                                       kernel_initializer=tf.keras.initializers.Orthogonal(0.01))])
        return actor

    def make_critic():
        critic = Critic([tf.keras.layers.Dense(256, 'relu', kernel_initializer='orthogonal'),
                         tf.keras.layers.Dense(256, 'relu', kernel_initializer='orthogonal'),
                         tf.keras.layers.Dense(1,
                                               kernel_initializer=tf.keras.initializers.Orthogonal(0.01))])
        return critic

    if l_target_entropy is None:
        l_target_entropy = -1 * (np.prod(env.action_space.shape))

    feature_size = 32

    # Model Setup
    def make_encoder():
        myim_shape = [None,4,im_shape[0],im_shape[0],3]
        enc_network = Encoder(feature_size, myim_shape)

        return enc_network

    def make_WGANdiscriminator():
        disc = WGANdiscriminator(feature_size, alpha)

        return disc

    def make_labelnet_seq():
        disc = Labelnet(feature_size)

        return disc

    def make_labelnet_frame():
        disc = Labelnet_frame(feature_size)

        return disc

    def make_decoder():
        img_shape = [None, feature_size]
        recon_shape = [None, im_shape[0]]

        source_decoder = Decoder(img_shape,recon_shape)
        target_decoder = Decoder(img_shape,recon_shape)

        return source_decoder, target_decoder

    log_dir = osp.join('experiments_data/', '{}/{}'.format(exp_name, exp_num))
    os.makedirs(log_dir, exist_ok=False)

    # CSV
    csv_file = open(osp.join(log_dir, 'progress.csv'), 'w', newline='')
    csv_fieldnames = ['Iteration', 'Steps', 'n', 'mean', 'max', 'min', 'std']
    csv_writer = csv.DictWriter(csv_file, csv_fieldnames)
    csv_writer.writeheader()
    csv_file.flush()

    l_optimizer = tf.keras.optimizers.Adam(l_learning_rate)
    l_agent = SAC(make_actor=make_actor,
                  make_critic=make_critic,
                  make_critic2=make_critic,
                  actor_optimizer=l_optimizer,
                  critic_optimizer=l_optimizer,
                  gamma=l_gamma,
                  polyak=l_polyak,
                  entropy_coefficient=l_entropy_coefficient,
                  tune_entropy_coefficient=l_tune_entropy_coefficient,
                  target_entropy=l_target_entropy,
                  clip_actor_gradients=l_clip_actor_gradients)

    sampler = Sampler(env, episode_limit, init_random_samples, visual_env=True)

    gail = DIFFIL(agent=l_agent,
                        batch_size = d_l_batch_size,
                        make_decoder = make_decoder,
                        make_encoder=make_encoder,
                        make_label = make_labelnet_seq,
                        make_label_frame = make_labelnet_frame,
                        make_fwgan = make_WGANdiscriminator,
                        expert_buffer=expert_buffer,
                        source_random_buffer=source_random_buffer,
                        target_random_buffer=target_random_buffer,
                        feature_size = feature_size,
                        recon_loss = recon_scale,
                        feature_consistency_loss = feature_consistency_scale,
                        disc_loss = disc_loss_scale,
                        gen_loss = gen_loss_scale,
                        seq_label_source = label_loss_source_scale,
                        seq_label_target = label_loss_target_scale,
                        frame_label_loss = label_loss_frame,
                        sampler = sampler)

    agent_buffer = LearnerAgentReplayBuffer(gail, l_buffer_size, episode_limit,initial_data=load_learner_trajectories(
                                                agent_learner_location, prior_file_location, visual_data=True,
                                                load_ids=True,
                                                max_demos=l_buffer_size))


    test_input = expert_buffer.get_random_batch(1)
    test_input['obs'] = np.expand_dims((env.reset()).astype('float32'), axis=0)
    gail(test_input['ims'])
    gail.summary()

    mean_test_returns = []
    mean_test_std = []
    steps = []

    step_counter = 0
    print('Epoch {}/{} - total steps {}'.format(0, epochs, step_counter))
    out = sampler.evaluate(l_agent, test_runs_per_epoch, False)
    mean_test_returns.append(out['mean'])
    mean_test_std.append(out['std'])
    steps.append(step_counter)

    sampling_num = 1000

    # Evaluation
    out = sampler.evaluate(l_agent, test_runs_per_epoch, False, get_ims=False)
    mean_test_returns.append(out['mean'])
    mean_test_std.append(out['std'])
    steps.append(step_counter)

    csv_write_dict = dict()
    csv_write_dict["Iteration"] = 0
    csv_write_dict["Steps"] = step_counter
    csv_write_dict["n"] = out['n']
    csv_write_dict["mean"] = out['mean']
    csv_write_dict["max"] = out['max']
    csv_write_dict["min"] = out['min']
    csv_write_dict["std"] = out['std']
    csv_writer.writerow(csv_write_dict)
    csv_file.flush()

    # Training
    for e in range(epochs):
        for i in range(sampling_num // episode_limit):
            traj_data = sampler.sample_trajectory(l_agent, l_exploration_noise)

            agent_buffer.add(traj_data)
            step_counter += traj_data['n']
        print("buffer update")
        train_model = gail.train(agent_buffer=agent_buffer,
                   l_batch_size=l_batch_size,
                   l_updates=RL_num,
                   l_act_delay=1,
                   d_updates=model_num_per_epoch)

        print('Epoch {}/{} - total steps {}'.format(e + 1, epochs, step_counter))
        traj_test = sampler.sample_test_trajectories(l_agent, 0.0, test_runs_per_epoch)

        out = log_trajectory_statistics(traj_test['ret'], False)

        csv_write_dict = dict()
        csv_write_dict["Iteration"] = e + 1
        csv_write_dict["Steps"] = step_counter
        csv_write_dict["n"] = out['n']
        csv_write_dict["mean"] = out['mean']
        csv_write_dict["max"] = out['max']
        csv_write_dict["min"] = out['min']
        csv_write_dict["std"] = out['std']
        csv_writer.writerow(csv_write_dict)
        csv_file.flush()

        print("="*50)
        print("Iteration : ", e + 1)
        print("Steps : ",step_counter)
        print("Episode_return : ", out['mean'])
        print("Episode_std:", out['std'])
        print("Episode_max : ", out['max'])
        print("Episode_min:", out['min'])
        print("="*50)

    return gail, sampler
# ==================================================
def main():
    """
    Run experiment for proposed imitation
    """

    # Parse arguments
    args = parse_arguments()

    # Run experiment
    model, sampler = run_experiment(args)


if __name__ == "__main__":
    main()