import os
import json
import random
import argparse
from os import path
from datetime import datetime
from tqdm import tqdm
from gym import spaces
from gym.wrappers.time_limit import TimeLimit
from stable_baselines3 import A2C, DQN, SAC, DDPG, TD3, PPO
import utils
import wrappers
import torch
import gym

from networks import count_parameters
import wandb
from wandb.integration.sb3 import WandbCallback
from symmetryreg_wrapper import TransformationCodingWrapper
from networks import Encoder

# TODO: implement active and passive decomposition
# TODO: add my own modified envs to gym? e.g., change dt and gravity for pendulum.

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--img-w', type=int, default=84)
    parser.add_argument('--img-h', type=int, default=84)
    parser.add_argument('--code-size', type=int, default=128)
    parser.add_argument('--num-epochs', type=int, default=100)
    parser.add_argument('--batch-size', type=int, default=256)
    parser.add_argument('--frame-stack', type=int, default=4)
    parser.add_argument('--buffer-size', type=int, default=int(1e6))
    parser.add_argument('--num-channels', type=int, default=64)
    parser.add_argument('--warmup-steps', type=int, default=int(1e5))
    parser.add_argument('--mlp-hidden-dim', type=int, default=128)
    parser.add_argument('--batch-size-incr', type=int, default=0)
    parser.add_argument('--steps-per-epoch', type=int, default=100)
    parser.add_argument('--rl-training-steps', type=int, default=int(1e6))
    parser.add_argument('--num-actions-train', type=int, default=10)
    parser.add_argument('--update-prob', type=float, default=0.01)
    parser.add_argument('--weight-decay', type=float, default=1e-7)
    parser.add_argument('--hinge-thresh', type=float, default=5)
    parser.add_argument('--barrier-coef', type=float, default=1)
    parser.add_argument('--learning-rate', type=float, default=1e-3)
    parser.add_argument('--env-name', type=str, default='Pong')
    parser.add_argument('--rl-algo', type=str, default='DQN', choices=['SAC', 'PPO', 'A2C', 'DDPG', 'TD3', 'DQN'])
    parser.add_argument('--barrier-type', type=str, default='log', choices=['inv', 'log', 'id'])
    parser.add_argument('--grayscale', action='store_true')
    parser.add_argument('--overwrite', action='store_true')
    parser.add_argument('--cosine-sim', action='store_true')
    parser.add_argument('--conformal-map', action='store_true')
    parser.add_argument('--cpu', action='store_false', dest='gpu')
    parser.add_argument('--project', type=str, default="TransformationCoding")
    parser.add_argument('--entity', type=str, default="symmetry_group")
    parser.add_argument('--agent-save-dir', type=str, default="checkpoints/RL")
    parser.add_argument('--model-type', type=str, default='vanilla', choices=['symmetrypretrained', 'vanilla'])
    args = parser.parse_args()

    args.project = args.project
    args.agent_save_dir = args.agent_save_dir + '/' + args.env_name + '/' + args.model_type
    if not os.path.isdir(args.agent_save_dir):
        os.makedirs(args.agent_save_dir)

    print(datetime.now())
    print('args = %s' % json.dumps(vars(args), sort_keys=True, indent=4))

    utils.set_seed(args.seed)
    utils.virtual_display()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print('Device:', device)

    #os.environ["WANDB_MODE"] = "offline"
    os.environ["WANDB_MODE"] = "disabled"

    run = wandb.init(
        project=args.project,
        config=vars(args),
        sync_tensorboard=True,  # auto-upload sb3's tensorboard metrics
        monitor_gym=True,  # auto-upload the videos of agents playing the game
    )

    wandb.config.update(vars(args))

    env = gym.make(utils.get_env_name(args.env_name))
    env = wrappers.PreprocessObsWrapper(env, args.img_w, args.img_h, args.grayscale)
    env = wrappers.FrameStack(env, num_stack=args.frame_stack)

    policy_kwargs = dict(
        features_extractor_class=Encoder,
        features_extractor_kwargs=dict(features_dim=args.code_size, channels_dim=args.num_channels),
    )

    model = eval(args.rl_algo)(
        'CnnPolicy', env,
        policy_kwargs=policy_kwargs,
        verbose=1,
        tensorboard_log=f"runs/{run.id}"
    )
    if args.model_type == 'vanilla':
        print('[%s] Started training RL agent' % datetime.now())
        print('Training for %d steps...' % args.rl_training_steps)
        model.learn(
            total_timesteps=args.rl_training_steps,
            callback=WandbCallback(
                gradient_save_freq=1000,
                model_save_path=f"models/{run.id}",
                verbose=2,
            ),
        )
        print('[%s] Finished training RL agent' % datetime.now())
        run.finish()
    elif args.model_type == 'symmetrypretrained':

        enc = model.policy.q_net.features_extractor
        enc.to(device)
        print(enc)
        print('%d parameters' % count_parameters(enc))
        if os.path.exists(os.path.join(args.agent_save_dir, f'model_final_{args.seed}.tar')):
            enc.load_state_dict(torch.load(
            os.path.join(args.agent_save_dir, f'model_final_{args.seed}.tar')
            ))
            print('Encoder Checkpoint Loaded for Training')

        env_pre = wrappers.ReplayBufferWrapper(env, buffer_size=args.buffer_size)
        env_pre = TransformationCodingWrapper(env_pre, args, enc, verbose=True)
        env_pre = TimeLimit(env_pre, max_episode_steps=100)

        print('[%s] Started pre-training the feature extractor' % datetime.now())
        print('Training for %d epochs...' % args.num_epochs)
        env_pre.reset()
        cur_epoch = 0
        while True:
            a = env_pre.action_space.sample()
            _, _, done, info = env_pre.step(a)

            if done:
                env_pre.reset()

            if info['epoch_cnt'] > cur_epoch:
                wandb.log(
                    {
                        'epoch': info['epoch_cnt'],
                        'pretraining_loss': info['avg_loss_list'][-1]
                    }
                )
                cur_epoch += 1
                save_path = path.join(args.agent_save_dir, f'model_final_{args.seed}.tar')
                print('saving model to %s' % save_path)
                torch.save(enc.state_dict(), save_path)

            if info['epoch_cnt'] == args.num_epochs:
                break

        print('[%s] Finished pre-training the feature extractor' % datetime.now())

        save_path = path.join(args.agent_save_dir, f'model_final_{args.seed}.tar')
        print('saving model to %s' % save_path)
        torch.save(enc.state_dict(), save_path)

        print('[%s] Started training RL agent' % datetime.now())
        print('Training for %d steps...' % args.rl_training_steps)
        model.learn(
            total_timesteps=args.rl_training_steps,
            callback=WandbCallback(
                gradient_save_freq=1000,
                model_save_path=f"models/{run.id}",
                verbose=2,
            ),
        )
        print('[%s] Finished training RL agent' % datetime.now())
        run.finish()

