import os, io, csv, math, random
import numpy as np
from einops import rearrange
import random
import torch
from decord import VideoReader
import cv2
from scipy.ndimage import distance_transform_edt
import torchvision.transforms as transforms
from torch.utils.data.dataset import Dataset
# from utils.util import zero_rank_print
#from torchvision.io import read_image
from PIL import Image
import torchvision.transforms as T
import torch.nn.functional as F
import numpy as np
import json

def pil_image_to_numpy(image, is_maks = False, index = 1,size=(1024,576),is_normal = False):
    """Convert a PIL image to a NumPy array."""
    
    if is_maks:
        if is_normal:
            if image.mode != 'RGB':
                image = image.convert('RGB')
        image = image.resize(size)
        return np.array(image).astype(np.int16)
    else:
        if image.mode != 'RGB':
            image = image.convert('RGB')
        image = image.resize(size)
        return np.array(image)

    
def numpy_to_pt(images: np.ndarray, is_mask=False) -> torch.FloatTensor:
    """Convert a NumPy image to a PyTorch tensor."""
    if images.ndim == 3:
        images = images[..., None]
    
    # for 3D + Normal
    if len(images.shape) == 4:
        images = torch.from_numpy(images.transpose(0, 3, 1, 2))
    else:
        images = torch.from_numpy(images.transpose(0, 1, 4, 2, 3))

    if is_mask:
        return images.float() 
    else:
        return images.float() / 255


def find_largest_inner_rectangle_coordinates(mask_gray):

    refine_dist = cv2.distanceTransform(mask_gray.astype(np.uint8), cv2.DIST_L2, 5, cv2.DIST_LABEL_PIXEL)
    _, maxVal, _, maxLoc = cv2.minMaxLoc(refine_dist)
    radius = int(maxVal)

    return maxLoc, radius



class YoutubeVos(Dataset):
    def __init__(
            self,video_folder,ann_folder,depth_folder,normal_map_folder, 
            sample_size=(1024,576), sample_stride=4, sample_n_frames=14,
        ):

        # data_list_depth = [i.replace(".pth","") for i in os.listdir(depth_folder)]   
        normal_map_folder_list = [i.replace(".pth","") for i in os.listdir(normal_map_folder)] 
        ann_folder_list = [i.replace(".pth","") for i in os.listdir(ann_folder)]
        
        self.trajectory_CoTracker = "./data/VIPSeg/SpaTracker_4P"
        trajectory_CoTracker_list = [i.replace(".json","") for i in os.listdir(self.trajectory_CoTracker)] 

        # set1 = set(data_list_depth)
        set2 = set(normal_map_folder_list)
        set3 = set(ann_folder_list)
        set4 = set(trajectory_CoTracker_list)
        intersection = set2 & set3 & set4
        self.dataset = list(intersection)

        self.length = len(self.dataset)
        print(f"data scale: {self.length}")
        print(f"train image size: {sample_size}")
        random.shuffle(self.dataset)    
        self.video_folder    = video_folder
        self.sample_stride   = sample_stride
        self.sample_n_frames = sample_n_frames
        self.ann_folder = ann_folder
        self.heatmap = self.gen_gaussian_heatmap()
        # self.depth_folder=depth_folder
        self.normal_map_folder=normal_map_folder
        self.sample_size = sample_size

        self.max_id = 15

        print("length",len(self.dataset))
