import os
import json
import random
import argparse
from os import path
from datetime import datetime
from tqdm import tqdm
from gym import spaces
from gym.wrappers.time_limit import TimeLimit
from stable_baselines3 import SAC, PPO, DDPG
from stable_baselines3.common.monitor import Monitor
import utils
import wrappers
import torch
import gym

from networks import Encoder, count_parameters, EncoderDecoupled, get_decoder
import wandb
from wandb.integration.sb3 import WandbCallback
from symmetryreg_wrapper import TransformationCodingWrapper, AutoencoderWrapper, SimCLRWrapper
from stable_baselines3.common.env_checker import check_env

# TODO: implement active and passive decomposition
# TODO: add my own modified envs to gym? e.g., change dt and gravity for pendulum.

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--img-w', type=int, default=48)
    parser.add_argument('--img-h', type=int, default=48)
    parser.add_argument('--code-size', type=int, default=32)
    parser.add_argument('--num-epochs', type=int, default=30)
    parser.add_argument('--batch-size', type=int, default=256)
    parser.add_argument('--frame-stack', type=int, default=4)
    parser.add_argument('--buffer-size', type=int, default=int(1e6))
    parser.add_argument('--num-channels', type=int, default=24)
    parser.add_argument('--warmup-steps', type=int, default=int(1e5))
    parser.add_argument('--mlp-hidden-dim', type=int, default=128)
    parser.add_argument('--batch-size-incr', type=int, default=0)
    parser.add_argument('--steps-per-epoch', type=int, default=100)
    parser.add_argument('--rl-training-steps', type=int, default=int(1e5))
    parser.add_argument('--num-actions-train', type=int, default=10)
    parser.add_argument('--num-discretized-actions', type=int, default=100)
    parser.add_argument('--update-prob', type=float, default=0.01)
    parser.add_argument('--weight-decay', type=float, default=1e-7)
    parser.add_argument('--hinge-thresh', type=float, default=5)
    parser.add_argument('--barrier-coef', type=float, default=1)
    parser.add_argument('--learning-rate', type=float, default=1e-3)
    parser.add_argument('--reset-noise-scale', type=float, default=0.1,
                        help='only works with certain mujoco environments')
    parser.add_argument('--normalization-steps', type=int, default=int(1e4))
    parser.add_argument('--env-name', type=str, default='InvertedPendulum')
    parser.add_argument('--rl-algo', type=str, default='PPO', choices=['SAC', 'PPO', 'DDPG'])
    parser.add_argument('--barrier-type', type=str, default='log', choices=['inv', 'log', 'id'])
    parser.add_argument('--grayscale', action='store_true')
    parser.add_argument('--overwrite', action='store_true')
    parser.add_argument('--cosine-sim', action='store_true')
    parser.add_argument('--conformal-map', action='store_true')
    parser.add_argument('--cpu', action='store_false', dest='gpu')
    parser.add_argument('--project', type=str, default="TransformationCoding")
    parser.add_argument('--entity', type=str, default="symmetry_group")
    parser.add_argument('--agent-save-dir', type=str, default="checkpoints/RL")
    parser.add_argument('--model-type', type=str, default='vanilla', choices=['symmetrypretrained', 'vanilla', 'symmetrydecoupled', 'autoencoderdecoupled', 'simCLRdecoupled'])
    parser.add_argument('--eval-freq', type=int, default=int(1e3))
    parser.add_argument('--eval-dir', type=str, default="results/RL")
    parser.add_argument('--num-eval-episodes', type=int, default=1)
    parser.add_argument('--temp', type=float, default=5e-1)

    args = parser.parse_args()

    args.project = args.project + '_mujoco'
    args.agent_save_dir = args.agent_save_dir + '/' + args.env_name + '/' + args.model_type + '/' + args.rl_algo
    args.eval_dir = args.eval_dir + '/' + args.env_name + '/' + args.model_type + '/' + args.rl_algo + '/seed_' + str(args.seed)

    if not os.path.isdir(args.agent_save_dir):
        os.makedirs(args.agent_save_dir)

    if not os.path.isdir(args.eval_dir):
        os.makedirs(args.eval_dir)

    print(datetime.now())
    print('args = %s' % json.dumps(vars(args), sort_keys=True, indent=4))

    utils.set_seed(args.seed)
    utils.virtual_display()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print('Device:', device)


    def make_env(env_id):
        if env_id in (
            'Ant',
            'Swimmer',
            'Humanoid',
            'Walker2d',
            'HalfCheetah',
            'Hopper',
        ):
            env = gym.make(env_id + '-v3')
        elif env_id in ('Reacher', 'InvertedPendulum', 'InvertedDoublePendulum'):
            env = gym.make(env_id + '-v2')
        env.seed(args.seed)
        return env

    #os.environ["WANDB_MODE"] = "offline"
    os.environ["WANDB_MODE"] = "disabled"

    run = wandb.init(
        project=args.project,
        config=vars(args),
        sync_tensorboard=True  # auto-upload sb3's tensorboard metrics
    )

    wandb.config.update(vars(args))
    env = make_env(args.env_name)
    env = wrappers.RenderObsWrapper(env)
    env = wrappers.PreprocessObsWrapper(env, args.img_w, args.img_h, args.grayscale)
    env = wrappers.FrameStack(env, num_stack=args.frame_stack)

    if args.model_type in ('vanilla', 'symmetrypretrained'):
        policy_kwargs = dict(
            features_extractor_class=Encoder,
            features_extractor_kwargs=dict(
                features_dim=args.code_size, channels_dim=args.num_channels
            ),
        )
        env = Monitor(env)
        model = eval(args.rl_algo)(
            'CnnPolicy', env,
            policy_kwargs=policy_kwargs,
            verbose=1,
            tensorboard_log=f"runs/{run.id}"
        )
        if args.model_type == 'vanilla':
            print('[%s] Started training RL agent' % datetime.now())
            print('Training for %d steps...' % args.rl_training_steps)
            model.learn(
                total_timesteps=args.rl_training_steps,
                callback=WandbCallback(
                    gradient_save_freq=1000,
                    model_save_path=f"models/{run.id}",
                    verbose=2,
                ),
                eval_env=env, eval_freq=args.eval_freq,
                n_eval_episodes=args.num_eval_episodes,
                eval_log_path=args.eval_dir
            )
            print('[%s] Finished training RL agent' % datetime.now())
            run.finish()
        elif args.model_type == 'symmetrypretrained':
            if args.rl_algo == 'SAC':
                enc = model.policy.actor.features_extractor
            elif args.rl_algo == 'PPO':
                enc = model.policy.features_extractor
            enc.to(device)
            print(enc)
            print('%d parameters' % count_parameters(enc))
            if os.path.exists(os.path.join(args.agent_save_dir, f'model_final_{args.seed}.tar')):
                enc.load_state_dict(torch.load(
                    os.path.join(args.agent_save_dir, f'model_final_{args.seed}.tar')
                ))
                print('Encoder Checkpoint Loaded for Training')

            if not isinstance(env.action_space, spaces.Discrete):
                env_pre = wrappers.DiscretizeActionsWrapper(
                    env, num_actions=args.num_discretized_actions
                )
            else:
                env_pre = env
            env_pre = wrappers.ReplayBufferWrapper(env_pre, buffer_size=args.buffer_size)
            env_pre = TransformationCodingWrapper(env_pre, args, enc, verbose=True)
            env_pre = TimeLimit(env_pre, max_episode_steps=100)

            print('[%s] Started pre-training the feature extractor' % datetime.now())
            print('Training for %d epochs...' % args.num_epochs)
            env_pre.reset()
            cur_epoch = 0
            while True:
                a = random.randrange(env_pre.action_space.n)
                _, _, done, info = env_pre.step(a)

                if done:
                    env_pre.reset()

                if info['epoch_cnt'] > cur_epoch:
                    wandb.log(
                        {
                            'epoch': info['epoch_cnt'],
                            'pretraining_loss': info['avg_loss_list'][-1]
                        }
                    )
                    cur_epoch += 1

                if info['epoch_cnt'] == args.num_epochs:
                    break

            print('[%s] Finished pre-training the feature extractor' % datetime.now())
            save_path = path.join(args.agent_save_dir, f'model_final_{args.seed}.tar')
            print('saving model to %s' % save_path)
            torch.save(enc.state_dict(), save_path)
            #utils.compare_models(enc, model.policy.features_extractor)
            print('[%s] Started training RL agent' % datetime.now())
            print('Training for %d steps...' % args.rl_training_steps)
            model.learn(
                total_timesteps=args.rl_training_steps,
                callback=WandbCallback(
                    gradient_save_freq=1000,
                    model_save_path=f"models/{run.id}",
                    verbose=2,
                ),
                eval_env=env, eval_freq=args.eval_freq,
                n_eval_episodes=args.num_eval_episodes,
                eval_log_path=args.eval_dir
            )
            print('[%s] Finished training RL agent' % datetime.now())
    elif args.model_type in ('symmetrydecoupled', 'autoencoderdecoupled', 'simCLRdecoupled'):
        enc = EncoderDecoupled(args, features_dim=args.code_size, channels_dim=args.num_channels)
        enc.to(device)
        print(enc)
        print('%d parameters' % count_parameters(enc))
        if os.path.exists(os.path.join(args.agent_save_dir, f'model_final_{args.seed}.tar')):
            enc.load_state_dict(torch.load(
            os.path.join(args.agent_save_dir, f'model_final_{args.seed}.tar')
            ))
            print('Encoder Checkpoint Loaded for Training')
        print('[%s] Started training representation' % datetime.now())
        print('Training for %d epochs...' % args.num_epochs)

        if args.model_type == 'autoencoderdecoupled':
            dec = get_decoder(args)
            dec.to(device)
            print(dec)
            print('%d parameters' % count_parameters(dec))
            env = wrappers.ReplayBufferAutoencoderWrapper(env, buffer_size=args.buffer_size)
            env = AutoencoderWrapper(env, args, enc, dec, verbose=True)
            env = TimeLimit(env, max_episode_steps=100)

        elif args.model_type == 'symmetrydecoupled':
            if not isinstance(env.action_space, spaces.Discrete):
                env = wrappers.DiscretizeActionsWrapper(env, num_actions=args.num_discretized_actions)
            env = wrappers.ReplayBufferWrapper(env, buffer_size=args.buffer_size)
            env = TransformationCodingWrapper(env, args, enc, verbose=True)
            env = TimeLimit(env, max_episode_steps=100)

        elif args.model_type == 'simCLRdecoupled':
            env = wrappers.TransitionBufferWrapper(env, buffer_size=args.buffer_size)
            env = SimCLRWrapper(env, args, enc, verbose=True)
            env = TimeLimit(env, max_episode_steps=100)

        env.reset()

        cur_epoch = 0
        while True:
            # a = random.randrange(env.action_space.n)
            a = env.action_space.sample()

            _, _, done, info = env.step(a)

            if done:
                env.reset()

            if info['epoch_cnt'] > cur_epoch:
                wandb.log(
                    {
                        'epoch': info['epoch_cnt'],
                        'pretraining_loss': info['avg_loss_list'][-1]
                    }
                )
                cur_epoch += 1

            if info['epoch_cnt'] == args.num_epochs:
                break

        print('[%s] Finished training representation' % datetime.now())

        print('[%s] Finished pre-training the feature extractor' % datetime.now())
        save_path = path.join(args.agent_save_dir, f'model_final_{args.seed}.tar')
        print('saving model to %s' % save_path)
        torch.save(enc.state_dict(), save_path)

        # ----------- normalize rep. -----------#
        env = make_env(args.env_name)
        env = wrappers.RenderObsWrapper(env)
        env = wrappers.PreprocessObsWrapper(env, args.img_w, args.img_h, args.grayscale)
        env = wrappers.FrameStack(env, num_stack=args.frame_stack)
        env = wrappers.EncodeObsWrapper(env, enc, args.code_size)
        env = wrappers.NormalizeObsWrapper(env)

        print('[%s] Started normalizing representation' % datetime.now())
        env.reset()

        for i in tqdm(range(args.normalization_steps)):
            a = env.action_space.sample()
            _, _, done, _ = env.step(a)

            if done:
                env.reset()

        print('[%s] Finished normalizing representation' % datetime.now())
        # --------------------------------------#

        # ----------- train RL agent -----------#
        model = eval(args.rl_algo)('MlpPolicy', env, verbose=1)
        print('[%s] Started training RL agent' % datetime.now())
        print('Training for %d steps...' % args.rl_training_steps)
        model.learn(total_timesteps=args.rl_training_steps,
                    callback=WandbCallback(
                        gradient_save_freq=1000,
                        model_save_path=f"models/{run.id}",
                        verbose=2,
                        ),
                    eval_env=env, eval_freq=args.eval_freq,
                    n_eval_episodes=args.num_eval_episodes,
                    eval_log_path=args.eval_dir
                    )
        print('[%s] Finished training RL agent' % datetime.now())
        # --------------------------------------#
