import os, sys
import cv2
import numpy as np
import imageio
import json
import random
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm, trange

import matplotlib.pyplot as plt

import pdb
from progressbar import ProgressBar

from nerf_helpers import *
from config import gen_args
from data import DynamicsDataset
from models import DynamicsModel
from utils import set_seed, count_parameters, get_lr, AverageMeter



def train():

    args = gen_args()

    set_seed(args.seed)

    use_gpu = torch.cuda.is_available()


    ### make log dir
    os.system('mkdir -p ' + args.outf)

    if args.resume == 0:
        log_fout = open(os.path.join(args.outf, 'log.txt'), 'w')
    else:
        log_fout = open(os.path.join(args.outf, 'log_resume_epoch_%d_iter_%d.txt' % (
            args.resume_epoch, args.resume_iter)), 'w')

    print(args)
    print(args, file=log_fout)


    ### dataloaders
    phases = ['train', 'valid']

    datasets = {phase: DynamicsDataset(args, phase) for phase in phases}

    def worker_init_fn(worker_id):
        np.random.seed(args.seed + worker_id)

    dataloaders = {phase: DataLoader(
        datasets[phase],
        batch_size=args.batch_size,
        shuffle=True if phase == 'train' else False,
        num_workers=args.num_workers,
        worker_init_fn=worker_init_fn)
        for phase in phases}


    ### create model
    model = DynamicsModel(args)
    print("model #params: %d" % count_parameters(model))
    print("image encoder #params: %d" % count_parameters(model.img_encoder))
    print("tf encoder #params: %d" % count_parameters(model.tf_encoder))
    print("dynamics #params: %d" % count_parameters(model.latent_space_dynamics))
    print("decoder #params: %d" % count_parameters(model.latent_space_dynamics))

    # resume training of a saved model (if given)
    if args.resume == 0:

        # NeRF + TCN
        if args.nerf_loss == 1 and args.ct_loss == 1 and args.auto_loss == 0:
            if args.env == 'FluidManipClip':
                model_path = os.path.join(
                    args.outf.replace('dy', 'ae').replace('nHis%d' % args.n_his, 'nHis1'),
                    # 'net_epoch_0_iter_50000.pth')
                    'net_epoch_0_iter_134000.pth')
            elif args.env == 'FluidShakeWithIce_1000':
                model_path = os.path.join(
                    args.outf.replace('dy', 'ae').replace('nHis%d' % args.n_his, 'nHis1'),
                    'net_epoch_1_iter_95000.pth')
            elif args.env == 'RigidDrop':
                model_path = os.path.join(
                    args.outf.replace('dy', 'ae').replace('nHis%d' % args.n_his, 'nHis1'),
                    'net_epoch_4_iter_22000.pth')
            elif args.env == 'RigidFall':
                model_path = os.path.join(
                    args.outf.replace('dy', 'ae').replace('nHis%d' % args.n_his, 'nHis1'),
                    'net_epoch_13_iter_18000.pth')
            elif args.env == 'FluidShake':
                model_path = os.path.join(
                    args.outf.replace('dy', 'ae').replace('nHis%d' % args.n_his, 'nHis1'),
                    'net_epoch_2_iter_113000.pth')


            # !!! 2021/06/16
            elif args.env == 'FluidManipClip_wKuka_wColor':
                model_path = os.path.join(
                    args.outf.replace('dy', 'ae').replace('nHis%d' % args.n_his, 'nHis1'),
                    'net_epoch_2_iter_125000.pth')
            elif args.env == 'FluidShakeWithIce_wKuka_wColor_wGripper':
                model_path = os.path.join(
                    args.outf.replace('dy', 'ae').replace('nHis%d' % args.n_his, 'nHis1'),
                    'net_epoch_2_iter_110000.pth')

        # NeRF only
        elif args.nerf_loss == 1 and args.ct_loss == 0 and args.auto_loss == 0:
            if args.env == 'FluidManipClip':
                model_path = os.path.join(
                    args.outf.replace('dy', 'ae').replace('nHis%d' % args.n_his, 'nHis1'),
                    'net_epoch_0_iter_129000.pth')
            elif args.env == 'FluidShakeWithIce_1000':
                model_path = os.path.join(
                    args.outf.replace('dy', 'ae').replace('nHis%d' % args.n_his, 'nHis1'),
                    'net_epoch_1_iter_57000.pth')
            elif args.env == 'RigidDrop':
                model_path = os.path.join(
                    args.outf.replace('dy', 'ae').replace('nHis%d' % args.n_his, 'nHis1'),
                    'net_best.pth')

            # !!! 2021/06/16
            elif args.env == 'FluidManipClip_wKuka_wColor':
                model_path = os.path.join(
                    args.outf.replace('dy', 'ae').replace('nHis%d' % args.n_his, 'nHis1'),
                    'net_epoch_2_iter_115000.pth')
            elif args.env == 'FluidShakeWithIce_wKuka_wColor_wGripper':
                model_path = os.path.join(
                    args.outf.replace('dy', 'ae').replace('nHis%d' % args.n_his, 'nHis1'),
                    'net_epoch_2_iter_115000.pth')


        # TCN only
        elif args.nerf_loss == 0 and args.ct_loss == 1 and args.auto_loss == 0:
            if args.env == 'FluidManipClip':
                model_path = os.path.join(
                    args.outf.replace('dy', 'ae').replace('nHis%d' % args.n_his, 'nHis1'),
                    'net_epoch_5_iter_120000.pth')
            elif args.env == 'FluidShakeWithIce_1000':
                model_path = os.path.join(
                    args.outf.replace('dy', 'ae').replace('nHis%d' % args.n_his, 'nHis1'),
                    # 'net_epoch_0_iter_93000.pth')
                    'net_epoch_3_iter_94000.pth')
            elif args.env == 'RigidDrop':
                model_path = os.path.join(
                    args.outf.replace('dy', 'ae').replace('nHis%d' % args.n_his, 'nHis1'),
                    'net_epoch_21_iter_20000.pth')

            # !!! 2021/06/16
            elif args.env == 'FluidManipClip_wKuka_wColor':
                model_path = os.path.join(
                    args.outf.replace('dy', 'ae').replace('nHis%d' % args.n_his, 'nHis1'),
                    'net_epoch_3_iter_30000.pth')
            elif args.env == 'FluidShakeWithIce_wKuka_wColor_wGripper':
                model_path = os.path.join(
                    args.outf.replace('dy', 'ae').replace('nHis%d' % args.n_his, 'nHis1'),
                    'net_epoch_3_iter_20000.pth')


        # TCN + autoencoder
        elif args.nerf_loss == 0 and args.ct_loss == 1 and args.auto_loss == 1:
            if args.env == 'FluidManipClip':
                model_path = os.path.join(
                    args.outf.replace('dy', 'ae').replace('nHis%d' % args.n_his, 'nHis1'),
                    'net_epoch_4_iter_120000.pth')
            elif args.env == 'FluidShakeWithIce_1000':
                model_path = os.path.join(
                    args.outf.replace('dy', 'ae').replace('nHis%d' % args.n_his, 'nHis1'),
                    'net_epoch_4_iter_72000.pth')
            elif args.env == 'RigidDrop':
                model_path = os.path.join(
                    args.outf.replace('dy', 'ae').replace('nHis%d' % args.n_his, 'nHis1'),
                    'net_epoch_53_iter_0.pth')

            # !!! 2021/06/16
            elif args.env == 'FluidManipClip_wKuka_wColor':
                model_path = os.path.join(
                    args.outf.replace('dy', 'ae').replace('nHis%d' % args.n_his, 'nHis1'),
                    'net_epoch_28_iter_28000.pth')
            elif args.env == 'FluidShakeWithIce_wKuka_wColor_wGripper':
                model_path = os.path.join(
                    args.outf.replace('dy', 'ae').replace('nHis%d' % args.n_his, 'nHis1'),
                    'net_epoch_27_iter_12000.pth')


        else:
            raise AssertionError("Unsupported combination nerf_loss = %d, ct_loss = %d, auto_loss = %d" % (
                args.nerf_loss, args.ct_loss, args.auto_loss))




        print("Loading save ckp from autoencoding %s" % model_path)

        pretrained_dict = torch.load(model_path)

        model_dict = model.state_dict()

        # only load parameters in encoder-decoder
        pretrained_dict = {
            k: v for k, v in pretrained_dict.items() \
            if ('img_encoder' in k or 'tf_encoder' in k or 'decoder' in k) and k in model_dict}
        model.load_state_dict(pretrained_dict, strict=False)

    elif args.resume == 1:
        model_path = os.path.join(args.outf, 'net_epoch_%d_iter_%d.pth' % (
            args.resume_epoch, args.resume_iter))
        print("Loading saved ckp from %s" % model_path)

        pretrained_dict = torch.load(model_path)
        model.load_state_dict(pretrained_dict)

    if args.nerf_loss == 1:
        # !!! need to double checkout these parameters
        near = args.near
        far = args.far

        bds_dict = {
            'near' : near,
            'far' : far,
        }
        model.decoder.render_kwargs_train.update(bds_dict)
        model.decoder.render_kwargs_test.update(bds_dict)

    if use_gpu:
        model = model.cuda()


    ### optimizer and losses
    params = model.latent_space_dynamics.parameters()
    optimizer = torch.optim.Adam(params, lr=args.lrate, betas=(0.9, 0.999))

    # Misc
    img2mse = lambda x, y : torch.mean((x - y) ** 2)
    mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.]))
    to8b = lambda x : (255 * np.clip(x, 0, 1)).astype(np.uint8)


    # start training
    st_epoch = args.resume_epoch if args.resume_epoch > 0 else 0
    best_valid_loss = np.inf

    for epoch in range(st_epoch, args.n_epoch):

        for phase in phases:
            model.train(phase == 'train')
            meter_loss = AverageMeter()

            bar = ProgressBar(max_value=len(dataloaders[phase]))

            for i, data in bar(enumerate(dataloaders[phase])):

                if args.prestored == 0:
                    # imgs: B x (n_his + n_roll) x n_view x 3 x H x W
                    # poses: B x (n_his + n_roll) x n_view x 4 x 4
                    # actions: B x (n_his + n_roll) x action_dim
                    # hwf: B x 3
                    imgs, poses, actions, hwf = data
                    hwf = hwf[0]

                    B, N, n_view, C, H, W = imgs.size()


                    '''
                    img_demo = imgs[0, 0, 0].permute(1, 2, 0).data.numpy()[..., ::-1] * 255.
                    cv2.imshow('img', img_demo.astype(np.uint8))
                    cv2.waitKey(0)

                    print()
                    print('imgs.size()', imgs.size())
                    print('poses.size()', poses.size())
                    print('actions.size()', actions.size())
                    print('hwf.size()', hwf.size())
                    '''

                    # use different views for encoding and decoding
                    n_view_enc = args.n_view_enc
                    n_view_dec = n_view - n_view_enc

                    mask_nodes = torch.ones(B, N, n_view_enc)   # assume complete mask for now

                    if use_gpu:
                        imgs = imgs.cuda()
                        poses = poses.cuda()
                        actions = actions.cuda()
                        mask_nodes = mask_nodes.cuda()


                    with torch.set_grad_enabled(phase == 'train'):
                        ### encode images
                        # img_embeds: B x N x n_view_enc x nf_hidden
                        # st_time = time.time()
                        img_embeds = model.encode_img(
                            imgs[:, :, :n_view_enc])
                        # print('img_embeds time elapse', time.time() - st_time)

                        # print('img_embeds.shape', img_embeds.shape)


                        ### aggregate image embeddings from multiple views
                        # state_embeds: B x N x nf_hidden
                        # st_time = time.time()
                        if args.nerf_loss == 1 or args.auto_loss == 1:
                            state_embeds = model.encode_state(
                                img_embeds,
                                poses[:, :, :n_view_enc],
                                mask_nodes)

                        elif args.nerf_loss == 0 and args.auto_loss == 0 and args.ct_loss == 1:
                            state_embeds = torch.mean(img_embeds, 2)
                            state_embeds_norm = state_embeds.norm(dim=2, keepdim=True) + 1e-8
                            state_embeds = state_embeds / state_embeds_norm

                        else:
                            raise AssertionError("Unsupported loss combination")

                        # print('state_embeds time elapse', time.time() - st_time)

                        # print('state_embeds.shape', state_embeds.shape)
                        # print(state_embeds.norm(dim=2))


                        '''
                        ### decode the image embedding to render the original image
                        # rgb: N_rand x 3
                        # extras
                        # target_s: N_rand x 3
                        camera_info = {
                            'hwf': hwf,
                            'poses': poses[:, :, -n_view_dec:],
                            'near': near,
                            'far': far}
                        # st_time = time.time()
                        rgb, extras, target_s = model.render_rays(
                            state_embeds, camera_info, imgs[:, :, -n_view_dec:])
                        # print('render_rays time elapse', time.time() - st_time)

                        img_loss = img2mse(rgb, target_s)
                        loss_ae = img_loss  # autoencoding loss

                        if 'rgb0' in extras:
                            img_loss0 = img2mse(extras['rgb0'], target_s)
                            loss_ae = loss_ae + img_loss0
                        '''

                else:

                    # state_embeds: B x (n_his + n_roll) x nf_hidden
                    # actions: B x (n_his + n_roll) x action_dim
                    state_embeds, actions = data

                    B, N, _ = state_embeds.size()

                    if use_gpu:
                        state_embeds = state_embeds.cuda()
                        actions = actions.cuda()


                with torch.set_grad_enabled(phase == 'train'):
                    ### forward prediction
                    loss_dy = 0.

                    state_cur = state_embeds[:, :args.n_his, :]
                    for j in range(args.n_roll):
                        action_cur = actions[:, j:j + args.n_his, :]

                        # state_cur: B x n_his x nf_hidden
                        # actions_cur: B x n_his x act_dim
                        # state_pred: B x nf_hidden
                        state_pred = model.dynamics_prediction(state_cur, action_cur)

                        # print('state_pred.shape', state_pred.shape)
                        # print(state_pred.norm(dim=-1))

                        # state_cur: B x n_his x nf_hidden
                        state_cur = torch.cat([state_cur[:, 1:], state_pred[:, None]], 1)

                        loss_dy += F.mse_loss(state_pred, state_embeds[:, j + args.n_his])

                    loss_dy = loss_dy / args.n_roll


                ### optimization
                loss = loss_dy

                meter_loss.update(loss.item(), B)

                if phase == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                ### log and save ckp
                if i % args.log_per_iter == 0:
                    log = '%s [%d/%d][%d/%d] LR: %.6f, Loss: %.6f (%.6f)' % (
                        phase, epoch, args.n_epoch, i, len(dataloaders[phase]),
                        get_lr(optimizer),
                        np.sqrt(loss.item()), np.sqrt(meter_loss.avg))

                    print()
                    print(log)
                    log_fout.write(log + '\n')
                    log_fout.flush()

                if phase == 'train' and i % args.ckp_per_iter == 0:
                    torch.save(model.state_dict(), '%s/net_epoch_%d_iter_%d.pth' % (args.outf, epoch, i))

            log = '%s [%d/%d] Loss: %.6f, Best valid: %.6f' % (
                phase, epoch, args.n_epoch, np.sqrt(meter_loss.avg), np.sqrt(best_valid_loss))
            print(log)
            log_fout.write(log + '\n')
            log_fout.flush()

            if phase == 'valid':
                if meter_loss.avg < best_valid_loss:
                    best_valid_loss = meter_loss.avg
                    torch.save(model.state_dict(), '%s/net_best.pth' % (args.outf))




if __name__=='__main__':
    torch.set_default_tensor_type('torch.cuda.FloatTensor')

    train()
