import os, sys
import cv2
import numpy as np
import imageio
import json
import random
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm, trange

import matplotlib.pyplot as plt

import pdb
from progressbar import ProgressBar

from nerf_helpers import *
from config import gen_args
from data import DynamicsDataset
from models import DynamicsModel
from utils import set_seed, count_parameters, get_lr, AverageMeter



def eval():

    args = gen_args()

    set_seed(args.seed)

    use_gpu = torch.cuda.is_available()


    ### make log dir
    os.system('mkdir -p ' + args.evalf)

    log_fout = open(os.path.join(args.evalf, 'log.txt'), 'w')

    print(args)
    print(args, file=log_fout)


    ### dataloaders
    phases = ['valid']

    datasets = {phase: DynamicsDataset(args, phase) for phase in phases}

    def worker_init_fn(worker_id):
        np.random.seed(args.seed + worker_id)

    dataloaders = {phase: DataLoader(
        datasets[phase],
        batch_size=args.batch_size,
        shuffle=True if phase == 'train' else False,
        num_workers=args.num_workers,
        worker_init_fn=worker_init_fn)
        for phase in phases}


    ### create model
    model = DynamicsModel(args)
    print("model #params: %d" % count_parameters(model))

    # resume training of a saved model (if given)
    if args.eval_epoch == -1:
        '''
        model_dec_path = os.path.join(
            args.outf.replace('dy', 'dec').replace('nHis%d' % args.n_his, 'nHis1'), 'net_best.pth')
        print("Loading save ckp from dec %s" % model_dec_path)
        pretrained_dict_dec = torch.load(model_dec_path)

        model_dy_path = os.path.join(args.outf, 'net_best.pth')
        print("Loading save ckp from dy %s" % model_dy_path)
        pretrained_dict_dy = torch.load(model_dy_path)

        model_dict = model.state_dict()

        pretrained_dict = {}

        for k, v in pretrained_dict_dec.items():
            if ('img_encoder' in k or 'tf_encoder' in k or 'decoder' in k) and (k in model_dict):
                pretrained_dict[k] = v
        for k, v in pretrained_dict_dy.items():
            if ('latent_space_dynamics' in k) and (k in model_dict):
                pretrained_dict[k] = v
        '''

        model_path = os.path.join(args.outf, 'net_best.pth')
        print("Loading saved ckp from %s" % model_path)
        pretrained_dict = torch.load(model_path)

        model.load_state_dict(pretrained_dict)

    else:
        model_path = os.path.join(args.outf, 'net_epoch_%d_iter_%d.pth' % (
            args.eval_epoch, args.eval_iter))

        print("Loading saved ckp from %s" % model_path)
        pretrained_dict = torch.load(model_path)
        model.load_state_dict(pretrained_dict)


    near = args.near
    far = args.far

    if args.nerf_loss == 1:
        # !!! need to double checkout these parameters
        bds_dict = {
            'near' : near,
            'far' : far,
        }
        model.decoder.render_kwargs_train.update(bds_dict)
        model.decoder.render_kwargs_test.update(bds_dict)

    if use_gpu:
        model = model.cuda()


    ### optimizer and losses
    params = model.parameters()
    optimizer = torch.optim.Adam(params, lr=args.lrate, betas=(0.9, 0.999))

    # Misc
    img2mse = lambda x, y : torch.mean((x - y) ** 2)
    mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.]))
    to8b = lambda x : (255 * np.clip(x, 0, 1)).astype(np.uint8)


    # start training
    best_valid_loss = np.inf

    phase = 'valid'
    model.train(phase == 'train')
    meter_loss = AverageMeter()
    meter_loss_ae = AverageMeter()
    meter_loss_dy = AverageMeter()

    bar = ProgressBar(max_value=len(dataloaders[phase]))

    for i, data in bar(enumerate(dataloaders[phase])):

        if i < 80:
            # ignore the first 80 data points
            continue

        if i % args.eval_skip_frame == 0:
            store_folder = os.path.join(args.evalf, '%d' % i)
            os.system('mkdir -p ' + store_folder)

            # imgs: B x (n_his + n_roll) x n_view x 3 x H x W
            # poses: B x (n_his + n_roll) x n_view x 4 x 4
            # actions: B x (n_his + n_roll) x action_dim
            # hwf: B x 3

            imgs, poses, actions, hwf = data
            hwf = hwf[0]

            B, N, n_view, C, H, W = imgs.size()

            '''
            img_demo = imgs[0, 0, 0].permute(1, 2, 0).data.numpy()[..., ::-1] * 255.
            cv2.imshow('img', img_demo.astype(np.uint8))
            cv2.waitKey(0)
            '''



            '''
            img_demo = imgs[0, 0, 0].permute(1, 2, 0).data.numpy()[..., ::-1] * 255.
            cv2.imshow('img', img_demo.astype(np.uint8))
            cv2.waitKey(0)

            print()
            print('imgs.size()', imgs.size())
            print('poses.size()', poses.size())
            print('actions.size()', actions.size())
            print('hwf.size()', hwf.size())
            '''

            # use different views for encoding and decoding
            n_view_enc = args.n_view_enc
            n_view_dec = n_view - n_view_enc

            mask_nodes = torch.ones(B, N, n_view_enc)   # assume complete mask for now

            if use_gpu:
                imgs = imgs.cuda()
                poses = poses.cuda()
                actions = actions.cuda()
                mask_nodes = mask_nodes.cuda()

            with torch.set_grad_enabled(phase == 'train'):
                ### encode images
                # img_embeds: B x N x n_view_enc x (nf_hidden + 16)
                # st_time = time.time()
                img_embeds = model.encode_img(
                    imgs[:, :, :n_view_enc])
                # print('img_embeds time elapse', time.time() - st_time)

                # print('img_embeds.shape', img_embeds.shape)

                '''
                imgs_enc = imgs[:, :, :n_view_enc].permute(0, 1, 2, 4, 5, 3).data.cpu().numpy() * 255.
                print(imgs_enc.shape)

                img_show = np.concatenate(
                    [imgs_enc[0, 0, 0], imgs_enc[0, 0, 1], imgs_enc[0, 0, 2], imgs_enc[0, 0, 3]], 1)
                cv2.imshow('image', img_show.astype(np.uint8))
                cv2.waitKey(0)
                cv2.destroyAllWindows()

                cv2.imwrite('test_0.png', imgs_enc[0, 0, 0].astype(np.uint8)[..., ::-1])
                cv2.imwrite('test_1.png', imgs_enc[0, 0, 1].astype(np.uint8)[..., ::-1])
                cv2.imwrite('test_2.png', imgs_enc[0, 0, 2].astype(np.uint8)[..., ::-1])
                cv2.imwrite('test_3.png', imgs_enc[0, 0, 3].astype(np.uint8)[..., ::-1])

                print(img_embeds.size())
                print(img_embeds)
                print(poses)
                '''


                ### aggregate image embeddings from multiple views
                # state_embeds: B x N x nf_hidden
                # st_time = time.time()
                if args.nerf_loss == 1 or args.auto_loss == 1:
                    state_embeds = model.encode_state(
                        img_embeds,
                        poses[:, :, :n_view_enc],
                        mask_nodes)

                elif args.nerf_loss == 0 and args.auto_loss == 0 and args.ct_loss == 1:
                    state_embeds = torch.mean(img_embeds, 2)
                    state_embeds_norm = state_embeds.norm(dim=2, keepdim=True) + 1e-8
                    state_embeds = state_embeds / state_embeds_norm

                else:
                    raise AssertionError("Unsupported loss combination")

                # print('state_embeds time elapse', time.time() - st_time)

                '''
                print(state_embeds.size())
                print(state_embeds)
                exit(0)
                '''



                # print('state_embeds.shape', state_embeds.shape)


                '''
                ### decode the image embedding to render the original image
                # rgb: N_rand x 3
                # extras
                # target_s: N_rand x 3
                camera_info = {
                    'hwf': hwf,
                    'poses': poses,
                    'near': near,
                    'far': far}

                rgb, extras, target_s = model.render_rays(
                    state_embeds, camera_info, imgs[:, :, -n_view_dec:])

                img_loss = img2mse(rgb, target_s)
                loss_ae = img_loss  # autoencoding loss

                if 'rgb0' in extras:
                    img_loss0 = img2mse(extras['rgb0'], target_s)
                    loss_ae = loss_ae + img_loss0
                '''

                '''
                loss_ae = torch.zeros(1)

                print()
                for j in range(N):
                    print("%d/%d" % (j, N))
                    # rgb: B x n_view x H x W x 3
                    for k in range(n_view_dec):
                        camera_info['poses'] = poses[:, j, n_view_enc+k:n_view_enc+k+1]
                        rgb, extras = model.render_imgs(state_embeds[:, j], camera_info)

                        rgb_np = rgb.data.cpu().numpy().clip(0., 1.) * 255
                        rgb_np = rgb_np.astype(np.uint8)[..., ::-1]
                        gt_np = imgs[:, j].permute(0, 1, 3, 4, 2)
                        gt_np = gt_np.data.cpu().numpy().clip(0., 1.) * 255
                        gt_np = gt_np.astype(np.uint8)[..., ::-1]

                        cv2.imwrite(os.path.join(store_folder, 'pred_%d_%d.png' % (j, k)), rgb_np[0, 0])
                        cv2.imwrite(os.path.join(store_folder, 'gt_%d_%d.png' % (j, k)), gt_np[0, n_view_enc + k])
                '''


                ### forward prediction
                loss_dy = 0.
                state_cur = state_embeds[:, :args.n_his, :]

                print()
                for j in range(args.n_roll):
                    action_cur = actions[:, j:j + args.n_his, :]

                    # state_cur: B x n_his x nf_hidden
                    # actions_cur: B x n_his x act_dim
                    # state_pred: B x nf_hidden
                    state_pred = model.dynamics_prediction(state_cur, action_cur)

                    # state_cur: B x n_his x nf_hidden
                    state_cur = torch.cat([state_cur[:, 1:], state_pred[:, None]], 1)

                    loss_dy_cur = F.mse_loss(state_pred, state_embeds[:, j + args.n_his])
                    loss_dy += loss_dy_cur

                    print("rolling %d/%d: %.6f" % (j, args.n_roll, loss_dy_cur.item()))

                    camera_info = {
                        'hwf': hwf,
                        'poses': poses,
                        'near': near,
                        'far': far}

                    if args.nerf_loss == 1:

                        for k in range(n_view_dec):
                            # render the predicted image
                            # rgb: B x n_view x H x W x 3
                            camera_info['poses'] = poses[:, j+args.n_his, n_view_enc+k:n_view_enc+k+1]
                            rgb, extras = model.render_imgs(state_pred, camera_info)

                            # store the data
                            rgb_np = rgb.data.cpu().numpy().clip(0., 1.) * 255
                            rgb_np = rgb_np.astype(np.uint8)[..., ::-1]
                            gt_np = imgs[:, j+args.n_his].permute(0, 1, 3, 4, 2)
                            gt_np = gt_np.data.cpu().numpy().clip(0., 1.) * 255
                            gt_np = gt_np.astype(np.uint8)[..., ::-1]

                            cv2.imwrite(os.path.join(store_folder, 'pred_%d_%d.png' % (j, k)), rgb_np[0, 0])
                            cv2.imwrite(os.path.join(store_folder, 'gt_%d_%d.png' % (j, k)), gt_np[0, n_view_enc + k])

                    elif args.auto_loss == 1:

                        camera_info['poses'] = poses[:, j+args.n_his:j+args.n_his+1, -n_view_dec:]
                        imgs_pred = model.decode_img(state_pred[:, None, :], camera_info)
                        imgs_to_render = imgs[:, j+args.n_his:j+args.n_his+1, -n_view_dec:]

                        loss_auto = img2mse(imgs_pred, imgs_to_render)
                        print('loss_auto:', loss_auto.item())

                        for k in range(n_view_dec):
                            rgb_np = imgs_pred[0, 0, k].permute(1, 2, 0).data.cpu().numpy()
                            rgb_np = rgb_np.clip(0., 1.) * 255
                            rgb_np = rgb_np.astype(np.uint8)[..., ::-1]

                            gt_np = imgs_to_render[0, 0, k].permute(1, 2, 0).data.cpu().numpy()
                            gt_np = gt_np.clip(0., 1.) * 255
                            gt_np = gt_np.astype(np.uint8)[..., ::-1]

                            cv2.imwrite(os.path.join(store_folder, 'pred_%d_%d.png' % (j, k)), rgb_np)
                            cv2.imwrite(os.path.join(store_folder, 'gt_%d_%d.png' % (j, k)), gt_np)



                loss_dy = loss_dy / args.n_roll


            ### optimization
            # loss = loss_ae + loss_dy
            loss = loss_dy

            B = imgs.size(0)
            # meter_loss_ae.update(loss_ae.item(), B)
            meter_loss_dy .update(loss_dy.item(), B)
            meter_loss.update(loss.item(), B)

            if phase == 'train':
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            ### log and save ckp
            if i % args.log_per_iter == 0:
                log = '%s [%d/%d] Loss: %.6f (%.6f), loss_dy: %.6f (%.6f)' % (
                    phase, i, len(dataloaders[phase]),
                    loss.item(), meter_loss.avg,
                    loss_dy.item(), meter_loss_dy.avg)

                '''
                log = '%s [%d/%d] Loss: %.6f (%.6f), loss_ae: %.6f (%.6f), loss_dy: %.6f (%.6f)' % (
                    phase, i, len(dataloaders[phase]),
                    loss.item(), meter_loss.avg,
                    loss_ae.item(), meter_loss_ae.avg,
                    loss_dy.item(), meter_loss_dy.avg)
                '''

                print()
                print(log)
                log_fout.write(log + '\n')
                log_fout.flush()

        log = '%s Loss: %.4f' % (phase, meter_loss.avg)
        print(log)
        log_fout.write(log + '\n')
        log_fout.flush()




if __name__ == '__main__':
    torch.set_default_tensor_type('torch.cuda.FloatTensor')

    eval()
