'''
Author: Wenhao Ding
Email: wenhaod@andrew.cmu.edu
Date: 2020-10-12 22:14:36
LastEditTime: 2021-05-31 22:57:05
Description: 
'''

import numpy as np
import argparse
from skimage.io import imsave

import torch
import torch.utils.data
from torch.autograd import Variable
from utils import CUDA, CPU, kaiming_init, make_gif

from data_loader import ObjectGridDataset, ObjectTreeDataset, ObjectGrammarDataset
from model import TreeVAE, GridVAE, GrammarVAE
from renderer import Renderer


# global hyper parameters
position_scale = 10
z_dim = 64
iteration = 150
trials = 10
print_iter = 1


def train(model_name, train_w_gt):
    # for adjusting the weight of kl divergence, from 0 to 1
    def beta_scheduler(epoch, gamma=10.0, epoch_th=100):
        beta = 1.0/(1.0+np.exp(-(epoch-epoch_th)/gamma))
        return beta

    if model_name == 'grid':
        model = CUDA(GridVAE(z_dim))
        epochs = 10000
        beta_max = 0.1
        save_iter = 20
        num_workers = 8
        gamma = 10.0
        epoch_th = 200
        batch_size = 128
        lr = 0.001
    elif model_name == 'tree':
        model = CUDA(TreeVAE(z_dim))
        epochs = 1000
        beta_max = 0.2
        save_iter = 1
        num_workers = 0 
        gamma = 10.0
        epoch_th = 100
        batch_size = 128
        lr = 0.001
    elif model_name == 'grammar':
        model = CUDA(GrammarVAE(z_dim, max_length=25, rule_dim=9, attri_dim=6))
        epochs = 10000
        beta_max = 0.1
        save_iter = 20
        num_workers = 0
        gamma = 10.0
        epoch_th = 200
        batch_size = 128
        lr = 0.001

    model.apply(kaiming_init)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    if model_name == 'grid':
        dataset = ObjectGridDataset('./data/object_grid_dataset.npy', './data/object_oracle_grid_dataset.npy', position_scale, train_w_gt)
        train_iter = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    elif model_name == 'tree':
        dataset = ObjectTreeDataset('./data/object_tree_dataset.npy', './data/object_oracle_tree_dataset.npy', position_scale, train_w_gt)
        train_iter = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=dataset.collate_fn)
    elif model_name == 'grammar':
        dataset = ObjectGrammarDataset('./data/object_grammar_dataset.npy', './data/object_oracle_grammar_dataset.npy', position_scale, train_w_gt)
        train_iter = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

    for e_i in range(epochs):
        total_loss_list = []
        kldiv_loss_list = []
        recon_loss_list = []
        beta = beta_max * beta_scheduler(e_i, gamma, epoch_th)
        for t_i, batch_data in enumerate(train_iter):
            if model_name in ['grid', 'grammar']:
                batch_data = CUDA(batch_data.float())

            recon_loss, kldiv_loss, _ = model(batch_data)
            total_loss = recon_loss + beta*kldiv_loss

            total_loss_list.append(total_loss.item())
            kldiv_loss_list.append(kldiv_loss.item())
            recon_loss_list.append(recon_loss.item())

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

        avg_total_loss = np.mean(total_loss_list)
        avg_kldiv_loss = np.mean(kldiv_loss_list)
        avg_recon_loss = np.mean(recon_loss_list)
        if (e_i+1) % print_iter == 0:
            print('Epoch: {}, Beta: {}, Total loss: {}, KLD: {}, Recon loss: {}'.format(e_i, beta, avg_total_loss, avg_kldiv_loss, avg_recon_loss))

        if ((e_i+1) % save_iter == 0):
            model.save_model('./models/'+model_name+'.model.pth')


def sample(model_name):
    renderer = Renderer('./data/cube.obj', './data/circle.obj', './data/ref.png')
    if model_name == 'grid':
        model = CUDA(GridVAE(z_dim))
        model.load_model('./models/grid.model.pth')
    elif model_name == 'tree':
        model = CUDA(TreeVAE(z_dim))
        model.load_model('./models/tree.model.pth')
    elif model_name == 'grammar':
        model = CUDA(GrammarVAE(z_dim, max_length=25, rule_dim=9, attri_dim=6))
        model.load_model('./models/grammar.model.pth')

    z = Variable(CUDA(torch.zeros(1, z_dim)))
    for s_i in range(10):
        # random initialize z
        z_init = np.random.normal(0.0, 1.0, size=(1, z_dim))
        z.data.copy_(CUDA(torch.from_numpy(z_init)))

        # generate unconditioned image
        scene_parameters = model.decode(z, position_scale)
        box_poses = scene_parameters['box_poses']
        box_colors = scene_parameters['box_colors']
        plane_poses = scene_parameters['plane_poses']
        plane_colors = scene_parameters['plane_colors']

        renderer(box_poses, box_colors, plane_poses, plane_colors)
        imsave('./samples/'+model_name+'_samples_'+str(s_i)+'.png', renderer.render())


