import os, io, csv, math, random
import numpy as np
from einops import rearrange
import random
import torch
from decord import VideoReader
import cv2
from scipy.ndimage import distance_transform_edt
import torchvision.transforms as transforms
from torch.utils.data.dataset import Dataset
# from utils.util import zero_rank_print
#from torchvision.io import read_image
from PIL import Image
import torchvision.transforms as T
import torch.nn.functional as F
import numpy as np
import json
import re
import time


def pil_image_to_numpy(image, is_maks = False, index = 1,size=(1024,576),is_normal = False):
    """Convert a PIL image to a NumPy array."""
    
    if is_maks:
        if is_normal:
            if image.mode != 'RGB':
                image = image.convert('RGB')
        image = image.resize(size)
        return np.array(image).astype(np.int16)
    else:
        if image.mode != 'RGB':
            image = image.convert('RGB')
        image = image.resize(size)
        return np.array(image)

    
def numpy_to_pt(images: np.ndarray, is_mask=False) -> torch.FloatTensor:
    """Convert a NumPy image to a PyTorch tensor."""
    if images.ndim == 3:
        images = images[..., None]
    
    # for 3D + Normal
    if len(images.shape) == 4:
        images = torch.from_numpy(images.transpose(0, 3, 1, 2))
    else:
        images = torch.from_numpy(images.transpose(0, 1, 4, 2, 3))

    if is_mask:
        return images.float() 
    else:
        return images.float() / 255


def find_largest_inner_rectangle_coordinates(mask_gray):

    refine_dist = cv2.distanceTransform(mask_gray.astype(np.uint8), cv2.DIST_L2, 5, cv2.DIST_LABEL_PIXEL)
    _, maxVal, _, maxLoc = cv2.minMaxLoc(refine_dist)
    radius = int(maxVal)

    return maxLoc, radius