#         sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
        
        print("sample size",sample_size)
        
    def center_crop(self,img):
        h, w = img.shape[-2:]  # Assuming img shape is [C, H, W] or [B, C, H, W]
        min_dim = min(h, w)
        top = (h - min_dim) // 2
        left = (w - min_dim) // 2
        return img[..., top:top+min_dim, left:left+min_dim]
        
    def gen_gaussian_heatmap(self,imgSize=200):
        circle_img = np.zeros((imgSize, imgSize), np.float32)
        circle_mask = cv2.circle(circle_img, (imgSize//2, imgSize//2), imgSize//2, 1, -1)

        isotropicGrayscaleImage = np.zeros((imgSize, imgSize), np.float32)

        # Guass Map
        for i in range(imgSize):
            for j in range(imgSize):
                isotropicGrayscaleImage[i, j] = 1 / 2 / np.pi / (40 ** 2) * np.exp(
                    -1 / 2 * ((i - imgSize / 2) ** 2 / (40 ** 2) + (j - imgSize / 2) ** 2 / (40 ** 2)))

        isotropicGrayscaleImage = isotropicGrayscaleImage * circle_mask
        isotropicGrayscaleImage = (isotropicGrayscaleImage / np.max(isotropicGrayscaleImage)).astype(np.float32)
        isotropicGrayscaleImage = (isotropicGrayscaleImage / np.max(isotropicGrayscaleImage)*255).astype(np.uint8)

#         isotropicGrayscaleImage = cv2.resize(isotropicGrayscaleImage, (40, 40))
        return isotropicGrayscaleImage
    
    def create_sphere_matrix(self, center, radius, shape):
        # Creates a coordinate matrix with the specified shape
        x_coords, y_coords, z_coords = np.ogrid[:shape[0], :shape[1], :shape[2]]
        
        # Calculate the distance from each position to the center of the sphere
        distances = np.sqrt((x_coords - center[0])**2 + (y_coords - center[1])**2 + (z_coords - center[2])**2)

        # Create a Boolean matrix that represents whether each position is inside the sphere
        sphere_mask = distances <= radius
        
        # Converts a Boolean matrix to an integer matrix, with the sphere being 1 inside and 0 outside
        sphere_matrix = sphere_mask.astype(int) * 1
        
        return sphere_matrix

    def create_ellipse_matrix(self, center, radius, shape):
        # Creates a coordinate matrix with the specified shape
        x_coords, y_coords, z_coords = np.ogrid[:shape[0], :shape[1], :shape[2]]
        
        # Extract the semi-axes lengths for the ellipse
        a = radius[0]
        b = radius[1]
        c = radius[2]
        
        # Calculate the distance from each position to the center of the ellipse
        # and normalize it by the corresponding semi-axes
        distances = np.sqrt((x_coords - center[0])**2 / a**2 +
                            (y_coords - center[1])**2 / b**2 +
                            (z_coords - center[2])**2 / c**2)
        
        # Create a Boolean matrix that represents whether each position is inside the ellipse
        ellipse_mask = distances <= 1
        
        # Converts a Boolean matrix to an integer matrix, with the ellipse being 1 inside and 0 outside
        ellipse_matrix = ellipse_mask.astype(int)
        
        return ellipse_matrix


    def fill_and_concatenate(self, a, b):
        """
        Fill the non-zero regions of a 3D matrix with values from array a and concatenate the results.
        
        :param a: A 1D numpy array with 3 values.
        :param b: A 3D numpy matrix with shape (100, 100, 100).
        :return: A 4D numpy array with shape (3, 100, 100, 100).
        """
        # if len(a) != 3:
        #     raise ValueError("Array a must contain exactly 3 elements.")
        

        output = np.zeros((a.shape[0],) + b.shape, dtype=b.dtype)
        
        for i, value in enumerate(a):
            filled_b = b.copy()  
            filled_b[b != 0] = value  
            output[i] = filled_b  
        
        return output

    def find_point_on_line_away_from_p1(self, p1, p2, r):
        """
        计算在点 p2 一侧，距离 p2 一定距离 r 的点 p3，
        并且点 p3 与点 p1 不在同一侧。
        
        参数:
        p1: 点 p1 的坐标 (x, y)
        p2: 点 p2 的坐标 (x1, y1)
        r: 距离 p2 的距离
        
        返回:
        p3: 在 p2 一侧，距离 p2 为 r 的点 p3 的坐标
        """
        # 计算向量 p1p2
        vector_p1p2 = np.array([p2[0] - p1[0], p2[1] - p1[1]])
        
        # 计算向量 p1p2 的模
        norm_p1p2 = np.linalg.norm(vector_p1p2)
        
        # 避免除以零
        if norm_p1p2 == 0:
            return p2
        
        # 计算比例因子 lambda
        # 使用 -r 确保 p3 在 p1 的对侧
        lambda_factor = r / norm_p1p2

        # 计算点 p3 的坐标
        p3 = [p1[0] + lambda_factor * vector_p1p2[0], 
            p1[1] + lambda_factor * vector_p1p2[1]]

        return p3


    def angle_with_x_axis(self, p):
        x, y = p
        # 计算与x轴的角度，结果是弧度
        angle_rad = math.atan2(y, x)
        
        # 将弧度转换为度
        angle_deg = math.degrees(angle_rad)
        
        return angle_deg

    # rotation
    def determine_rotation_direction(self, p1, p2):
        # 计算坐标变化量
        p1_angle = self.angle_with_x_axis(p1)
        p2_angle = self.angle_with_x_axis(p2)
        # print(p1,p2,p1_angle,p2_angle)
        # 判断旋转方向
        if abs(p1_angle-p2_angle)*100<100:
            return 0
        else:
            if p1[0]>0 and p1_angle>p2_angle:
                return 1
            elif p1[0]<0 and p1_angle<p2_angle:
                return 1
            else:
                return 2
    
    def calculate_center_coordinates(self, numpy_images, masks, ids,json_trajectory_content,original_size,sample_size,
                                      side=20, depth_1=20, ratio_d = 0.4):
        List_2D = []
        List_3D = []
        for index_mask, mask in enumerate(masks):
            img_3D = np.zeros((self.sample_size[1],self.sample_size[0],depth_1), np.uint8)
            img_2D = np.zeros((self.sample_size[1],self.sample_size[0],3), np.float32)
            
            for index in ids:

                mask_array = (np.array(mask)==int(index))*1
                mask_32 = cv2.resize(mask_array.astype(np.uint8),(int(self.sample_size[0]/8),int(self.sample_size[1]/8)))
                # if len(np.column_stack(np.where(mask_32 != 0)))==0:
                if mask_32.sum() <= 20:
                    continue
                
                # try:
                center_coordinate,radius  = find_largest_inner_rectangle_coordinates(mask_array)

                center_coordinate_ = json_trajectory_content[index]['center_coordinate'][index_mask]
                center_coordinate = [float(center_coordinate_[1]/original_size[0]*self.sample_size[0]),float(center_coordinate_[0]/original_size[1]*self.sample_size[1])]
                side = int(radius)
                side = min(side,100)
                depth_value = float(center_coordinate_[2])
                depth_value = max(depth_value,0)
                depth_value = min(depth_value,30)

                reference_coordinate_left = json_trajectory_content[index]['reference_point_left'][index_mask]
                reference_coordinate_right = json_trajectory_content[index]['reference_point_right'][index_mask]
                reference_coordinate_up = json_trajectory_content[index]['reference_point_up'][index_mask]
                reference_coordinate_down = json_trajectory_content[index]['reference_point_down'][index_mask]

                scale = 1000
                points = np.array([
                    (reference_coordinate_left[0]/original_size[1]*self.sample_size[1], reference_coordinate_left[1]/original_size[0]*self.sample_size[0], reference_coordinate_left[2]*scale),
                    (reference_coordinate_right[0]/original_size[1]*self.sample_size[1], reference_coordinate_right[1]/original_size[0]*self.sample_size[0], reference_coordinate_right[2]*scale),
                    (reference_coordinate_up[0]/original_size[1]*self.sample_size[1], reference_coordinate_up[1]/original_size[0]*self.sample_size[0], reference_coordinate_up[2]*scale),
                    (reference_coordinate_down[0]/original_size[1]*self.sample_size[1], reference_coordinate_down[1]/original_size[0]*self.sample_size[0], reference_coordinate_down[2]*scale),
                ])
                
                # center
                x, y, z =  center_coordinate_[0]/original_size[1]*self.sample_size[1], center_coordinate_[1]/original_size[0]*self.sample_size[0], center_coordinate_[2]*scale
                distances_squared = np.sum((points - [x, y, z])**2, axis=1)
                farthest_index = np.argmax(distances_squared)
                points = np.delete(points, farthest_index, axis=0)
                

                v1 = points[1] - points[0]
                v2 = points[2] - points[0]
                normal_vector = np.cross(v1, v2)
                
                if np.linalg.norm(normal_vector) == 0 or int(index) < 100:
                    unit_normal = np.array([0,  0,  1]) 
                else:
                    unit_normal = normal_vector / np.linalg.norm(normal_vector)

                # 计算点 P'，沿着法向量移动 r 的距离
                refer_side = int(radius * 5) 
                x_prime = x + refer_side * unit_normal[0]
                y_prime = y + refer_side * unit_normal[1]
                z_prime = depth_value + int(depth_1*ratio_d*5) * unit_normal[2]

                reference_coordinate = [float(y_prime),float(x_prime)]

                # z_center = max(z_center,0)
                # z_center = min(z_center,30)
                # reference_coordinate = self.find_point_on_line_away_from_p1((center_coordinate[0],center_coordinate[1]),
                                                # (reference_coordinate[0],reference_coordinate[1]),refer_side)
                
                refer_depth_value = float(z_prime) 
                # refer_depth_value = max(refer_depth_value,0)
                # refer_depth_value = min(refer_depth_value,30)

                if side>30000:
                    print("radius is too large")
                    continue
                
                try:
                    # 例子：创建一个512x512x512的球体矩阵，球心为(256, 256, 256)，半径为100
                    sphere_matrix = self.create_ellipse_matrix((int(center_coordinate[0]),int(center_coordinate[1]), int((depth_value/30*depth_1)))
                                                            , (side,side,int(depth_1*ratio_d)), (self.sample_size[1],self.sample_size[0], depth_1))

                    ref_sphere_matrix = self.create_ellipse_matrix((int(reference_coordinate[0]),int(reference_coordinate[1]), int((refer_depth_value/30*depth_1)))
                                                            , (refer_side,refer_side,int(depth_1*ratio_d*5)), (self.sample_size[1],self.sample_size[0], depth_1))
                    sphere_matrix = sphere_matrix * ref_sphere_matrix
                    img_3D[sphere_matrix != 0] = sphere_matrix[sphere_matrix != 0]
                except:
                    continue

                # for debug
                circle_img = np.zeros((self.sample_size[1],self.sample_size[0]), np.float32)
                circle_mask = cv2.circle(circle_img, (int(center_coordinate[1]),int(center_coordinate[0])), side, 1, -1)

                ref_circle_img = np.zeros((self.sample_size[1],self.sample_size[0]), np.float32)
                ref_circle_mask = cv2.circle(ref_circle_img, (int(reference_coordinate[1]),int(reference_coordinate[0])), refer_side, 1, -1)
                circle_mask = circle_mask * ref_circle_mask
                non_zero_coordinates = np.column_stack(np.where(circle_mask != 0))
                for coord in non_zero_coordinates:
                    img_2D[coord[0], coord[1]] = 250


            # new_img = cv2.cvtColor(new_img.astype(np.uint8), cv2.COLOR_GRAY2RGB)
            img_3D = (img_3D!=0).astype(np.uint8) * 255
            # img_3D = img_3D.astype(np.uint8)
            List_3D.append(img_3D)
            List_2D.append(img_2D)
        return List_2D,List_3D
    
    def get_batch(self, idx):
        def sort_frames(frame_name):
            return int(frame_name.split('.')[0])
  
        while True:
            videoid = self.dataset[idx]
    
            preprocessed_dir = os.path.join(self.video_folder, videoid)
            ann_folder = os.path.join(self.ann_folder, videoid)
            # depth_folder_file = os.path.join(self.depth_folder, videoid)
            # normal_map_file = os.path.join(self.normal_map_folder, videoid)
            
            
            if not os.path.exists(ann_folder):
                idx = random.randint(0, len(self.dataset) - 1)
                print("os.path.exists({}), error".format(ann_folder))
                continue
            
            # if not os.path.exists(normal_map_file):
            #     idx = random.randint(0, len(self.dataset) - 1)
            #     print("os.path.exists({}), error".format(normal_map_file))
            #     continue

            # Sort and limit the number of image and depth files to 14
            image_files = sorted(os.listdir(preprocessed_dir), key=sort_frames)[:self.sample_n_frames]
            mask_files = sorted(os.listdir(ann_folder), key=sort_frames)[:self.sample_n_frames]
            # depth_files = sorted(os.listdir(depth_folder_file), key=sort_frames)[:self.sample_n_frames]
            # normal_files = sorted(os.listdir(normal_map_file), key=sort_frames)[:self.sample_n_frames]

            
            json_trajectory = os.path.join(self.trajectory_CoTracker, videoid+".json")
            with open(json_trajectory, 'r', encoding='utf-8') as file:
                json_trajectory_content = json.load(file)

            # Load image frames and mask
            numpy_images = np.array([pil_image_to_numpy(Image.open(os.path.join(preprocessed_dir, img)),size=self.sample_size) for img in image_files])
            pixel_values = numpy_to_pt(numpy_images)
            masks = np.array([pil_image_to_numpy(Image.open(os.path.join(ann_folder, df)),True,size=self.sample_size) for df in mask_files])
            # depth = np.array([pil_image_to_numpy(Image.open(os.path.join(depth_folder_file, df)),True,size=self.sample_size) for df in depth_files])
            # try:
            #     normal_map = np.array([pil_image_to_numpy(Image.open(os.path.join(normal_map_file, df)),True,size=self.sample_size) for df in normal_files])
            # except:
            #     idx = random.randint(0, len(self.dataset) - 1)
            #     print("({}), image file is truncated".format(normal_map_file))
            #     continue


            # Load mask frames
            mask = Image.open(os.path.join(ann_folder, mask_files[0]))
            width, height = mask.size
            original_size = [width, height]


            # ids = [i for i in np.unique(np.array(mask))]
            ids = [i for i in json_trajectory_content]
            if len(ids)==1:
                idx = random.randint(0, len(self.dataset) - 1)
                print("len(ids), error")
                continue

            # random.shuffle(ids)
            # ids = ids[:self.max_id]

            List_2D,List_3D = self.calculate_center_coordinates(numpy_images,masks,ids,json_trajectory_content,
                                                                original_size,self.sample_size)
            List_2D = np.array(List_2D)
            List_3D = np.array(List_3D)

            mask_pixel_values = numpy_to_pt(masks,True)
            List_2D = numpy_to_pt(List_2D,True)
            List_3D = numpy_to_pt(List_3D,True)

            # Load motion values  .astype(np.int16)
            motion_values = 180
            break
        return pixel_values, mask_pixel_values, motion_values, List_2D,List_3D, videoid

        
        
    
    def __len__(self):
        return self.length
    
    def coordinates_normalize(self,center_coordinates):
        first_point = center_coordinates[0]
        center_coordinates = [one-first_point for one in center_coordinates]
        
        return center_coordinates
    
    def normalize(self, images):
        """
        Normalize an image array to [-1,1].
        """
        return 2.0 * images - 1.0
    
    def normalize_sam(self, images):
        """
        Normalize an image array to [-1,1].
        """
        return (images - torch.tensor([0.485, 0.456, 0.406]).unsqueeze(0).unsqueeze(-1).unsqueeze(-1))/torch.tensor([0.229, 0.224, 0.225]).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
    
    def __getitem__(self, idx):
        

        pixel_values, mask_pixel_values,motion_values,List_2D,List_3D, videoid = self.get_batch(idx)
        

        pixel_values = self.normalize(pixel_values)
        
        sample = dict(pixel_values=pixel_values, mask_pixel_values=mask_pixel_values,
                      motion_values=motion_values,List_2D=List_2D,List_3D=List_3D,videoid=videoid)
        return sample



if __name__ == "__main__":

    import mrcfile
    from tqdm import tqdm
    dataset = YoutubeVos(
        video_folder = "/ai/aidata/VideoGeneration/Drag3D/data/VIPSeg/imgs",
        ann_folder = "/ai/aidata/VideoGeneration/Drag3D/data/VIPSeg/panomasks",
        depth_folder = "/ai/aidata/VideoGeneration/Drag3D/data/VIPSeg/depth",
        normal_map_folder = "/ai/aidata/VideoGeneration/Drag3D/data/VIPSeg/NormalMap",
        sample_size=(576,320),
        sample_stride=1, sample_n_frames=15
    )
#     import pdb
#     pdb.set_trace()
    inverse_process = transforms.Compose([
    transforms.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], std=[1/0.229, 1/0.224, 1/0.225]),
])
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, num_workers=10,)
    for idx, batch in tqdm(enumerate(dataloader)):
        print(batch["videoid"])

        images = ((batch["pixel_values"][0].permute(0,2,3,1)+1)/2)*255
        masks = batch["mask_pixel_values"][0].permute(0,2,3,1)*255
        List_2D = batch["List_2D"][0].permute(0,2,3,1)
        # print("id",batch["id"])
        
        List_3D = batch["List_3D"][0].permute(0,2,3,1)
