import os, sys
import cv2
import numpy as np
import imageio
import json
import random
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm, trange

import matplotlib.pyplot as plt

import pdb
from progressbar import ProgressBar

from nerf_helpers import *
from config import gen_args
from data import DynamicsDataset
from models import DynamicsModel
from utils import set_seed, count_parameters, get_lr, AverageMeter



def eval():

    args = gen_args()

    set_seed(args.seed)

    use_gpu = torch.cuda.is_available()


    ### make log dir
    os.system('mkdir -p ' + args.evalf)

    log_fout = open(os.path.join(args.evalf, 'log.txt'), 'w')

    print(args)
    print(args, file=log_fout)


    ### dataloaders
    phases = ['valid']

    datasets = {phase: DynamicsDataset(args, phase) for phase in phases}

    def worker_init_fn(worker_id):
        np.random.seed(args.seed + worker_id)

    dataloaders = {phase: DataLoader(
        datasets[phase],
        batch_size=args.batch_size,
        shuffle=True if phase == 'train' else False,
        num_workers=args.num_workers,
        worker_init_fn=worker_init_fn)
        for phase in phases}


    ### create model
    model = DynamicsModel(args)
    print("model #params: %d" % count_parameters(model))

    # resume training of a saved model (if given)
    if args.eval_epoch == -1:
        model_path = os.path.join(args.outf, 'net_best.pth')

    else:
        model_path = os.path.join(args.outf, 'net_epoch_%d_iter_%d.pth' % (
            args.eval_epoch, args.eval_iter))

    print("Loading saved ckp from %s" % model_path)
    pretrained_dict = torch.load(model_path)
    model.load_state_dict(pretrained_dict, strict=False)

    print(model)
    print(model.parameters)

    # !!! need to double checkout these parameters
    near = args.near
    far = args.far

    if args.nerf_loss == 1:
        bds_dict = {
            'near' : near,
            'far' : far,
        }
        model.decoder.render_kwargs_train.update(bds_dict)
        model.decoder.render_kwargs_test.update(bds_dict)

    if use_gpu:
        model = model.cuda()


    ### optimizer and losses
    params = model.parameters()
    optimizer = torch.optim.Adam(params, lr=args.lrate, betas=(0.9, 0.999))

    # Misc
    img2mse = lambda x, y : torch.mean((x - y) ** 2)
    mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.]))
    to8b = lambda x : (255 * np.clip(x, 0, 1)).astype(np.uint8)


    # start training
    best_valid_loss = np.inf

    phase = 'valid'
    model.train(phase == 'train')
    meter_loss = AverageMeter()

    if args.nerf_loss == 1:
        meter_loss_nerf = AverageMeter()
    if args.auto_loss == 1:
        meter_loss_auto = AverageMeter()

    from tqdm import tqdm

    for i, data in tqdm(enumerate(dataloaders[phase])):

        if i < 0:
            continue

        if i % args.eval_skip_frame == 0:
            store_folder = os.path.join(args.evalf, '%d' % i)
            os.system('mkdir -p ' + store_folder)

            # imgs: B x (n_his + n_roll) x n_view x 3 x H x W
            # poses: B x (n_his + n_roll) x n_view x 4 x 4
            # actions: B x (n_his + n_roll) x action_dim
            # hwf: B x 3

            imgs, poses, actions, hwf = data
            hwf = hwf[0]

            B, N, n_view, C, H, W = imgs.size()


            '''
            img_demo = imgs[0, 0, 0].permute(1, 2, 0).data.numpy()[..., ::-1] * 255.
            cv2.imshow('img', img_demo.astype(np.uint8))
            cv2.waitKey(0)

            print()
            print('imgs.size()', imgs.size())
            print('poses.size()', poses.size())
            print('actions.size()', actions.size())
            print('hwf.size()', hwf.size())
            '''

            # use different views for encoding and decoding
            n_view_enc = args.n_view_enc
            n_view_dec = args.n_view - n_view_enc

            mask_nodes = torch.ones(B, N, n_view_enc)   # assume complete mask for now

            if use_gpu:
                imgs = imgs.cuda()
                poses = poses.cuda()
                actions = actions.cuda()
                mask_nodes = mask_nodes.cuda()

            with torch.set_grad_enabled(phase == 'train'):
                # not consider negative examples for evaluation
                imgs = imgs.view(B, N, args.n_view, 2, C, H, W)[:, :, :, 0]

                ### encode images
                # img_embeds: B x N x n_view_enc x (nf_hidden + 16)
                # st_time = time.time()
                img_embeds = model.encode_img(
                    imgs[:, :, :n_view_enc])
                # print('img_embeds time elapse', time.time() - st_time)

                # print('img_embeds.shape', img_embeds.shape)


                if args.nerf_loss or args.auto_loss:
                    ### aggregate image embeddings from multiple views
                    # state_embeds: B x N x nf_hidden
                    # st_time = time.time()
                    state_embeds = model.encode_state(
                        img_embeds,
                        poses[:, :, :n_view_enc],
                        mask_nodes)
                    # print('state_embeds time elapse', time.time() - st_time)

                    # print('state_embeds.shape', state_embeds.shape)


                    ### decode the image embedding to render the original image
                    # rgb: N_rand x 3
                    # extras
                    # target_s: N_rand x 3
                    camera_info = {
                        'hwf': hwf,
                        'poses': poses,
                        'near': near,
                        'far': far}

                    '''
                    rgb, extras, target_s = model.render_rays(
                        state_embeds, camera_info, imgs[:, :, -n_view_dec:])

                    img_loss = img2mse(rgb, target_s)
                    loss_nerf = img_loss  # autoencoding loss

                    if 'rgb0' in extras:
                        img_loss0 = img2mse(extras['rgb0'], target_s)
                        loss_nerf = loss_nerf + img_loss0
                    '''

                    if args.nerf_loss:
                        loss_nerf = torch.zeros(1)

                        for j in range(N):
                            print("%d/%d" % (j, N))
                            # rgb: B x n_view x H x W x 3
                            for k in range(n_view_dec):
                                camera_info['poses'] = poses[:, j, n_view_enc+k:n_view_enc+k+1]


                                rgb, extras = model.render_imgs(state_embeds[:, j], camera_info)



                                rgb_np = rgb.data.cpu().numpy().clip(0., 1.) * 255
                                rgb_np = rgb_np.astype(np.uint8)[..., ::-1]
                                gt_np = imgs[:, j].permute(0, 1, 3, 4, 2)
                                gt_np = gt_np.data.cpu().numpy().clip(0., 1.) * 255
                                gt_np = gt_np.astype(np.uint8)[..., ::-1]

                                print(rgb_np.sum(), "rgb_np sum")

                                cv2.imwrite(os.path.join(store_folder, 'pred_%d_%d.png' % (j, k)), rgb_np[0, 0])
                                cv2.imwrite(os.path.join(store_folder, 'gt_%d_%d.png' % (j, k)), gt_np[0, n_view_enc + k])


                    elif args.auto_loss:
                        camera_info['poses'] = poses[:, :, -n_view_dec:]
                        imgs_pred = model.decode_img(state_embeds, camera_info)
                        imgs_to_render = imgs[:, :, -n_view_dec:]

                        loss_auto = img2mse(imgs_pred, imgs_to_render)
                        print('loss_auto:', loss_auto.item())

                        for j in range(N):
                            print("%d/%d" % (j, N))

                            for k in range(n_view_dec):
                                rgb_np = imgs_pred[0, j, k].permute(1, 2, 0).data.cpu().numpy()
                                rgb_np = rgb_np.clip(0., 1.) * 255
                                rgb_np = rgb_np.astype(np.uint8)[..., ::-1]

                                gt_np = imgs_to_render[0, j, k].permute(1, 2, 0).data.cpu().numpy()
                                gt_np = gt_np.clip(0., 1.) * 255
                                gt_np = gt_np.astype(np.uint8)[..., ::-1]

                                cv2.imwrite(os.path.join(store_folder, 'pred_%d_%d.png' % (j, k)), rgb_np)
                                cv2.imwrite(os.path.join(store_folder, 'gt_%d_%d.png' % (j, k)), gt_np)



                '''
                ### forward prediction
                loss_pred = 0.
                state_cur = state_embeds[:, :args.n_his, :]

                print()
                for j in range(args.n_roll):
                    print("rolling %d/%d" % (j, args.n_roll))
                    action_cur = actions[:, j:j + args.n_his, :]

                    # state_cur: B x n_his x nf_hidden
                    # actions_cur: B x n_his x act_dim
                    # state_pred: B x nf_hidden
                    state_pred = model.dynamics_prediction(state_cur, action_cur)

                    # state_cur: B x n_his x nf_hidden
                    state_cur = torch.cat([state_cur[:, 1:], state_pred[:, None]], 1)

                    loss_pred += F.mse_loss(state_pred, state_embeds[:, j + args.n_his])

                    # render the predicted image
                    # rgb: B x n_view x H x W x 3
                    camera_info['poses'] = poses[:, j+args.n_his]
                    rgb, extras = model.render_imgs(state_pred, camera_info)

                    # store the data
                    rgb_np = rgb.data.cpu().numpy().clip(0., 1.) * 255
                    rgb_np = rgb_np.astype(np.uint8)[..., ::-1]
                    gt_np = imgs[:, j+args.n_his].permute(0, 1, 3, 4, 2)
                    gt_np = gt_np.data.cpu().numpy().clip(0., 1.) * 255
                    gt_np = gt_np.astype(np.uint8)[..., ::-1]

                    for k in range(rgb_np.shape[1]):
                        cv2.imwrite(os.path.join(store_folder, 'pred_%d_%d.png' % (j, k)), rgb_np[0, k])
                        cv2.imwrite(os.path.join(store_folder, 'gt_%d_%d.png' % (j, k)), gt_np[0, k])

                loss_pred = loss_pred / args.n_roll
                '''


            ### optimization
            loss = 0.
            B = imgs.size(0)

            if args.nerf_loss == 1:
                loss += loss_nerf
                meter_loss_nerf.update(loss_nerf.item(), B)

            if args.auto_loss == 1:
                loss += loss_auto
                meter_loss_auto.update(loss_auto.item(), B)

            meter_loss.update(loss.item(), B)

            if phase == 'train':
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            ### log and save ckp
            if i % args.log_per_iter == 0:
                log = '%s [%d/%d] Loss: %.6f (%.6f)' % (
                    phase, i, len(dataloaders[phase]),
                    loss.item(), meter_loss.avg)

                if args.nerf_loss == 1:
                    log += ', loss_nerf: %.6f (%.6f)' % (
                        loss_nerf.item(), meter_loss_nerf.avg)
                if args.auto_loss == 1:
                    log += ', loss_auto: %.6f (%.6f)' % (
                        loss_auto.item(), meter_loss_auto.avg)

                print()
                print(log)
                log_fout.write(log + '\n')
                log_fout.flush()

        log = '%s Loss: %.4f' % (phase, meter_loss.avg)
        print(log)
        log_fout.write(log + '\n')
        log_fout.flush()




if __name__=='__main__':
    torch.set_default_tensor_type('torch.cuda.FloatTensor')

    eval()