class Obyssey(Dataset):
    def __init__(
            self,video_folder,ann_folder,
            sample_size=(1024,576), sample_stride=4, sample_n_frames=14,
        ):

        ann_folder_list = [i.replace(".json","") for i in os.listdir(ann_folder)]
        ## data config
        self.dataset = list(ann_folder_list)

        self.length = len(self.dataset)
        print(f"data scale: {self.length}")
        print(f"train image size: {sample_size}")
        random.shuffle(self.dataset)    
        self.video_folder    = video_folder
        self.sample_stride   = sample_stride
        self.sample_n_frames = sample_n_frames
        self.ann_folder = ann_folder
        self.sample_size = sample_size

        self.max_id = 15

        print("length",len(self.dataset))
        print("sample size",sample_size)
        
    def center_crop(self,img):
        h, w = img.shape[-2:]  # Assuming img shape is [C, H, W] or [B, C, H, W]
        min_dim = min(h, w)
        top = (h - min_dim) // 2
        left = (w - min_dim) // 2
        return img[..., top:top+min_dim, left:left+min_dim]
        
    def gen_gaussian_heatmap(self,imgSize=200):
        circle_img = np.zeros((imgSize, imgSize), np.float32)
        circle_mask = cv2.circle(circle_img, (imgSize//2, imgSize//2), imgSize//2, 1, -1)

        isotropicGrayscaleImage = np.zeros((imgSize, imgSize), np.float32)

        # Guass Map
        for i in range(imgSize):
            for j in range(imgSize):
                isotropicGrayscaleImage[i, j] = 1 / 2 / np.pi / (40 ** 2) * np.exp(
                    -1 / 2 * ((i - imgSize / 2) ** 2 / (40 ** 2) + (j - imgSize / 2) ** 2 / (40 ** 2)))

        isotropicGrayscaleImage = isotropicGrayscaleImage * circle_mask
        isotropicGrayscaleImage = (isotropicGrayscaleImage / np.max(isotropicGrayscaleImage)).astype(np.float32)
        isotropicGrayscaleImage = (isotropicGrayscaleImage / np.max(isotropicGrayscaleImage)*255).astype(np.uint8)

        return isotropicGrayscaleImage


    # V1
    def create_ellipse_matrix(self, center, radius, shape):
        # Creates a coordinate matrix with the specified shape
        x_coords, y_coords, z_coords = np.ogrid[:shape[0], :shape[1], :shape[2]]
        
        # Extract the semi-axes lengths for the ellipse
        a = radius[0]
        b = radius[1]
        c = radius[2]
        
        # Calculate the distance from each position to the center of the ellipse
        # and normalize it by the corresponding semi-axes
        distances = (x_coords - center[0])**2 / a**2 +(y_coords - center[1])**2 / b**2 +(z_coords - center[2])**2 / c**2
        
        # Create a Boolean matrix that represents whether each position is inside the ellipse
        ellipse_mask = distances <= 1
        
        # Converts a Boolean matrix to an integer matrix, with the ellipse being 1 inside and 0 outside
        ellipse_matrix = ellipse_mask.astype(int)
        
        return ellipse_matrix

    # V2, faster
    # def create_ellipse_matrix(self, center, radius, shape):
    #     # Ensure the center and radius are numpy arrays
    #     center = np.array(center, dtype=np.float32)
    #     radius = np.array(radius, dtype=np.float32)
        
    #     # Create a coordinate grid
    #     X, Y, Z = np.indices(shape, dtype=np.float32)
        
    #     # Calculate the squared distances (avoiding np.sqrt until the end)
    #     # Broadcasting will automatically expand the center and radius arrays
    #     squared_distances = ((X - center[0])**2 / radius[0]**2 +
    #                         (Y - center[1])**2 / radius[1]**2 +
    #                         (Z - center[2])**2 / radius[2]**2)
        
    #     # Create a Boolean matrix for points inside the ellipsoid
    #     ellipse_mask = squared_distances <= 1
        
    #     # Convert to integer matrix
    #     ellipse_matrix = ellipse_mask.astype(int)
        
    #     return ellipse_matrix

    def fill_and_concatenate(self, a, b):
        """
        Fill the non-zero regions of a 3D matrix with values from array a and concatenate the results.
        
        :param a: A 1D numpy array with 3 values.
        :param b: A 3D numpy matrix with shape (100, 100, 100).
        :return: A 4D numpy array with shape (3, 100, 100, 100).
        """
        # if len(a) != 3:
        #     raise ValueError("Array a must contain exactly 3 elements.")
        

        output = np.zeros((a.shape[0],) + b.shape, dtype=b.dtype)
        
        for i, value in enumerate(a):
            filled_b = b.copy()  
            filled_b[b != 0] = value  
            output[i] = filled_b  
        
        return output
    
    def find_last_number(self,s):
        match = re.search(r'\d+$', s)
        if match:
            return match.group(0)
        else:
            return None
    
    def calculate_normal(self, x, y, z, x1, y1, z1, x2, y2, z2):
        vector1 = np.array([x1 - x, y1 - y, z1 - z])
        vector2 = np.array([x2 - x, y2 - y, z2 - z])
        
        normal = np.cross(vector1, vector2)
        
        if np.linalg.norm(normal) == 0:
            raise ValueError("The three points do not define a plane; they are collinear.")
        
        normal = normal / np.linalg.norm(normal)
        
        return normal

    def plane_cut_matrix(self, x, y, z, x1, y1, z1, x2, y2, z2, w, h, d):
        if w <= 0 or h <= 0 or d <= 0:
            raise ValueError("Dimensions w, h, and d must be positive integers.")

        normal = self.calculate_normal(x, y, z, x1, y1, z1, x2, y2, z2)
        
        dd = -normal.dot([x, y, z])
        
        X, Y, Z = np.meshgrid(np.arange(w), np.arange(h), np.arange(d), indexing='ij')

        plane_equation = normal[0]*X + normal[1]*Y + normal[2]*Z + dd
        
        matrix = (plane_equation < 0).astype(int)
        
        return matrix

    # # 优化速度
    # def plane_cut_matrix(self, x, y, z, x1, y1, z1, x2, y2, z2, w, h, d):
    #     if w <= 0 or h <= 0 or d <= 0:
    #         raise ValueError("Dimensions w, h, and d must be positive integers.")

    #     # Calculate the normal vector
    #     normal = self.calculate_normal(x, y, z, x1, y1, z1, x2, y2, z2)
        
    #     # Calculate the plane equation parameter d
    #     dd = -np.dot([z, y, z], normal)
        
    #     # Create arrays for the dimensions without a full meshgrid
    #     X = np.arange(w, dtype=float)
    #     Y = np.arange(h, dtype=float)
    #     Z = np.arange(d, dtype=float)
        
    #     # Use broadcasting to calculate the plane equation for the whole volume at once
    #     plane_equation = (normal[0] * X[:, np.newaxis, np.newaxis] +
    #                     normal[1] * Y[np.newaxis, :, np.newaxis] +
    #                     normal[2] * Z[np.newaxis, np.newaxis, :] + dd)
        
    #     # Fill the matrix based on the plane equation
    #     matrix = (plane_equation < 0).astype(int)
        
    #     return matrix

    def calculate_center_coordinates(self, image_path, data, depth_scale=40, z_ratio = 0.3):

        max_depth = float(data["max_depth"])
        min_depth = float(data["min_depth"])

        List_2D = []
        List_3D = []
        numpy_images = []
        
        
        for index_mask, frame_name in enumerate(data):
            if "rgb" not in frame_name:
                continue
            
            if index_mask >= self.sample_n_frames+2:
                break

            frame_path_one = os.path.join(image_path,frame_name)
            image = Image.open(frame_path_one)

            width, height = image.size
            original_size = [width, height]

            image = pil_image_to_numpy(image,size=self.sample_size)
            numpy_images.append(image)

            img_3D = np.zeros((self.sample_size[1],self.sample_size[0],depth_scale), np.uint8)
            img_2D = np.zeros((self.sample_size[1],self.sample_size[0],3), np.float32)
            
            for idx,instance_ID in enumerate(data[frame_name]):
                center = data[frame_name][instance_ID]["center"]
                refer_center_1 = data[frame_name][instance_ID]["refer_center_1"]
                refer_center_2 = data[frame_name][instance_ID]["refer_center_2"]

                center = [float(center[0]),float(center[1]),float(center[2])]
                refer_center_1 = [float(refer_center_1[0]),float(refer_center_1[1]),float(refer_center_1[2])]
                refer_center_2 = [float(refer_center_2[0]),float(refer_center_2[1]),float(refer_center_2[2])]

                center[0] = center[0]/original_size[0]*self.sample_size[0]
                center[1] = center[1]/original_size[1]*self.sample_size[1]
                refer_center_1[0] = refer_center_1[0]/original_size[0]*self.sample_size[0]
                refer_center_1[1] = refer_center_1[1]/original_size[1]*self.sample_size[1]
                refer_center_2[0] = refer_center_2[0]/original_size[0]*self.sample_size[0]
                refer_center_2[1] = refer_center_2[1]/original_size[1]*self.sample_size[1]

                radius = float(data[frame_name][instance_ID]["radius"])/original_size[1]*self.sample_size[1]
                max_side = self.sample_size[1]/5

                radius = min(radius,max_side)
                radius = max(radius,5)
                radius = int(radius)

                # if radius < 3:
                #     continue

                try:
                    # start_time = time.time()
                    sphere_matrix = self.create_ellipse_matrix((int(center[1]),int(center[0]), int(((float(center[2]))/max_depth*depth_scale)))
                                                    , (radius,radius,int(depth_scale*z_ratio)), (self.sample_size[1],self.sample_size[0],depth_scale))
                    # end_time = time.time()  # 记录结束时间
                    # elapsed_time = end_time - start_time
                    # print(f"Elapsed time: {elapsed_time} seconds")

                    # start_time = time.time()
                    ref_sphere_matrix = self.plane_cut_matrix(int(center[1]),int(center[0]), int(((float(center[2]))/max_depth*depth_scale)),
                                                        int(refer_center_1[1]),int(refer_center_1[0]), int(((float(refer_center_1[2]))/max_depth*depth_scale)), 
                                                        int(refer_center_2[1]),int(refer_center_2[0]), int(((float(refer_center_2[2]))/max_depth*depth_scale)), 
                                                        self.sample_size[1],self.sample_size[0], depth_scale)
                    # end_time = time.time()  # 记录结束时间
                    # elapsed_time = end_time - start_time
                    # print(f"11111: {elapsed_time} seconds")

                except:
                    # print("The three points do not define a plane; they are collinear.")
                    continue

                sphere_matrix = sphere_matrix * ref_sphere_matrix
                img_3D[sphere_matrix != 0] = sphere_matrix[sphere_matrix != 0]

                # for debug
                # circle_img = np.zeros((self.sample_size[1],self.sample_size[0]), np.float32)
                # circle_mask = cv2.circle(circle_img, (int(center[0]),int(center[1])), radius, 1, -1)

                # non_zero_coordinates = np.column_stack(np.where(img_3D != 0))
                # for coord in non_zero_coordinates:
                #     img_2D[coord[0], coord[1]] = 250


            # new_img = cv2.cvtColor(new_img.astype(np.uint8), cv2.COLOR_GRAY2RGB)
            img_3D = img_3D[:, :, ::2]
            img_3D = (img_3D!=0).astype(np.uint8) * 255
            
            List_3D.append(img_3D)
            List_2D.append(img_2D)
            pixel_values = numpy_to_pt(np.array(numpy_images))

        
        return pixel_values,List_2D,List_3D
    
    def get_batch(self, idx):

        def sort_frames(frame_name):
            return int(frame_name.split('.')[0])
  
        while True:
            videoid = self.dataset[idx]
            number = self.find_last_number(videoid)
            # print(videoid)
            
            # if  videoid[:-(len(str(number))+1)]!="dancing":
            #     idx = random.randint(0, len(self.dataset) - 1)
            #     continue

            
            preprocessed_dir = os.path.join(self.video_folder,videoid[:-(len(str(number))+1)])
            preprocessed_dir = os.path.join(preprocessed_dir,"rgbs")

            ann_path = os.path.join(self.ann_folder,videoid+".json")
            with open(ann_path, 'r', encoding='utf-8') as file:
                data = json.load(file)

            # try:
            pixel_values,List_2D,List_3D = self.calculate_center_coordinates(preprocessed_dir,data)
            # except:
            #     print("does not exist {}".format(videoid))
            #     idx = random.randint(0, len(self.dataset) - 1)
            #     continue

            List_2D = np.array(List_2D)
            List_3D = np.array(List_3D)

            List_2D = numpy_to_pt(List_2D,True)
            List_3D = numpy_to_pt(List_3D,True)

            # Load motion values  .astype(np.int16)
            motion_values = 180
            break
        return pixel_values, motion_values, List_2D,List_3D

        
        
    
    def __len__(self):
        return self.length
    
    def coordinates_normalize(self,center_coordinates):
        first_point = center_coordinates[0]
        center_coordinates = [one-first_point for one in center_coordinates]
        
        return center_coordinates
    
    def normalize(self, images):
        """
        Normalize an image array to [-1,1].
        """
        return 2.0 * images - 1.0
    
    def normalize_sam(self, images):
        """
        Normalize an image array to [-1,1].
        """
        return (images - torch.tensor([0.485, 0.456, 0.406]).unsqueeze(0).unsqueeze(-1).unsqueeze(-1))/torch.tensor([0.229, 0.224, 0.225]).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
    
    def __getitem__(self, idx):
        
        pixel_values, motion_values, List_2D, List_3D = self.get_batch(idx)
        
        pixel_values = self.normalize(pixel_values)
        
        sample = dict(pixel_values=pixel_values, 
                      motion_values=motion_values,
                      List_2D=List_2D,
                      List_3D=List_3D)
        
        return sample



if __name__ == "__main__":

    import mrcfile
    from tqdm import tqdm
    dataset = Obyssey(
        video_folder = "/ai/aidata/VideoGeneration/Motion3D/data/Obyssey",
        ann_folder = "/ai/aidata/VideoGeneration/Motion3D/data/Obyssey/Obyssey_Traj_Whole",
        sample_size=(576,320),
        sample_stride=1, sample_n_frames=15
    )

    inverse_process = transforms.Compose([
    transforms.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], std=[1/0.229, 1/0.224, 1/0.225]),])
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, num_workers=10,)
    for idx, batch in tqdm(enumerate(dataloader)):
        images = ((batch["pixel_values"][0].permute(0,2,3,1)+1)/2)*255
        List_2D = batch["List_2D"][0].permute(0,2,3,1)
        
        List_3D = batch["List_3D"][0].permute(0,2,3,1)

        print(images.shape)
        print(List_3D.shape)

        for i in range(images.shape[0]):
            image = images[i].numpy().astype(np.uint8)

            heatmap = List_3D[i].numpy()
            Image2D = List_2D[i].numpy()

            cv2.imwrite("./vis/image_{}.jpg".format(i), image) 
            cv2.imwrite("./vis/cc{}.jpg".format(i), Image2D.astype(np.uint8)*0.5+image*0.5) 

            heatmap = heatmap[::2, ::2,:]
            with mrcfile.new_mmap('./vis/{}.mrc'.format(i), overwrite=True, shape=heatmap.shape, mrc_mode=2) as mrc:
                    mrc.data[:] = heatmap

        break