#         Id_Images = ((batch["Id_Images"][0])*torch.tensor([0.229, 0.224, 0.225]).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)+torch.tensor([0.485, 0.456, 0.406]).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)).permute(0,2,3,1)*255
#         center_coordinates = batch["center_coordinates"]
        
        print(images.shape)
        print(List_3D.shape)

        for i in range(images.shape[0]):
            image = images[i].numpy().astype(np.uint8)

            mask = masks[i].numpy()
            heatmap = List_3D[i].numpy()
            Image2D = List_2D[i].numpy()
            # print(heatmap.shape,Image2D.shape)
#             center_coordinate = center_coordinates[i][0][:2].numpy().astype(np.uint8)
            # print(image.shape,heatmap.shape)
            # print(i)
            # assert False
#             print(mask.shape)
#             print(center_coordinate)
#             mask[center_coordinate[0]:center_coordinate[0]+10,center_coordinate[1]:center_coordinate[1]+10]=125 
            
            # print(np.unique(heatmap))
            # print(heatmap.sum())
#             print(Id_Image.shape)
            # if i!=0:
            # print(np.unique(Image2D)[1]-np.unique(List_2D[0].numpy())[1])

            cv2.imwrite("./vis/image_{}.jpg".format(i), image) 
#             cv2.imwrite("./vis/Id_Image_{}.jpg".format(i), Id_Image) 
            cv2.imwrite("./vis/mask_{}.jpg".format(i), mask.astype(np.uint8)) 
            # cv2.imwrite("./vis/Image2D_{}.jpg".format(i), Image2D.astype(np.uint8)) 

            cv2.imwrite("./vis/cc{}.jpg".format(i), Image2D.astype(np.uint8)*0.5+image*0.5) 
            # assert False
            with mrcfile.new_mmap('./vis/{}.mrc'.format(i), overwrite=True, shape=heatmap.shape, mrc_mode=2) as mrc:
                    mrc.data[:] = heatmap

            # chunk_size = 100  # You can adjust this value based on your data size and memory constraints

            # # Calculate the number of chunks needed
            # num_chunks = heatmap.shape[0] // chunk_size + (heatmap.shape[0] % chunk_size != 0)

            # # print(heatmap.shape)
            # # Open the MRC file using 'new_mmap'
            # with mrcfile.new_mmap("./vis/{}.mrc".format(i), overwrite=False, shape=heatmap.shape, mrc_mode=2) as mrc:
            #     # Iterate over each chunk
            #     for chunk_index in range(num_chunks):
            #         # Calculate the start and end indices for the current chunk
            #         start_index = chunk_index * chunk_size
            #         end_index = min((chunk_index + 1) * chunk_size, heatmap.shape[0])
                    
            #         # Write the current chunk of data into the MRC file
            #         mrc.data[start_index:end_index] = heatmap[start_index:end_index]
        
            # assert False
            
            # save_videos_grid(batch["pixel_values"][i:i+1].permute(0,2,1,3,4), os.path.join(".", f"{idx}-{i}.mp4"), rescale=True)
        break