def get_oracle_z(model_name):
    if model_name == 'grid':
        model = CUDA(GridVAE(z_dim))
        model.load_model('./models/grid.model.pth')
        model.eval()
    elif model_name == 'tree':
        model = CUDA(TreeVAE(z_dim))
        model.load_model('./models/tree.model.pth')
        model.eval()
    elif model_name == 'grammar':
        model = CUDA(GrammarVAE(z_dim, max_length=25, rule_dim=9, attri_dim=6))
        model.load_model('./models/grammar.model.pth')
        model.eval()

    # here we only load the oracle dataset
    if model_name == 'grid':
        dataset = ObjectGridDataset('./data/object_oracle_grid_dataset.npy', None, position_scale, train_w_gt=False)
        train_iter = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False)
    elif model_name == 'tree':
        dataset = ObjectTreeDataset('./data/object_oracle_tree_dataset.npy', None, position_scale, train_w_gt=False)
        train_iter = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=dataset.collate_fn)
    elif model_name == 'grammar':
        dataset = ObjectGrammarDataset('./data/object_oracle_grammar_dataset.npy', None, position_scale, train_w_gt=False)
        train_iter = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False)

    renderer = Renderer('./data/cube.obj', './data/circle.obj', './data/ref.png')
    for t_i, batch_data in enumerate(train_iter):
        if model_name in ['grid', 'grammar']:
            batch_data = CUDA(batch_data.float())

        # save the original image (only works for grid-VAE and grammar-VAE)
        if model_name == 'grid':
            plane_poses = batch_data[:, 0:2, 0:3] * CUDA(torch.tensor([position_scale, position_scale, 2*np.pi]))[None][None]
            plane_colors = batch_data[:, 0:2, 3:6]
            box_poses = batch_data[:, 2:10, 0:3] * CUDA(torch.tensor([position_scale, position_scale, 2*np.pi]))[None][None]
            box_colors = batch_data[:, 2:10, 3:6]
            renderer(box_poses, box_colors, plane_poses, plane_colors)
            imsave('./samples/grid_oracle_original.png', renderer.render())
        elif model_name == 'grammar':
            x_ = batch_data.transpose(-2, -1)
            scene_param = model.decode_from_x(x_, position_scale)
            renderer(scene_param['box_poses'], scene_param['box_colors'], scene_param['plane_poses'], scene_param['plane_colors'])
            imsave('./samples/grammar_oracle_original.png', renderer.render())

        # encode
        _, _, z_list = model(batch_data)

        # save the decoded image
        scene_param = model.decode(z_list[0], position_scale)
        renderer(scene_param['box_poses'], scene_param['box_colors'], scene_param['plane_poses'], scene_param['plane_colors'])
        imsave('./samples/'+model_name+'_oracle_reconstructed.png', renderer.render())
        break # we only use the first sample of the oracle dataset

    np.save('./data/'+model_name+'.oracle.z.npy', CPU(z_list[0]))


