import os, sys
import cv2
import numpy as np
import imageio
import json
import random
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm, trange

import matplotlib.pyplot as plt

import pdb
from progressbar import ProgressBar

from nerf_helpers import *
from config import gen_args
from data import DynamicsDataset
from models import DynamicsModel
from utils import set_seed, count_parameters, get_lr, AverageMeter



def eval():

    args = gen_args()

    set_seed(args.seed)

    use_gpu = torch.cuda.is_available()


    ### make log dir
    os.system('mkdir -p ' + args.evalf)

    log_fout = open(os.path.join(args.evalf, 'log.txt'), 'w')

    print(args)
    print(args, file=log_fout)


    ### dataloaders
    phases = ['valid']

    datasets = {phase: DynamicsDataset(args, phase) for phase in phases}

    def worker_init_fn(worker_id):
        np.random.seed(args.seed + worker_id)

    dataloaders = {phase: DataLoader(
        datasets[phase],
        batch_size=args.batch_size,
        shuffle=True if phase == 'train' else False,
        num_workers=args.num_workers,
        worker_init_fn=worker_init_fn)
        for phase in phases}


    ### create model
    model = DynamicsModel(args)
    print("model #params: %d" % count_parameters(model))

    # resume training of a saved model (if given)
    if args.eval_epoch == -1:
        model_path = os.path.join(args.outf.replace('nn', 'ae'), 'net_best.pth')

    else:
        model_path = os.path.join(args.outf.replace('nn', 'ae'), 'net_epoch_%d_iter_%d.pth' % (
            args.eval_epoch, args.eval_iter))

    print("Loading saved ckp from %s" % model_path)
    pretrained_dict = torch.load(model_path)

    model_dict = model.state_dict()

    # only load parameters in encoder-decoder
    pretrained_dict = {
        k: v for k, v in pretrained_dict.items() \
        if ('img_encoder' in k or 'tf_encoder' in k or 'decoder' in k) and k in model_dict}
    model.load_state_dict(pretrained_dict, strict=False)


    # !!! need to double checkout these parameters
    near = args.near
    far = args.far

    if args.nerf_loss == 1:
        bds_dict = {
            'near' : near,
            'far' : far,
        }
        model.decoder.render_kwargs_train.update(bds_dict)
        model.decoder.render_kwargs_test.update(bds_dict)

    if use_gpu:
        model = model.cuda()


    ### optimizer and losses
    params = model.parameters()
    optimizer = torch.optim.Adam(params, lr=args.lrate, betas=(0.9, 0.999))

    # Misc
    img2mse = lambda x, y : torch.mean((x - y) ** 2)
    mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.]))
    to8b = lambda x : (255 * np.clip(x, 0, 1)).astype(np.uint8)


    # start training
    best_valid_loss = np.inf

    phase = 'valid'
    model.train(phase == 'train')
    meter_loss = AverageMeter()

    if args.nerf_loss == 1:
        meter_loss_nerf = AverageMeter()
    if args.auto_loss == 1:
        meter_loss_auto = AverageMeter()

    bar = ProgressBar(max_value=len(dataloaders[phase]))

    for i, data in bar(enumerate(dataloaders[phase])):

        if i % args.eval_skip_frame == 0:
            store_folder = os.path.join(args.evalf, '%d' % i)
            os.system('mkdir -p ' + store_folder)

            # imgs: B x (n_his + n_roll) x n_view x 3 x H x W
            # poses: B x (n_his + n_roll) x n_view x 4 x 4
            # actions: B x (n_his + n_roll) x action_dim
            # hwf: B x 3

            imgs, poses, actions, hwf = data
            hwf = hwf[0]

            B, N, n_view, C, H, W = imgs.size()


            '''
            img_demo = imgs[0, 0, 0].permute(1, 2, 0).data.numpy()[..., ::-1] * 255.
            cv2.imshow('img', img_demo.astype(np.uint8))
            cv2.waitKey(0)

            print()
            print('imgs.size()', imgs.size())
            print('poses.size()', poses.size())
            print('actions.size()', actions.size())
            print('hwf.size()', hwf.size())
            '''

            mask_nodes = torch.ones(B, N, n_view)   # assume complete mask for now

            if use_gpu:
                imgs = imgs.cuda()
                poses = poses.cuda()
                actions = actions.cuda()
                mask_nodes = mask_nodes.cuda()

            with torch.set_grad_enabled(phase == 'train'):
                ### encode images
                # img_embeds: B x N x n_view x (nf_hidden + 16)
                # st_time = time.time()
                img_embeds = model.encode_img(imgs)
                # print('img_embeds time elapse', time.time() - st_time)

                # print('img_embeds.shape', img_embeds.shape)

                print(B, N, n_view)

                img_embeds = model.encode_state(
                    img_embeds.view(B, N * n_view, 1, -1),
                    poses.view(B, N * n_view, 1, 4, 4),
                    mask_nodes.view(B, N * n_view, 1, -1))
                img_embeds = img_embeds.view(B, N, n_view, -1)

                '''
                print(img_embeds[0, 0, 1])
                print(img_embeds[0, 20, 1])
                time.sleep(100)
                '''

                for k in range(N):
                    # embed_q: nf_hidden
                    embed_q = img_embeds[0, k, 1].data.cpu().numpy()

                    img_np = imgs[0, k, 1].permute(1, 2, 0)
                    img_np = img_np.data.cpu().numpy().clip(0., 1.) * 255
                    img_np = img_np.astype(np.uint8)[..., ::-1]

                    cv2.imwrite(os.path.join(store_folder, 'img_%d.png' % k), img_np)

                    for l in range(0, n_view):
                        # embed_t: N x nf_hidden
                        embed_t = img_embeds[0, :, l].data.cpu().numpy()

                        dis = np.mean((embed_t - embed_q)**2, 1)
                        idx = np.argmin(dis)

                        print(idx, dis[idx])

                        img_np = imgs[0, idx, l].permute(1, 2, 0)
                        img_np = img_np.data.cpu().numpy().clip(0., 1.) * 255
                        img_np = img_np.astype(np.uint8)[..., ::-1]

                        cv2.imwrite(os.path.join(store_folder, 'img_%d_%d.png' % (k, l)), img_np)




if __name__=='__main__':
    torch.set_default_tensor_type('torch.cuda.FloatTensor')

    eval()
