from .model import EncoderProcesserDecoder
import torch.nn as nn
import torch
from .utilities.Normalization import GaussNormalizer
from .utilities.calc_features import calc_C3D4_element_features, calc_C3D4_face_features, calc_C3D8_element_features, calc_C3D8_face_features
import os
import math
from utils import NodeType
import time

def build_mlp(in_size, hidden_size, out_size, lay_norm=True):

    module = nn.Sequential(nn.Linear(in_size, hidden_size), nn.ReLU(), nn.Linear(hidden_size, hidden_size), nn.ReLU(), nn.Linear(hidden_size, hidden_size), nn.ReLU(), nn.Linear(hidden_size, out_size))
    if lay_norm: return nn.Sequential(module,  nn.LayerNorm(normalized_shape=out_size))
    return module
        
class MeshForma(nn.Module):

    def __init__(self, config, device, model_dir='checkpoint/simulator_normed.pth') -> None:
        super(MeshForma, self).__init__()
        
        self.device = device
        model_config = config["model_features"]
        self.node_input_size = model_config["node_input_size"]
        self.node_output_size =  model_config["node_output_size"]
        self.model_dir = model_dir
        if (model_config["element_type"] == "C3D4"):
            face_num, node_num = 4, 4
        elif (model_config["element_type"] == 'C3D8'):
            face_num, node_num = 6, 8
        self.noise_std = model_config["noise_std"]
        self._output_normalizer = GaussNormalizer(size=self.node_output_size, name='output_normalizer', device=device)
        self._node_normalizer = GaussNormalizer(size=self.node_input_size, name='node_normalizer', device=device)
        self._element_normalizer = GaussNormalizer(size=model_config["element_input_size"], name='element_normalizer', device=device)
        self._face_normalizer = GaussNormalizer(size=model_config["face_input_size"], name='face_normalizer', device=device)
        self._scripted_normalizer = GaussNormalizer(size=model_config["scripted_input_size"], name='scripted_node_normalizer', device=device)
        self._face_face_normalizer = GaussNormalizer(size=model_config["face_face_input_size"], name='face_face_normalizer', device=device)
        self._safe_zero = torch.tensor(0, dtype=torch.float32, requires_grad=False, device=device)
        self.e_n_norm = GaussNormalizer(size=node_num * 3, name='e_n_norm', device=device)
        self.f_n_norm = GaussNormalizer(size=3 * node_num // face_num * 3, name='f_n_norm', device=device)
        self.e_f_norm = GaussNormalizer(size=face_num * 3, name='e_f_norm', device=device)
        self.model = EncoderProcesserDecoder(message_passing_num=model_config["message_passing_num"], node_input_size = self.node_input_size, node_output_size = self.node_output_size, 
                                             element_input_size = model_config["element_input_size"], face_input_size = model_config["face_input_size"], 
                                             face_face_input_size = model_config["face_face_input_size"], scripted_input_size = model_config["scripted_input_size"],
                                             face_num = face_num, node_num = node_num, device = device).to(device)

        print('MPM Simulator model initialized')

    def update_node_attr(self, frames, types:torch.Tensor, accumulate):
        node_feature = []

        node_feature = frames#stress
        node_type = torch.squeeze(types.long())
        one_hot = torch.nn.functional.one_hot(node_type, 9)
        node_feature.append(one_hot)
        node_feats = torch.cat(node_feature, dim=1)
        attr = self._node_normalizer(node_feats, accumulate).to(torch.float)
        return attr

    def get_world_pos_noise(self, data_list, noise_std, device):
        world_pos_sequence = data_list["world_pos"]
        type = data_list["node_type"]
        noise = torch.normal(std=float(noise_std), mean=0.0, size=world_pos_sequence.shape).to(device)
        mask = type!=NodeType.NORMAL
        noise[mask.reshape(mask.shape[0]),:]=0
        return noise.to(device)
    
    def calc_element_features(self, data_list, noised_cur_pos, accumulate):
        elements = data_list["elements"]
        element_ins = elements.shape[1]
        elements_pos = noised_cur_pos[elements.reshape(-1)].reshape(-1, element_ins, 3)
        old_elements_pos = data_list["mesh_pos"][elements.reshape(-1)].reshape(-1, element_ins, 3)
        if (elements.shape[1] == 4):
            elements_features = calc_C3D4_element_features(elements, elements_pos, old_elements_pos, self._safe_zero, self.device).to(torch.float)
        else:
            elements_sequ = data_list["elements_sequ"]
            seq_elements_pos = noised_cur_pos[elements_sequ.reshape(-1)].reshape(-1, element_ins, 3)
            seq_old_elements_pos = data_list["mesh_pos"][elements_sequ.reshape(-1)].reshape(-1, element_ins, 3)
            elements_features = calc_C3D8_element_features(elements, seq_elements_pos, seq_old_elements_pos, self._safe_zero, self.device).to(torch.float)
        volume = elements_features[:, 0]
        elements_features = self._element_normalizer(elements_features, accumulate)
        
        return volume, elements_features
    
    def calc_face_features(self, data_list, noised_cur_pos, accumulate):
        #get faces attributes
        faces = data_list["faces"]
        face_ins = faces.shape[1]
        faces_pos = noised_cur_pos[faces.reshape(-1)].reshape(-1, face_ins, 3)
        old_faces_pos = data_list["mesh_pos"][faces.reshape(-1)].reshape(-1, face_ins, 3)
        if (face_ins == 3):
            faces_features = calc_C3D4_face_features(faces_pos, old_faces_pos, self._safe_zero).to(torch.float)
        else:
            faces_sequ = data_list["faces_sequ"]
            seq_faces_pos = noised_cur_pos[faces.reshape(-1)].reshape(-1, face_ins, 3)
            seq_old_faces_pos = data_list["mesh_pos"][faces_sequ.reshape(-1)].reshape(-1, face_ins, 3)
            faces_features = calc_C3D8_face_features(seq_faces_pos, seq_old_faces_pos, self._safe_zero).to(torch.float)
        faces_features = self._face_normalizer(faces_features, accumulate)
               
        return faces_features         

    def calc_normal(self, x): #[N * 3 * 3]
        vec_ab = x[:, 0] - x[:, 1]
        vec_ac = x[:, 0] - x[:, 2]
        
        normal = torch.cross(vec_ab, vec_ac, dim=-1)
        norm = torch.linalg.norm(normal, dim=-1, keepdim=True)
        return normal / (norm + 1e-8)
 
    def calc_other_features(self, data_list, noised_cur_pos, device, accumulate):
        future = data_list["next_pos_need"]
        faces_faces_features = []
        faces_faces = data_list["faces_to_faces"]
        faces = data_list["faces"]
        face_ins = faces.shape[1]
        faces_pos = noised_cur_pos[faces.reshape(-1)].reshape(-1, face_ins, 3)
        faces_faces_pos = faces_pos[faces_faces.reshape(-1),:,:].reshape(-1, 2, face_ins, 3)
        node_type = data_list["node_type"]
        
        faces_faces_S = torch.mean(faces_faces_pos, dim = 2)
        faces_faces_pos[:, 0, :, :] = future[faces[faces_faces[:, 0]].reshape(-1)].reshape(-1, face_ins, 3)
        faces_faces_S[:, 0, :] = torch.mean(faces_faces_pos[:, 0, :, :], axis = 1)        
        faces_faces_features.append(faces_faces_S[:, 1] - faces_faces_S[:, 0])
        for i in range(face_ins):
            faces_faces_features.append(faces_faces_S[:, 0] - noised_cur_pos[faces[faces_faces[:, 1], i]])
        for i in range(face_ins):
            faces_faces_features.append(faces_faces_S[:, 1] - future[faces[faces_faces[:, 0], i]])

        pos_s = faces_pos[faces_faces[:, 0], 0:3, :] #N * 3 * 3
        n_s = self.calc_normal(pos_s)
        pos_r = faces_pos[faces_faces[:, 1], 0:3, :] #N * 3 * 3
        n_r = self.calc_normal(pos_r)
        faces_faces_features.append(n_s)
        faces_faces_features.append(n_r)
        
        faces_faces_features = torch.concat(faces_faces_features, dim = 1)
        faces_faces_features = self._face_face_normalizer(faces_faces_features, accumulate)

        scripted_features = torch.zeros((faces.shape[0], face_ins * (3 + 1))).to(device)
        for i in range(face_ins):
            mask = (node_type[faces[:, i]] == NodeType.HANDLE)
            mask = mask.reshape(-1)
            scripted_features[mask, i * 4: (i * 4 + 4)] = torch.concat([future[faces[mask, i]] - noised_cur_pos[faces[mask, i]],torch.ones(mask.sum()).to(device).reshape(-1, 1)], dim = 1).to(torch.float)
        scripted_features = self._scripted_normalizer(scripted_features, accumulate)
    
        return faces_faces_features.to(torch.float), scripted_features.to(torch.float)        
            
    def pos_to_velocity(self, noised_frames, next_pos):

        velo_next = next_pos - noised_frames
        return velo_next    
    
    def get_stress_noise(self, data_list, noise_std, device):
        world_pos_sequence = data_list["stress"]
        type = data_list["node_type"]
        noise = torch.normal(std=float(2000), mean=0.0, size=world_pos_sequence.shape).to(device)
        mask = type!=NodeType.NORMAL
        noise[mask.reshape(mask.shape[0]),:]=0
        return noise.to(device) 
        
    def forward(self, data_list, noise_flag):
        
        accumulate = self.training
        
        if noise_flag:
            world_pos_sequence_noise = self.get_world_pos_noise(data_list, self.noise_std, self.device)
            #data_list["stress"] = data_list["stress"] + self.get_stress_noise(data_list, self.noise_std, self.device)
        else:
            world_pos_sequence_noise = torch.zeros(data_list["world_pos"].shape).to(self.device)
            #world_pos_sequence_noise = self.get_world_pos_noise(data_list, self.noise_std, self.device)
        cur_pos = data_list["world_pos"]
        node_type = data_list["node_type"]
        noised_cur_pos = cur_pos + world_pos_sequence_noise
        target = data_list["target"]    

        noised_frames = noised_cur_pos
        if ("stress" in data_list):
            node_attr = self.update_node_attr([noised_cur_pos - data_list["last_pos"], data_list["stress"]], node_type, accumulate)        
            
        data_list["world_pos"] = noised_cur_pos
        data_list["node_features"] = node_attr
        
        #get element attributes
        data_list["volume"], data_list["element_features"] = self.calc_element_features(data_list, noised_cur_pos, accumulate)
        
        #get face attributes
        data_list["face_features"] = self.calc_face_features(data_list, noised_cur_pos, accumulate)
        
        elements, faces = data_list["elements"], data_list["faces"]
        pos = data_list["world_pos"]
        element_pos = pos[elements.reshape(-1, 1)].reshape(elements.shape[0], -1, 3)
        element_center_pos = torch.mean(element_pos, dim = 1, keepdim = True)
        data_list["e_n_vec"] = self.e_n_norm((element_pos - element_center_pos).reshape(elements.shape[0], -1), accumulate)
        face_pos = pos[faces.reshape(-1, 1)].reshape(faces.shape[0], -1, 3)
        face_center_pos = torch.mean(face_pos, dim = 1, keepdim = True)
        data_list["f_n_vec"] = self.f_n_norm((face_pos - face_center_pos).reshape(faces.shape[0], -1), accumulate)
        e_f_pos = face_center_pos[data_list["cells_faces"].reshape(-1, 1),:,:].reshape(elements.shape[0], -1, 3)
        data_list["e_f_vec"] = self.e_f_norm((e_f_pos - element_center_pos).reshape(elements.shape[0], -1), accumulate)
        
        #get face_to_face and scripted features
        data_list["face_to_face_features"], data_list["scripted_features"] = self.calc_other_features(data_list, noised_cur_pos, self.device, accumulate)
        predicted = self.model(data_list)   
        if ("stress" in data_list):
            target_velocity = torch.concat([self.pos_to_velocity(noised_frames, target[:, :3]), target[:,3:].reshape(target.shape[0], -1)], dim = 1)    
        else:
            target_velocity = self.pos_to_velocity(noised_frames, target[:, :3])  

        target_velocity_normalized = self._output_normalizer(target_velocity, accumulate)

        return predicted, target_velocity_normalized
        
    def load_checkpoint(self, ckpdir=None):
        
        if ckpdir is None:
            ckpdir = self.model_dir
        dicts = torch.load(ckpdir, map_location=self.device)
        self.load_state_dict(dicts['model'])

        keys = list(dicts.keys())
        keys.remove('model')

        for k in keys:
            v = dicts[k]
            for para, value in v.items():
                object = eval('self.'+k)
                setattr(object, para, value)
            
            object = eval('self.' + k)

        print("Simulator model loaded checkpoint %s"%ckpdir)

    def save_checkpoint(self, savedir=None):
        if savedir is None:
            savedir=self.model_dir

        os.makedirs(os.path.dirname(self.model_dir), exist_ok=True)
        
        model = self.state_dict()
        _output_normalizer = self._output_normalizer.get_variable()
        _node_normalizer  = self._node_normalizer.get_variable()
        _element_normalizer = self._element_normalizer.get_variable()
        _face_normalizer = self._face_normalizer.get_variable()
        _scripted_normalizer = self._scripted_normalizer.get_variable()
        _face_face_normalizer = self._face_face_normalizer.get_variable()
        _e_f_norm = self.e_f_norm.get_variable()
        _e_n_norm = self.e_n_norm.get_variable()
        _f_n_norm = self.f_n_norm.get_variable()
        to_save = {'model':model, '_output_normalizer':_output_normalizer, '_node_normalizer':_node_normalizer, 
                   '_element_normalizer':_element_normalizer, '_face_normalizer':_face_normalizer,
                   '_scripted_normalizer':_scripted_normalizer, '_face_face_normalizer':_face_face_normalizer,
                   "e_f_norm":_e_f_norm, "e_n_norm":_e_n_norm, "f_n_norm":_f_n_norm}

        torch.save(to_save, savedir)
        print('Simulator model saved at %s'%savedir)
        
    def output_parameter(self):
        print("run output_parameter")
        print(type(self.named_parameters()))
        # 遍历模型参数
        for name, param in self.named_parameters():
            print(name, param)