import os
import json
import random
import argparse
from os import path
from datetime import datetime
from tqdm import tqdm
from stable_baselines3 import A2C, DQN, SAC, DDPG, TD3, PPO
import utils
import wrappers
import torch
import gym

from networks import Encoder, count_parameters, TransModelNormal
from symmetryreg_wrapper import TrainTransitionWrapper, process_samples
from networks import Encoder

# TODO: implement active and passive decomposition
# TODO: add my own modified envs to gym? e.g., change dt and gravity for pendulum.

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--img-w', type=int, default=84)
    parser.add_argument('--img-h', type=int, default=84)
    parser.add_argument('--code-size', type=int, default=128)
    parser.add_argument('--num-epochs', type=int, default=100)
    parser.add_argument('--batch-size', type=int, default=256)
    parser.add_argument('--frame-stack', type=int, default=4)
    parser.add_argument('--buffer-size', type=int, default=int(1e6))
    parser.add_argument('--num-channels', type=int, default=64)
    parser.add_argument('--warmup-steps', type=int, default=int(1e5))
    parser.add_argument('--mlp-hidden-dim', type=int, default=128)
    parser.add_argument('--steps-per-epoch', type=int, default=100)
    parser.add_argument('--update-prob', type=float, default=0.01)
    parser.add_argument('--weight-decay', type=float, default=1e-7)
    parser.add_argument('--learning-rate', type=float, default=1e-3)
    parser.add_argument('--env-name', type=str, default='Pong')
    parser.add_argument('--grayscale', action='store_true')
    parser.add_argument('--cpu', action='store_false', dest='gpu')
    parser.add_argument('--agent-save-dir', type=str, default="checkpoints/RL")
    parser.add_argument('--model-type', type=str, default='vanilla', choices=['symmetrypretrained', 'vanilla'])
    args = parser.parse_args()

    args.agent_save_dir = args.agent_save_dir + '/' + args.env_name + '/' + args.model_type

    print(datetime.now())
    print('args = %s' % json.dumps(vars(args), sort_keys=True, indent=4))

    utils.set_seed(args.seed)
    utils.virtual_display()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print('Device:', device)

    env = gym.make(utils.get_env_name(args.env_name))
    env = wrappers.PreprocessObsWrapper(env, args.img_w, args.img_h, args.grayscale)
    env = wrappers.FrameStack(env, num_stack=args.frame_stack)

    policy_kwargs = dict(
        features_extractor_class=Encoder,
        features_extractor_kwargs=dict(features_dim=args.code_size, channels_dim=args.num_channels),
    )

    model = DQN(
        'CnnPolicy', env,
        policy_kwargs=policy_kwargs,
        verbose=1
    )

    enc = model.policy.q_net.features_extractor
    enc.to(device)
    print(enc)
    print('%d parameters' % count_parameters(enc))
    if os.path.exists(os.path.join(args.agent_save_dir, f'model_final_{args.seed}.tar')):
        enc.load_state_dict(torch.load(
        os.path.join(args.agent_save_dir, f'model_final_{args.seed}.tar')
        ))
        print('Encoder Checkpoint Loaded for Evaluation')

    env_train = wrappers.TransitionBufferWrapper(env, buffer_size=args.buffer_size)

    trans = TransModelNormal(args, env.action_space.n)
    trans.to(device)
    if os.path.exists(os.path.join(args.agent_save_dir, f'model_transition_final_{args.seed}.tar')):
        trans.load_state_dict(torch.load(
        os.path.join(args.agent_save_dir, f'model_transition_final_{args.seed}.tar')
        ))
        print('Transition Function Checkpoint Loaded for Evaluation')
    env_train = TrainTransitionWrapper(env_train, args, enc, trans, verbose=True)

    print('[%s] Started training the transition model' % datetime.now())
    print('Training for %d epochs...' % args.num_epochs)
    env_train.reset()
    cur_epoch = 0
    while True:
        a = env_train.action_space.sample()
        _, _, done, info = env_train.step(a)

        if done:
            env_train.reset()

        if info['epoch_cnt'] > cur_epoch:
            cur_epoch += 1

        if info['epoch_cnt'] == args.num_epochs:
            buffer = info['buffers']
            break

    print('[%s] Finished training the transition model' % datetime.now())
    save_path = path.join(args.agent_save_dir, f'model_transition_final_{args.seed}.tar')
    print('saving transition model to %s' % save_path)
    torch.save(trans.state_dict(), save_path)

    print('collecting samples ..')
    buffer = wrappers.Buffer(10000)
    s = env.reset()
    episode_num = 0
    steps = 0
    while True:
        a = env.action_space.sample()
        s_, _, done, info = env.step(a)
        steps += 1

        if done or steps > 80:
            s = env.reset()
            steps = 0
            episode_num += 1
            print("Episode ", episode_num)
            continue

        if steps > 58:
            buffer.push([s, a, s_])

        if episode_num == 100:
            break

        s = s_

    print('[%s] Ranking Metric Evaluation' % datetime.now())

    print('[%s] Ranking Metric Evaluation like the paper' % datetime.now())
    import numpy as np
    enc.eval()
    trans.eval()
    MRR_list = []
    next_state = []
    pred_state = []
    T = 10
    random.shuffle(buffer.tolist())
    for i in tqdm(range(100)):
        #data = random.sample(buffer.tolist(), T)
        data = buffer.tolist()[i*T:(i+1)*T]
        s, a, s_ = process_samples(data, T, env.action_space.n, device)
        x, x_ = enc(s), enc(s_)
        next_state.append(x_.detach().cpu())
        x_next = trans.transition(x, a)
        pred_state.append(x_next.detach().cpu())

    next_state = torch.cat(next_state, dim=0)
    pred_state = torch.cat(pred_state, dim=0)
    total_samples = next_state.shape[0]
    ranks = np.empty(total_samples)
    for i in tqdm(range(total_samples)):
        dist = torch.linalg.norm(pred_state[i][None, :] - next_state, ord=2,
                                 dim=(-1)).detach().cpu().numpy()
        inds = np.argsort(dist)
        where = np.where(inds == i)
        ranks[i] = where[0][0]
        MRR_list.append(1.0 / (ranks[i] + 1))
    h1 = len(np.where(ranks < 2)[0]) / len(ranks)
    h5 = len(np.where(ranks < 6)[0]) / len(ranks)

    print("\n\n H@1:", h1)
    print("\n H@5:", h5)
    print("\n MRR:", np.mean(MRR_list))
