import json

import torch
import torch.nn as nn

import pickle
import os
import glob

import trimesh
from termcolor import colored

from model.AE_baseline_1 import Sg2ScAEModelBaseline1 as baseline1
from model.AE_Incremental3D import Sg2ScAEModelIncremental3D as increment3dG
from model.AE_baseline_2 import Sg2ScAEModelBaeline2 as baseline2

class AE(nn.Module):

    def __init__(self, root="../GT", type='increment3dG', vocab=None,  residual=False, gconv_pooling='avg',  num_box_params=7):
        super().__init__()
        assert type in ['increment3dG', 'baseline1', 'baseline2','baseline3'], '{} is not included'.format(type)
        self.type_ = type


        self.vocab = vocab
        self.epoch = 0
        self.database = os.path.join(root, "DEEPSDF_reconstruction")

        if self.type_ == 'baseline1':

            self.classes_ = sorted(list(set(self.vocab['object_idx_to_name'])))
            self.v1code_base = os.path.join(self.database, 'Codes')
            self.v1mesh_base = os.path.join(self.database, 'Meshes')
            # self.code_dict_path = os.path.join(self.database, 'deepsdf_code.json')
            id_names = os.listdir(self.v1code_base)
            self.code_dict = {}
            for id_name in id_names:
                latent_code = torch.load(os.path.join(self.v1code_base, id_name, 'sdf.pth'), map_location="cpu")[0]
                latent_code = latent_code.detach().numpy()
                self.code_dict[id_name] = latent_code[0]

            self.ae = baseline1( vocab,node_embedding_dim=128,gconv_pooling='avg', gconv_num_layers=6, mlp_normalization='batchnorm', num_box_params=7, shape_input_dim=256,residual=True )

        elif self.type_ == 'increment3dG':
            self.classes_ = sorted(list(set(self.vocab['object_idx_to_name'])))
            self.v1code_base = os.path.join(self.database, 'Codes')
            self.v1mesh_base = os.path.join(self.database, 'Meshes')
            # self.code_dict_path = os.path.join(self.database, 'deepsdf_code.json')
            id_names = os.listdir(self.v1code_base)
            self.code_dict = {}
            for id_name in id_names:
                latent_code = torch.load(os.path.join(self.v1code_base, id_name, 'sdf.pth'), map_location="cpu")[0]
                latent_code = latent_code.detach().numpy()
                self.code_dict[id_name] = latent_code[0]

            self.ae = increment3dG( vocab,node_embedding_dim=128,gconv_pooling='avg', gconv_num_layers=3, mlp_normalization='batchnorm', num_box_params=7, shape_input_dim=256,residual=True)

        elif self.type_ == 'baseline2':
            self.classes_ = sorted(list(set(self.vocab['object_idx_to_name'])))
            self.v1code_base = os.path.join(self.database, 'Codes')
            self.v1mesh_base = os.path.join(self.database, 'Meshes')
            # self.code_dict_path = os.path.join(self.database, 'deepsdf_code.json')
            id_names = os.listdir(self.v1code_base)
            self.code_dict = {}
            for id_name in id_names:
                latent_code = torch.load(os.path.join(self.v1code_base, id_name, 'sdf.pth'), map_location="cpu")[0]
                latent_code = latent_code.detach().numpy()
                self.code_dict[id_name] = latent_code[0]

            self.ae = baseline2( vocab,node_embedding_dim=128,gconv_pooling='avg', gconv_num_layers=3, mlp_normalization='batchnorm', num_box_params=7, shape_input_dim=256,residual=True )
        self.counter = 0


    def forward_incremental_3D_(self,obj_batch_scene_ids, objs, boxes, triples, new_mask,obj_indices,triple_scene_ids ):

        if self.type_ == 'increment3dG':
            boxes, shapes = self.ae.forward_batch_step(obj_batch_scene_ids, objs, boxes, triples, new_mask, obj_indices,triple_scene_ids)
        elif self.type_ == 'baseline1':
            boxes, shapes = self.ae.forward_batch_step(obj_batch_scene_ids,objs, boxes, triples, new_mask,obj_indices,triple_scene_ids)
        elif self.type_ == 'baseline2':
            boxes, shapes = self.ae.forward_batch_step(obj_batch_scene_ids,objs, boxes, triples, new_mask,obj_indices,triple_scene_ids)

        return  boxes, shapes


    def load_networks(self, exp, epoch, strict=True, restart_optim=False):

        if self.type_ == 'baseline1':
            print(colored('Loading baseline1 model from {}...'.format(os.path.join(exp, 'checkpoint', 'model{}.pth'.format(epoch))), 'green'))
            ckpt = torch.load(os.path.join(exp, 'checkpoint', 'model{}.pth'.format(epoch))).state_dict()
            self.ae.load_state_dict(
                ckpt,
                strict=strict
            )
        elif self.type_ == 'increment3dG':
            print(colored('Loading increment3dG model from {}...'.format(os.path.join(exp, 'checkpoint', 'model{}.pth'.format(epoch))), 'green'))
            ckpt = torch.load(os.path.join(exp, 'checkpoint', 'model{}.pth'.format(epoch))).state_dict()
            self.ae.load_state_dict(
                ckpt,
                strict=strict
            )
        elif self.type_ == 'baseline2':
            print(colored('Loading baseline2 model from {}...'.format(os.path.join(exp, 'checkpoint', 'model{}.pth'.format(epoch))), 'green'))
            ckpt = torch.load(os.path.join(exp, 'checkpoint', 'model{}.pth'.format(epoch))).state_dict()
            self.ae.load_state_dict(
                ckpt,
                strict=strict
            )
        elif self.type_ == 'baseline3':
            print(colored('Loading baseline3 model from {}...'.format(os.path.join(exp, 'checkpoint', 'model{}.pth'.format(epoch))), 'green'))
            ckpt = torch.load(os.path.join(exp, 'checkpoint', 'model{}.pth'.format(epoch))).state_dict()
            self.ae.load_state_dict(
                ckpt,
                strict=strict
            )
    def get_closest_vec(self, class_name, shape_vec, box_data):
        import numpy as np

        obj_ids = list(box_data[class_name].keys())
        codes   = np.vstack([self.code_dict[obj_id] for obj_id in obj_ids])   # (N, 256)

        codes_norm      = codes / (np.linalg.norm(codes, axis=1, keepdims=True) + 1e-8)
        shape_vec_np    = shape_vec.detach().cpu().numpy().reshape(-1)
        shape_vec_norm  = shape_vec_np / (np.linalg.norm(shape_vec_np) + 1e-8)

        sims   = np.dot(codes_norm, shape_vec_norm)      # (N,)
       
        id_max = np.argmax(sims)                         
        #print('id_selected:', id_max)

        return obj_ids[id_max], sims[id_max]           


    def decode_g2sv1(self, cats, shape_vecs, box_data, retrieval=False):
        if retrieval:
            vec_list = []
            mesh_list= []
            for (cat, shape_vec) in zip(cats, shape_vecs):
                class_name = self.classes_[cat].strip('\n')
                if class_name == 'floor' or class_name == '_scene_':
                    continue
                name_, vec_ = self.get_closest_vec(class_name, shape_vec, box_data)
                vec_list.append(vec_)
                obj = trimesh.load(os.path.join(self.v1mesh_base,name_,'sdf.ply'))
                mesh_list.append(obj)

        return mesh_list, vec_list

    def save(self, exp, outf, epoch, counter=None):
        if self.type_ == 'baseline1':
            torch.save(self.ae, os.path.join(exp, outf, 'model{}.pth'.format(epoch)))
        elif self.type_ == 'increment3dG':
            torch.save(self.ae, os.path.join(exp, outf, 'model{}.pth'.format(epoch)))
        elif self.type_ == 'baseline2':
            torch.save(self.ae, os.path.join(exp, outf, 'model{}.pth'.format(epoch)))

    def reset_all_scene_cls_states(self):
        if self.type_ == 'increment3dG':
            self.ae.reset_all_scene_states()
