import os
import math
import json
import torch
import numpy as np


from torch.utils.data import Dataset

from util import *
from dataset.prepross import *

class zoom_get_ma(Dataset):
    def __init__(self, cfg, 
                dataset_name,transform):
        super().__init__()
        self.cfg = cfg
        self.max_sh_degree = self.cfg.data.sh_step
        self.transform = transform
        
        self.root = self.cfg.data.root
        
        self.dataset_name = dataset_name

        if self.dataset_name == 'train':
            with open(os.path.join(self.root, 'train.json') , 'r') as f:
                self.file_list = json.load(f)

        if self.dataset_name == 'test':
            with open(os.path.join(self.root, 'test.json') , 'r') as f:
                self.file_list = json.load(f)


    def __len__(self):
        return len(self.file_list)
    
    def __getitem__(self, index):
        s = self.file_list[index]
        label_idx = int(s.split('/')[0])
        file_path = os.path.join(self.root, s)
        gs_zoom_base, gs_zoom0, gs_zoom1, gs_zoom2, cam_random, zoom1_cam, zoom2_cam = self.get_gs(file_path)

        sample = {  'cam': cam_random,
                    'cam1': zoom1_cam,
                    'cam2': zoom2_cam,
                    'scale_base': gs_zoom_base,
                    'scale0': gs_zoom0,
                    'scale1': gs_zoom1,
                    'scale2': gs_zoom2,
                    'label': label_idx,
                    'path': file_path
        }
        return sample
    
    def get_gs(self, path):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        gs = loadply(path, self.max_sh_degree)
        if self.transform:
            for transform in self.transform:
                gs = transform(gs)
        RGB = SH2RGB(gs["sh"]).squeeze(-1)
        gs = [gs['xyz'],gs['opacity'],gs['scale'],gs['q'],RGB] #xyz,opacity,scale,Rotation matrix,sh
        gs = np.concatenate(gs, axis=1)
        gs_tensor = torch.from_numpy(gs).to(device).float().contiguous()

        gs_cen, max_xyz = center_gs(gs_tensor)
        gs_zoom_base = fps(gs_cen, 1024)
        for i in range(360):
            base_cam = generate_external_lookat_camera(gs_cen[:, :3], k=max(max_xyz),fov_deg=20)
            cam_random = base_cam
            zoom1_cam = base_cam.copy(); zoom1_cam["fov_deg"] = 10
            zoom2_cam = base_cam.copy(); zoom2_cam["fov_deg"] = 5
            idx0 = select_choice_frustum_points(gs_cen[:,:3], cam_random)
            idx1 = select_choice_frustum_points(gs_cen[:,:3], zoom1_cam)
            idx2 = select_choice_frustum_points(gs_cen[:,:3], zoom2_cam)
            if idx0.shape[0] >= 1024 and idx1.shape[0] >= 1024 and idx2.shape[0] >= 1024 :
                break

        gs_zoom0 = fps(gs_cen[idx0], 1024)
        gs_zoom1 = fps(gs_cen[idx1], 1024)
        gs_zoom2 = fps(gs_cen[idx2], 1024)
 

        return gs_zoom_base, gs_zoom0, gs_zoom1, gs_zoom2, cam_random, zoom1_cam, zoom2_cam


    

    


