def latent_search(model_name, use_oracle):
    """ This method search in the latent space with gradient descent. There are three model options:
        - grid: use original VAE
        - tree: use T-VAE
        - grammar: use GVAE
    """
    renderer = Renderer('./data/cube.obj', './data/circle.obj', './data/ref.png')
    if model_name == 'grid':
        model = CUDA(GridVAE(z_dim))
        model.load_model('./models/grid.model.pth')
        model.eval()
        if use_oracle:
            oracle_z = np.load('./data/grid.oracle.z.npy')
        attack_lr = 0.1
        nonoracle_bias = 1.0
    elif model_name == 'tree':
        model = CUDA(TreeVAE(z_dim))
        model.load_model('./models/tree.model.pth')
        model.eval()
        if use_oracle:
            oracle_z = np.load('./data/tree.oracle.z.npy')
        attack_lr = 0.1
        nonoracle_bias = 1.0
    elif model_name == 'grammar':
        model = CUDA(GrammarVAE(z_dim, max_length=25, rule_dim=9, attri_dim=6))
        model.load_model('./models/grammar.model.pth')
        if use_oracle:
            oracle_z = np.load('./data/grammar.oracle.z.npy')
        attack_lr = 0.1
        nonoracle_bias = 1.0

    z = Variable(CUDA(torch.zeros(1, z_dim)), requires_grad=True)
    optimizer = torch.optim.Adam([z], lr=attack_lr)

    loss_list = []
    for t_ii in range(trials):
        if use_oracle:
            z_init = np.random.normal(oracle_z, 0.5*np.ones(z_dim))  # initialization with orcale codes
            print('Use oracle initialization')
            oracle_name = 'w'
        else:
            z_init = np.random.normal(0.0, 1.0, size=(1, z_dim)) + nonoracle_bias # randomly initialize z from a bad initialization
            print('Use non-oracle initialization')
            oracle_name = 'wo'
        z.data.copy_(CUDA(torch.from_numpy(z_init)))

        loss_list_inner = []
        for t_i in range(iteration):
            scene_parameters = model.decode(z, position_scale)
            box_poses = scene_parameters['box_poses']
            box_colors = scene_parameters['box_colors']
            plane_poses = scene_parameters['plane_poses']
            plane_colors = scene_parameters['plane_colors']
            loss = renderer(box_poses, box_colors, plane_poses, plane_colors)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            print('[{}/{}] [{}/{}], Total loss: {}'.format(t_ii, trials, t_i, iteration, loss))
            loss_list_inner.append(loss.item())
            imsave('/tmp/'+str(t_i)+'.png', renderer.render())
        loss_list.append(loss_list_inner)

    loss_list = np.array(loss_list)
    np.save('./samples/'+model_name+'_latent_search_loss_'+oracle_name+'_oracle.npy', loss_list)
    make_gif('./data/'+model_name+'_latent_search_'+oracle_name+'_oracle.gif')


def knowledge_latent_search():
    """ This method search in the latent space with knowledge integration
    """
    renderer = Renderer('./data/cube.obj', './data/circle.obj', './data/ref.png')

    model = CUDA(TreeVAE(z_dim))
    model.load_model('./models/tree.model.pth')
    model.eval()
    attack_lr = 0.1
    knowledge_lr = 0.01
    knowledge_itr = 5
    alpha = 0.01
    z = Variable(CUDA(torch.zeros(1, z_dim)), requires_grad=True)
    optimizer = torch.optim.Adam([z], lr=attack_lr)
    bias = -1.0

    loss_list = []
    for t_ii in range(trials):
        z_init = np.random.normal(0.0, 1.0, size=(1, z_dim)) + bias # randomly initialize z from a bad initialization
        z.data.copy_(CUDA(torch.from_numpy(z_init)))

        knowledge_loss_list = []
        loss_list_inner = []
        for t_i in range(iteration):
            # one step for task optimization
            scene_param = model.decode(z, position_scale)
            loss = renderer(scene_param['box_poses'], scene_param['box_colors'], scene_param['plane_poses'], scene_param['plane_colors'])
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # proximal optimization with knowledge
            z_original = torch.clone(z.detach())
            inner_optimizer = torch.optim.Adam([z], lr=knowledge_lr)
            for k_i in range(knowledge_itr):
                _, kg_loss = model.decode(z, position_scale, use_kg=True)
                dist_to_original = torch.nn.functional.mse_loss(z_original, z)
                knowledge_loss = alpha*dist_to_original + kg_loss
                inner_optimizer.zero_grad()
                knowledge_loss.backward()
                inner_optimizer.step()

            print('[{}/{}] [{}/{}], Total loss: {}, Knowledge loss: {}'.format(t_ii, trials, t_i, iteration, loss, knowledge_loss))
            loss_list_inner.append(loss.item())
            knowledge_loss_list.append(knowledge_loss.item())
            imsave('/tmp/'+str(t_i)+'.png', renderer.render())
        loss_list.append(loss_list_inner)

    loss_list = np.array(loss_list)
    np.save('./samples/tree_latent_search_loss_w_kg.npy', loss_list)
    np.save('./samples/tree_latent_search_knowledge_loss.npy', knowledge_loss_list)
    make_gif('./data/tree_latent_search_w_kg.gif')


