import os, sys
import cv2
import numpy as np
import imageio
import json
import random
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm, trange

import matplotlib.pyplot as plt

import pdb
from progressbar import ProgressBar

from nerf_helpers import *
from config import gen_args
from data import DynamicsDataset
from models import DynamicsModel
from utils import set_seed, count_parameters, get_lr, AverageMeter



def my_collate_fn(batch):
    #  batch是一个列表，其中是一个一个的元组，每个元组是dataset中_getitem__的结果
    # batch = list(zip(*batch))
    # print(len(batch))
    # print(batch[0])
    print(batch[0][0].shape, batch[1][0].shape)
    print(len(batch))
    print([batch[i][0].shape for i in range(len(batch))])
    imgs = torch.tensor(torch.stack([batch[i][0] for i in range(len(batch))]))
    poses = torch.tensor(torch.stack([batch[i][1] for i in range(len(batch))]))
    actions = torch.tensor(torch.stack([batch[i][2] for i in range(len(batch))]))
    hwf = torch.tensor(torch.stack([batch[i][3] for i in range(len(batch))]))

    return imgs, poses, actions, hwf


def train():
    args = gen_args()

    set_seed(args.seed)

    use_gpu = torch.cuda.is_available()


    ### make log dir
    os.system('mkdir -p ' + args.outf)

    if args.resume == 0 or args.resume == 2:
        log_fout = open(os.path.join(args.outf, 'log.txt'), 'w')
    else:
        log_fout = open(os.path.join(args.outf, 'log_resume_epoch_%d_iter_%d.txt' % (
            args.resume_epoch, args.resume_iter)), 'w')

    print(args)
    print(args, file=log_fout)


    ### dataloaders
    phases = ['train', 'valid']

    datasets = {phase: DynamicsDataset(args, phase) for phase in phases}

    def worker_init_fn(worker_id):
        np.random.seed(args.seed + worker_id)

    dataloaders = {phase: DataLoader(
        datasets[phase],
        batch_size=args.batch_size,
        shuffle=True if phase == 'train' else False,
        num_workers=args.num_workers,
        worker_init_fn=worker_init_fn,
       )
        for phase in phases}


    ### create model
    model = DynamicsModel(args)
    print("model #params: %d" % count_parameters(model))

    # resume training of a saved model (if given)
    if args.resume == 0:
        print("Randomly initialize the model's parameters")

    elif args.resume == 2:
        # model_path = os.path.join(args.outf, 'net_epoch_%d_iter_%d.pth' % (
        #     args.resume_epoch, args.resume_iter))
        model_path = "../nerf_dy/dump/dump_FluidManipClip_wKuka_wColor/files_ae_nHis1_ct_nerf/net_epoch_2_iter_125000.pth"
        print("Loading saved ckp from %s" % model_path)

        pretrained_dict = torch.load(model_path)

        '''
        model_dict = model.state_dict()

        # only load parameters in encoder-decoder
        pretrained_dict = {
            k: v for k, v in pretrained_dict.items() \
            if ('img_encoder' in k or 'tf_encoder' in k or 'decoder' in k) and k in model_dict}
        model.load_state_dict(pretrained_dict, strict=False)
        '''
        model.load_state_dict(pretrained_dict)

    # !!! need to double checkout these parameters
    near = args.near
    far = args.far

    if args.nerf_loss == 1:
        bds_dict = {
            'near' : near,
            'far' : far,
        }
        model.decoder.render_kwargs_train.update(bds_dict)
        model.decoder.render_kwargs_test.update(bds_dict)

    if use_gpu:
        model = model.cuda()


    ### optimizer and losses
    params = model.parameters()
    optimizer = torch.optim.Adam(params, lr=args.lrate, betas=(0.9, 0.999))

    # Misc
    img2mse = lambda x, y : torch.mean((x - y) ** 2)
    mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.]))
    to8b = lambda x : (255 * np.clip(x, 0, 1)).astype(np.uint8)


    # start training
    st_epoch = args.resume_epoch if args.resume_epoch > 0 else 0
    best_valid_loss = np.inf

    from tqdm import tqdm

    for epoch in range(st_epoch, args.n_epoch):

        for phase in phases:
            print("---------------------- %s --------------", phase)
            model.train(phase == 'train')
            meter_loss = AverageMeter()

            if args.ct_loss:
                meter_loss_ct = AverageMeter()
            if args.nerf_loss:
                meter_loss_nerf = AverageMeter()
            if args.auto_loss:
                meter_loss_auto = AverageMeter()

            # bar = ProgressBar(max_value=len(dataloaders[phase]))

            for i, data in tqdm(enumerate(dataloaders[phase])):
                # imgs: B x (n_his + n_roll) x (n_view * 2) x 3 x H x W
                # poses: B x (n_his + n_roll) x n_view x 4 x 4
                # actions: B x (n_his + n_roll) x action_dim
                # hwf: B x 3

                print(i)

                imgs, poses, actions, hwf = data
                hwf = hwf[0]

                B, N, _, C, H, W = imgs.size()
                n_view = args.n_view


                '''
                img_demo = imgs[0, 0, 0].permute(1, 2, 0).data.numpy()[..., ::-1] * 255.
                cv2.imshow('img', img_demo.astype(np.uint8))
                cv2.waitKey(0)

                print()
                print('imgs.size()', imgs.size())
                print('poses.size()', poses.size())
                print('actions.size()', actions.size())
                print('hwf.size()', hwf.size())
                '''

                # use different views for encoding and decoding
                n_view_enc = args.n_view_enc
                n_view_dec = n_view - n_view_enc

                mask_nodes = torch.ones(B, N, n_view_enc)   # assume complete mask for now

                if use_gpu:
                    imgs = imgs.cuda()
                    poses = poses.cuda()
                    actions = actions.cuda()
                    mask_nodes = mask_nodes.cuda()


                with torch.set_grad_enabled(phase == 'train'):
                    ### encode images
                    # img_embeds: B x N x (n_view_enc * 2) x nf_hidden
                    # st_time = time.time()

                    '''
                    # used for debug
                    print()
                    imgs_enc = imgs[:, :, :n_view_enc * 2].view(B, N, n_view_enc, 2, C, H, W)
                    print(imgs_enc.size())

                    imgs_enc_a = imgs_enc[:, :, :, 0].permute(0, 1, 2, 4, 5, 3).data.cpu().numpy() * 255.
                    imgs_enc_b = imgs_enc[:, :, :, 1].permute(0, 1, 2, 4, 5, 3).data.cpu().numpy() * 255.

                    img_show = np.concatenate(
                        [imgs_enc_a[0, 0, 0], imgs_enc_a[0, 0, -1], imgs_enc_b[0, 0, 0]], 1)
                    cv2.imshow('image', img_show.astype(np.uint8))
                    cv2.waitKey(0)
                    cv2.destroyAllWindows()

                    cv2.imwrite('test_anchor.png', imgs_enc_a[0, 0, 0].astype(np.uint8)[..., ::-1])
                    cv2.imwrite('test_pos.png', imgs_enc_a[0, 0, -1].astype(np.uint8)[..., ::-1])
                    cv2.imwrite('test_neg.png', imgs_enc_b[0, 0, 0].astype(np.uint8)[..., ::-1])
                    '''

                    img_embeds = model.encode_img(
                        imgs[:, :, :n_view_enc*2])
                    # print('img_embeds time elapse', time.time() - st_time)
                    # print('img_embeds.shape', img_embeds.shape)

                    # img_embeds_x: B x N x n_view_enc x nf_hidden
                    img_embeds_a = img_embeds.view(B, N, n_view_enc, 2, -1)[:, :, :, 0]
                    img_embeds_b = img_embeds.view(B, N, n_view_enc, 2, -1)[:, :, :, 1]

                    '''
                    # used for debug
                    anchor = img_embeds_a[0, 0, 0].data.cpu().numpy()
                    positive = img_embeds_a[0, 0, -1].data.cpu().numpy()
                    negative = img_embeds_b[0, 0, 0].data.cpu().numpy()
                    print(anchor.shape, positive.shape, negative.shape)
                    print(np.mean((anchor - positive)**2, -1),
                          np.mean((anchor - negative)**2, -1))

                    exit(0)
                    '''


                    if args.ct_loss:
                        # time contrastive loss
                        margin = 2.0
                        assert n_view_enc == 4
                        anchor = img_embeds_a[:, :, :2]
                        positive = img_embeds_a[:, :, -2:]
                        negative = img_embeds_b[:, :, :2]
                        d_positive = torch.mean((anchor - positive)**2, -1)
                        d_negative = torch.mean((anchor - negative)**2, -1)
                        loss_ct = torch.clamp(margin + d_positive - d_negative, min=0.0).mean()


                    if args.nerf_loss or args.auto_loss:
                        ### aggregate image embeddings from multiple views
                        # state_embeds: B x N x nf_hidden
                        # st_time = time.time()
                        state_embeds = model.encode_state(
                            img_embeds_a,
                            poses[:, :, :n_view_enc],
                            mask_nodes)
                        # print('state_embeds time elapse', time.time() - st_time)

                        # print('state_embeds.shape', state_embeds.shape)


                        ### decode the image embedding to render the original image
                        # rgb: N_rand x 3
                        # extras
                        # target_s: N_rand x 3
                        camera_info = {
                            'hwf': hwf,
                            'poses': poses[:, :, -n_view_dec:],
                            'near': near,
                            'far': far}
                        # st_time = time.time()

                        imgs_to_render = imgs[:, :, -n_view_dec*2:].view(B, N, n_view_dec, 2, C, H, W)
                        imgs_to_render = imgs_to_render[:, :, :, 0]


                        if args.nerf_loss:
                            rgb, extras, target_s = model.render_rays(
                                state_embeds, camera_info, imgs_to_render)
                            # print('render_rays time elapse', time.time() - st_time)


                            img_loss = img2mse(rgb, target_s)

                            loss_nerf = img_loss  # autoencoding loss

                            if 'rgb0' in extras:
                                img_loss0 = img2mse(extras['rgb0'], target_s)
                                loss_nerf = loss_nerf + img_loss0

                        elif args.auto_loss:
                            imgs_pred = model.decode_img(state_embeds, camera_info)
                            loss_auto = img2mse(imgs_pred, imgs_to_render)


                    '''
                    ### forward prediction
                    loss_pred = 0.

                    state_cur = state_embeds[:, :args.n_his, :]
                    for j in range(args.n_roll):
                        action_cur = actions[:, j:j + args.n_his, :]

                        # state_cur: B x n_his x nf_hidden
                        # actions_cur: B x n_his x act_dim
                        # state_pred: B x nf_hidden
                        state_pred = model.dynamics_prediction(state_cur, action_cur)

                        # state_cur: B x n_his x nf_hidden
                        state_cur = torch.cat([state_cur[:, 1:], state_pred[:, None]], 1)

                        loss_pred += F.mse_loss(state_pred, state_embeds[:, j + args.n_his])

                    loss_pred = loss_pred / args.n_roll
                    '''


                ### optimization
                loss = 0.
                B = imgs.size(0)

                if args.ct_loss:
                    loss += loss_ct
                    meter_loss_ct.update(loss_ct.item(), B)

                if args.nerf_loss:
                    loss += loss_nerf
                    meter_loss_nerf.update(loss_nerf.item(), B)

                if args.auto_loss:
                    loss += loss_auto
                    meter_loss_auto.update(loss_auto.item(), B)

                meter_loss.update(loss.item(), B)

                if phase == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                ### log and save ckp
                if i % args.log_per_iter == 0:
                    log = '%s [%d/%d][%d/%d] LR: %.6f, Loss: %.6f (%.6f)' % (
                        phase, epoch, args.n_epoch, i, len(dataloaders[phase]),
                        get_lr(optimizer),
                        loss.item(), meter_loss.avg)

                    if args.ct_loss:
                        log += ', loss_ct: %.6f (%.6f)' % (loss_ct.item(), meter_loss_ct.avg)

                    if args.nerf_loss:
                        log += ', loss_nerf: %.6f (%.6f)' % (loss_nerf.item(), meter_loss_nerf.avg)

                    if args.auto_loss:
                        log += ', loss_auto: %.6f (%.6f)' % (loss_auto.item(), meter_loss_auto.avg)

                    print()
                    print(log)
                    log_fout.write(log + '\n')
                    log_fout.flush()

                if phase == 'train' and i % args.ckp_per_iter == 0:
                    torch.save(model.state_dict(), '%s/net_epoch_%d_iter_%d.pth' % (args.outf, epoch, i))

            log = '%s [%d/%d] Loss: %.6f, Best valid: %.6f' % (phase, epoch, args.n_epoch, meter_loss.avg, best_valid_loss)
            print(log)
            log_fout.write(log + '\n')
            log_fout.flush()

            if phase == 'valid':
                if meter_loss.avg < best_valid_loss:
                    best_valid_loss = meter_loss.avg
                    torch.save(model.state_dict(), '%s/net_best.pth' % (args.outf))




if __name__=='__main__':
    torch.set_default_tensor_type('torch.cuda.FloatTensor')

    train()
