#!/usr/bin/env python


import argparse
import json
import matplotlib
display = False
if not display:
    matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL.Image
import pprint
import sys
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from tqdm import tqdm

from .ataritools.envs import MsPacmanEnv
from .ataritools.replaybuffer import ReplayBuffer
from . import utils
from .autoencoder import Autoencoder

if display:
    plt.ion()
else:
    os.makedirs('images', exist_ok=True)

def after_each_interval(epoch, interval):
    return (epoch) % interval == 0

def train(seed, lr, alpha, beta, n_steps, norm, quick):
    os.makedirs('models', exist_ok=True)

    if quick:
        print('Using quick training settings.')
    batch_size = 32 if not quick else 2
    input_shape = (1,84,84)
    grayscale = True
    replay_capacity = 25000 if not quick else 2500 # ~500 frames/episode
    eval_capacity = 2000 if not quick else 500
    n_epochs = 200000 if not quick else 1000

    loss_print_interval = 200 if not quick else 10
    eval_interval = 2000 if not quick else 100
    save_interval = 10000 if not quick else 500

    params = {
        'alpha': alpha,
        'beta': beta,
        'batch_size': batch_size,
        'eval_capacity': eval_capacity,
        'eval_interval': eval_interval,
        'input_shape': input_shape,
        'grayscale': grayscale,
        'lr': lr,
        'loss_print_interval': loss_print_interval,
        'n_epochs': n_epochs,
        'n_steps': n_steps,
        'norm': norm,
        'quick': quick,
        'replay_capacity': replay_capacity,
        'save_interval': save_interval,
        'seed': seed,
    }

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    utils.reset_seeds(seed=seed)

    net = Autoencoder(input_shape=input_shape, learning_rate=lr, alpha=alpha, beta=beta, norm=norm)
    net.summary()

    env = MsPacmanEnv(grayscale=grayscale)

    replay = ReplayBuffer(capacity=replay_capacity)
    eval_replay = ReplayBuffer(capacity=eval_capacity)

    eval_seed = replay.fill(env, seed)
    eval_replay.fill(env, eval_seed)
    eval_obs = torch.stack(eval_replay.obs_buffer).to(device)

    mu = torch.mean(torch.stack(replay.obs_buffer), dim=0)
    net.mu = mu
    net.to(device)

    log_dir = utils.create_log_dir(name=net.name)
    train_log = open(log_dir+'/train.txt', 'w')
    eval_log = open(log_dir+'/eval.txt', 'w')
    with open(log_dir+'/params.txt', 'w') as param_log:
        param_log.write(repr(params)+'\n')

    running_losses = None
    running_mse = 0
    running_temporal_loss = 0 # temporal smoothing loss (without multiplication by alpha)
    running_xcorr_loss = 0 # mean cross-correlation loss (without multiplication by beta)
    n_updates = 0
    for epoch in tqdm(range(n_epochs+1), desc='epoch'):
        obs, next_obs = replay.sample(n_steps=n_steps, batch_size=batch_size)
        obs = torch.stack(tuple(obs))
        next_obs = torch.stack(tuple(next_obs))
        obs = obs.to(device)
        next_obs = next_obs.to(device)

        losses = net.train_batch(obs, next_obs)
        if running_losses is None:
            running_losses = [0 for l in losses]
        for i,l in enumerate(losses):
            running_losses[i] += l
        n_updates += 1

        if after_each_interval(epoch, loss_print_interval):
            loss_info = {
                'epoch': epoch,
                'total_loss': running_losses[0] / loss_print_interval,
                'mse': (running_losses[1]+running_losses[2]) / loss_print_interval,
                'temporal': running_losses[3] / loss_print_interval,
                'x_corr': running_losses[4] / loss_print_interval,
                'pixerr': running_losses[1] / loss_print_interval * 84 * 84,
            }
            json_str = json.dumps(loss_info)
            tqdm.write(pprint.pformat(loss_info))
            if not quick:
                train_log.write(json_str+'\n')
                train_log.flush()
            for i in range(len(running_losses)):
                running_losses[i] = 0

        if after_each_interval(epoch, eval_interval):
            with torch.no_grad():
                mse = net.reconstruction_error(eval_obs)
            eval_info = {
                'epoch': epoch,
                'mse': mse,
                'pixerr': mse*84*84
            }
            json_str = json.dumps(eval_info)
            if not quick:
                eval_log.write(json_str+'\n')
                eval_log.flush()

        if after_each_interval(epoch, save_interval):
            if epoch > 0 and not quick:
                net.save()

    train_log.close()
    eval_log.close()

if __name__ == '__main__':
    parser = utils.get_parser()
    parser.add_argument('-q','--quick', help="Use quick training settings", action='store_true')
    parser.set_defaults(quick=False)
    parser.add_argument('-lr','--learning_rate', help='Learning rate for Adam optimizer', type=float, default=0.0001)
    parser.add_argument('-a','--alpha', help='Coefficient for temporal smoothing loss', type=float, default=0)
    parser.add_argument('-b','--beta', help='Coefficient for cross-correlation loss', type=float, default=0)
    parser.add_argument('-n','--nsteps', help='Number of timesteps for temporal smoothing loss', type=int, default=1)
    parser.add_argument('--norm', help='Which norm to use for temporal smoothing loss', type=str, choices=['l0','lq','l1','l2'], default='l1')
    parser.add_argument('-s','--seed', help='Random seed', type=int, default=0)
    args = parser.parse_args()

    train(seed=args.seed, lr=args.learning_rate, alpha=args.alpha, beta=args.beta, n_steps=args.nsteps, norm=args.norm, quick=args.quick)
