import os
import math
import json
import torch
import numpy as np

from torch.utils.data import Dataset

from util import *
from dataset.prepross import *

def load_labels_from_txt(txt_path):
    with open(txt_path, 'r') as f:
        lines = f.readlines()
    class_names = [line.strip() for line in lines]
    label_dict = {name: idx for idx, name in enumerate(class_names)}
    return label_dict

class zoom_get_mo(Dataset):
    def __init__(self, cfg,
                dataset_name,transform):
        super().__init__()
        self.cfg = cfg
        self.max_sh_degree = self.cfg.data.sh_step
        self.transform = transform
        NET_3DGS_ROOT = self.cfg.data.root
        self.root = NET_3DGS_ROOT
        self.dataset_name = dataset_name
        self.label_dict = load_labels_from_txt(os.path.join(NET_3DGS_ROOT, "label.txt"))
        self.ply_list = [os.path.join(os.path.join(NET_3DGS_ROOT, self.dataset_name, i, "point_cloud.ply")) for i in os.listdir(os.path.join(NET_3DGS_ROOT, self.dataset_name))]


    def __len__(self):
        return len(self.ply_list)
    
    def __getitem__(self, index):
        ply_path = self.ply_list[index]
        label_idx = self.label_dict[ply_path.split("/")[-2][:-5]]
        file_path = os.path.join(self.root, ply_path)
        gs_zoom_base, gs_zoom0, gs_zoom1, gs_zoom2, cam_random, zoom1_cam, zoom2_cam = self.get_gs(file_path)

        sample = {  'cam': cam_random,
                    'cam1': zoom1_cam,
                    'cam2': zoom2_cam,
                    'scale_base': gs_zoom_base,
                    'scale0': gs_zoom0,
                    'scale1': gs_zoom1,
                    'scale2': gs_zoom2,
                    'label': label_idx,
                    'path': file_path
        }
        return sample
    
    def get_gs(self, path):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        gs = loadply(path, self.max_sh_degree)
        if self.transform:
            for transform in self.transform:
                gs = transform(gs)
        RGB = SH2RGB(gs["sh"]).squeeze(-1)
        gs = [gs['xyz'],gs['opacity'],gs['scale'],gs['q'],RGB] #xyz,opacity,scale,Rotation matrix,sh
        gs = np.concatenate(gs, axis=1)
        gs_tensor = torch.from_numpy(gs).to(device).float().contiguous()

        gs_cen, max_xyz = center_gs(gs_tensor)
        gs_zoom_base = fps(gs_cen, 1024)
        for i in range(360):
            base_cam = generate_external_lookat_camera(gs_cen[:, :3], k=max(max_xyz),fov_deg=40)
            cam_random = base_cam
            zoom1_cam = base_cam.copy(); zoom1_cam["fov_deg"] = 20
            zoom2_cam = base_cam.copy(); zoom2_cam["fov_deg"] = 10
            idx0 = select_near_half_frustum_points(gs_cen[:,:3], cam_random)
            idx1 = select_near_half_frustum_points(gs_cen[:,:3], zoom1_cam)
            idx2 = select_near_half_frustum_points(gs_cen[:,:3], zoom2_cam)
            if idx0.shape[0] >= 1024 and idx1.shape[0] >= 1024 and idx2.shape[0] >= 1024 :
                break

        gs_zoom0 = fps(gs_cen[idx0], 1024)
        gs_zoom1 = fps(gs_cen[idx1], 1024)
        gs_zoom2 = fps(gs_cen[idx2], 1024)

        return gs_zoom_base, gs_zoom0, gs_zoom1, gs_zoom2, cam_random, zoom1_cam, zoom2_cam