def random_search(use_oracle):
    """ This method search the optimal pose in the physical space.
        The initial pose could be three cases: Gausian, Grid-VAE oracle, Tree-VAE oracle.
        Oracle means we use the mean z from all training data.
    """

    def generate_close_optimal_initialization():
        pose_noise = np.random.uniform(-1, 1, size=(8, 3))
        color_noise = np.random.uniform(-1, 1, size=(8, 3))
        box_poses = pose_noise + np.array([[-7, -7, 0], [-3, -3, 0], [-3, -7, 0], [-7, -3, 0], [7, 7, 0], [3, 3, 0], [3, 7, 0], [7, 3, 0]], dtype=np.float32)
        box_colors = color_noise + np.array([[1, 0, 0], [1, 0, 0], [1, 0, 0], [1, 0, 0], [0, 1, 0], [0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.float32)

        pose_noise = np.random.uniform(-1, 1, size=(2, 3))
        color_noise = np.random.uniform(-1, 1, size=(2, 3))
        plane_poses = pose_noise + np.array([[3, 3, 0], [-3, -3, 0]], dtype=np.float32)
        plane_colors = color_noise + np.array([[0, 0, 1.0], [0, 0.0, 1.0]], dtype=np.float32)
        
        return box_poses[None], box_colors[None], plane_poses[None], plane_colors[None]

    def generate_non_collision_initialization():
        object_grid_dataset_path = './data/object_grid_dataset.npy'
        # a dict with index as the key
        data = np.load(object_grid_dataset_path)
        num_examples = len(data)
        select_index = np.random.randint(0, num_examples)

        init_data = data[select_index, :, :] # [10, 6]
        plane_poses = init_data[0:2, 0:3]
        plane_colors = init_data[0:2, 3:6]
        box_poses = init_data[2:10, 0:3]
        box_colors = init_data[2:10, 3:6]
        return box_poses[None], box_colors[None], plane_poses[None], plane_colors[None]

    # define render
    renderer = Renderer('./data/cube.obj', './data/circle.obj', './data/ref.png')

    attack_lr = 0.1
    loss_list = []
    for t_ii in range(trials):
        if use_oracle:
            box_poses, box_colors, plane_poses, plane_colors = generate_close_optimal_initialization()
            print('Use oracle initialization')
            oracle_name = 'w'
        else:
            box_poses, box_colors, plane_poses, plane_colors = generate_non_collision_initialization()
            print('Use non-oracle initialization')
            oracle_name = 'wo'
        box_poses = Variable(CUDA(torch.from_numpy(box_poses).float()), requires_grad=True)
        box_colors = Variable(CUDA(torch.from_numpy(box_colors).float()), requires_grad=True)
        plane_poses = Variable(CUDA(torch.from_numpy(plane_poses).float()), requires_grad=True)
        plane_colors = Variable(CUDA(torch.from_numpy(plane_colors).float()), requires_grad=True)

        optimizer = torch.optim.Adam([box_poses, box_colors, plane_poses, plane_colors], lr=attack_lr)

        loss_list_inner = []
        for t_i in range(iteration):
            box_colors_new = torch.sigmoid(box_colors)
            plane_colors_new = torch.sigmoid(plane_colors)
            loss = renderer(box_poses, box_colors_new, plane_poses, plane_colors_new)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            print('[{}/{}] [{}/{}], Total loss: {}'.format(t_ii, trials, t_i, iteration, loss))
            loss_list_inner.append(loss.item())
            imsave('/tmp/'+str(t_i)+'.png', renderer.render())

        loss_list.append(loss_list_inner)

    loss_list = np.array(loss_list)
    np.save('./samples/random_search_loss.npy', loss_list)
    make_gif('./data/random_search_'+oracle_name+'_oracle.gif')



parser = argparse.ArgumentParser(description='Box Placement Experiment')
parser.add_argument('--type', type=int, default=0)
parser.add_argument('--model', type=str, default='tree', help='tree or grid or grammar')
parser.add_argument('--oracle', action='store_true', help='use oracle initialization or not')
parser.add_argument('--train_w_gt', type=bool, default=True, help='add oracle data points in training dataset or not')
args = parser.parse_args()

if args.type == 0:
    train(args.model, args.train_w_gt) 
elif args.type == 1:
    latent_search(args.model, args.oracle)
elif args.type == 2:
    random_search(args.use_oracle)
elif args.type == 3:
    knowledge_latent_search()
elif args.type == 4:
    get_oracle_z(args.model)
elif args.type == 5:
    sample(args.model)
else:
    raise ValueError('No such a type')
