import os, io, csv, math, random
import numpy as np
from einops import rearrange
import random
import torch
from decord import VideoReader
import cv2
from scipy.ndimage import distance_transform_edt
import torchvision.transforms as transforms
from torch.utils.data.dataset import Dataset
# from utils.util import zero_rank_print
#from torchvision.io import read_image
from PIL import Image
import torchvision.transforms as T
import torch.nn.functional as F
import numpy as np
import json
import re
import time
from os import path

def bbox_intersection(bbox1, bbox2):
    x1_1, y1_1, x2_1, y2_1 = bbox1
    x1_2, y1_2, x2_2, y2_2 = bbox2
    
    inter_x1 = max(x1_1, x1_2)
    inter_y1 = max(y1_1, y1_2)
    inter_x2 = min(x2_1, x2_2)
    inter_y2 = min(y2_1, y2_2)
    
    if inter_x1 < inter_x2 and inter_y1 < inter_y2:
        return int(inter_x1), int(inter_y1), int(inter_x2), int(inter_y2)
    else:
        return None

def generate_attention_mask(height, width, bbox):
    xmin, ymin, xmax, ymax = bbox
    attention_mask_2d = np.zeros((height, width), dtype=int)
    attention_mask_2d[ymin:ymax+1, xmin:xmax+1] = 1
    attention_mask = attention_mask_2d.flatten()
    return attention_mask

def calculate_square_bbox(top_midpoint, bottom_midpoint):
    x1, y1 = top_midpoint
    x2, y2 = bottom_midpoint
    height = y2 - y1
    width = height  
    x_min = x1 - width / 2
    y_min = y1
    x_max = x1 + width / 2
    y_max = y2
    return int(x_min), int(y_min), int(x_max), int(y_max)

def keypoints_to_bbox(keypoints):
    keypoints = np.array(keypoints)
    x_min = np.min(keypoints[:, 0])
    x_max = np.max(keypoints[:, 0])
    y_min = np.min(keypoints[:, 1])
    y_max = np.max(keypoints[:, 1])
    return int(x_min), int(y_min), int(x_max), int(y_max)

def scale_bbox(bbox, scale_x, scale_y=None):
    if scale_y is None:
        scale_y = scale_x
    x_min, y_min, x_max, y_max = bbox
    x_center = (x_min + x_max) / 2
    y_center = (y_min + y_max) / 2
    width = x_max - x_min
    height = y_max - y_min
    new_width = width * scale_x
    new_height = height * scale_y
    new_x_min = max(x_center - new_width / 2, 0)
    new_x_max = x_center + new_width / 2
    new_y_min = max(y_center - new_height / 2, 0)
    new_y_max = y_center + new_height / 2
    return int(new_x_min), int(new_y_min), int(new_x_max), int(new_y_max)

def count_keypoints_in_bbox(keypoints, bbox):
    xmin, ymin, xmax, ymax = bbox
    in_bbox_count = 0
    for idx, (x, y) in enumerate(keypoints):
        if xmin <= x <= xmax and ymin <= y <= ymax:
            in_bbox_count += 1
    return in_bbox_count


def pil_image_to_numpy(image, is_maks = False, index = 1,size=(1024,576),is_normal = False):
    """Convert a PIL image to a NumPy array."""
    
    if is_maks:
        if is_normal:
            if image.mode != 'RGB':
                image = image.convert('RGB')
        image = image.resize(size)
        return np.array(image).astype(np.int16)
    else:
        if image.mode != 'RGB':
            image = image.convert('RGB')
        image = image.resize(size)
        return np.array(image)

    
def numpy_to_pt(images: np.ndarray, is_mask=False) -> torch.FloatTensor:
    """Convert a NumPy image to a PyTorch tensor."""
    if images.ndim == 3:
        images = torch.from_numpy(images.transpose(2, 0, 1))
        
    elif len(images.shape) == 4:
        images = torch.from_numpy(images.transpose(0, 3, 1, 2)) 
    else:
        images = torch.from_numpy(images.transpose(0, 1, 4, 2, 3))

    if is_mask:
        return images.float() 
    else:
        return images.float() / 255


def find_largest_inner_rectangle_coordinates(mask_gray):

    refine_dist = cv2.distanceTransform(mask_gray.astype(np.uint8), cv2.DIST_L2, 5, cv2.DIST_LABEL_PIXEL)
    _, maxVal, _, maxLoc = cv2.minMaxLoc(refine_dist)
    radius = int(maxVal)

    return maxLoc, radius



class Obyssey(Dataset):
    def __init__(self,video_folder,ann_folder, 
                 data_config_path='/users/zeyuzhu/ControlSD/Moore-AnimateAnyone/all_data_config/train.json', 
                 sample_size=(1024,576), sample_stride=4, sample_n_frames=14,):
        ann_folder_list_ = [i.replace(".json","") for i in os.listdir(ann_folder)]
        ann_folder_list = []
        for videoid in ann_folder_list_:
            number = self.find_last_number(videoid)
            # if  videoid[:-(len(str(number))+1)]=="dancing":
            ann_folder_list.append(videoid) 
        
        self.dataset = list(ann_folder_list)
        self.length = len(self.dataset)
        print(f"data scale: {self.length}")
        print(f"train image size: {sample_size}")
        random.shuffle(self.dataset)    
        self.video_folder    = video_folder
        self.sample_stride   = sample_stride
        self.sample_n_frames = sample_n_frames
        self.ann_folder = ann_folder
        self.sample_size = sample_size
        self.max_id = 15

        print("length",len(self.dataset))
        print("sample size",sample_size)
        
        import json
        with open(data_config_path, 'r') as f: self.data_config = json.load(f)
        
    def center_crop(self,img):
        h, w = img.shape[-2:]  # Assuming img shape is [C, H, W] or [B, C, H, W]
        min_dim = min(h, w)
        top = (h - min_dim) // 2
        left = (w - min_dim) // 2
        return img[..., top:top+min_dim, left:left+min_dim]
        
    def gen_gaussian_heatmap(self,imgSize=200):
        circle_img = np.zeros((imgSize, imgSize), np.float32)
        circle_mask = cv2.circle(circle_img, (imgSize//2, imgSize//2), imgSize//2, 1, -1)

        isotropicGrayscaleImage = np.zeros((imgSize, imgSize), np.float32)

        # Guass Map
        for i in range(imgSize):
            for j in range(imgSize):
                isotropicGrayscaleImage[i, j] = 1 / 2 / np.pi / (40 ** 2) * np.exp(
                    -1 / 2 * ((i - imgSize / 2) ** 2 / (40 ** 2) + (j - imgSize / 2) ** 2 / (40 ** 2)))

        isotropicGrayscaleImage = isotropicGrayscaleImage * circle_mask
        isotropicGrayscaleImage = (isotropicGrayscaleImage / np.max(isotropicGrayscaleImage)).astype(np.float32)
        isotropicGrayscaleImage = (isotropicGrayscaleImage / np.max(isotropicGrayscaleImage)*255).astype(np.uint8)

        return isotropicGrayscaleImage


    # V1
    def create_ellipse_matrix(self, center, radius, shape):
        # Creates a coordinate matrix with the specified shape
        x_coords, y_coords, z_coords = np.ogrid[:shape[0], :shape[1], :shape[2]]
        
        # Extract the semi-axes lengths for the ellipse
        a = radius[0]
        b = radius[1]
        c = radius[2]
        
        # Calculate the distance from each position to the center of the ellipse
        # and normalize it by the corresponding semi-axes
        distances = (x_coords - center[0])**2 / a**2 +(y_coords - center[1])**2 / b**2 +(z_coords - center[2])**2 / c**2
        
        # Create a Boolean matrix that represents whether each position is inside the ellipse
        ellipse_mask = distances <= 1
        
        # Converts a Boolean matrix to an integer matrix, with the ellipse being 1 inside and 0 outside
        ellipse_matrix = ellipse_mask.astype(int)
        
        return ellipse_matrix


    def fill_and_concatenate(self, a, b):
        """
        Fill the non-zero regions of a 3D matrix with values from array a and concatenate the results.
        
        :param a: A 1D numpy array with 3 values.
        :param b: A 3D numpy matrix with shape (100, 100, 100).
        :return: A 4D numpy array with shape (3, 100, 100, 100).
        """
        # if len(a) != 3:
        #     raise ValueError("Array a must contain exactly 3 elements.")
        output = np.zeros((a.shape[0],) + b.shape, dtype=b.dtype)
        
        for i, value in enumerate(a):
            filled_b = b.copy()  
            filled_b[b != 0] = value  
            output[i] = filled_b  
        
        return output
    
    def activate_score_normalize(self, scores):
        scores = torch.Tensor(scores)
        for index in range(len(scores)):
            scores[index] = scores[index]/3
        return scores
    
    def find_last_number(self,s):
        match = re.search(r'\d+$', s)
        if match:
            return match.group(0)
        else:
            return None
    
    def calculate_normal(self, x, y, z, x1, y1, z1, x2, y2, z2):
        vector1 = np.array([x1 - x, y1 - y, z1 - z])
        vector2 = np.array([x2 - x, y2 - y, z2 - z])
        
        normal = np.cross(vector1, vector2)
        
        if np.linalg.norm(normal) == 0:
            raise ValueError("The three points do not define a plane; they are collinear.")
        
        normal = normal / np.linalg.norm(normal)
        
        return normal

    def plane_cut_matrix(self, x, y, z, x1, y1, z1, x2, y2, z2, w, h, d):
        if w <= 0 or h <= 0 or d <= 0:
            raise ValueError("Dimensions w, h, and d must be positive integers.")

        normal = self.calculate_normal(x, y, z, x1, y1, z1, x2, y2, z2)
        
        dd = -normal.dot([x, y, z])
        
        X, Y, Z = np.meshgrid(np.arange(w), np.arange(h), np.arange(d), indexing='ij')

        plane_equation = normal[0]*X + normal[1]*Y + normal[2]*Z + dd
        
        matrix = (plane_equation < 0).astype(int)
        
        return matrix

    def generate_masks(self, height, weight, frame_pose_list):
        full_mask_list = []
        face_mask_list = []
        lip_mask_list = []
        attention_mask_list = []
        for pose_list in frame_pose_list:
            per_person_full_mask = []
            per_person_face_mask = []
            per_person_lip_mask = []
            per_person_attention_mask = []
            for pose in pose_list:
                ## full mask
                left_point, right_point = pose[6], pose[5]
                length = (right_point[0] - left_point[0])*1.2
                left_up_point = [left_point[0], max(left_point[1] - length, 0)]
                right_up_point = [right_point[0], max(right_point[1] - length, 0)]
                keypoint_list = np.array([left_point, right_point, left_up_point, right_up_point, pose[23], pose[31], pose[39], pose[53]])
                bbox = keypoints_to_bbox(keypoint_list)
                full_bbox = scale_bbox(bbox, 1.4, 1.6)
                ## face mask
                points = np.array([pose[23], pose[31], pose[39], pose[53]])
                face_bbox = keypoints_to_bbox(points)
                face_bbox = scale_bbox(face_bbox, 1.4, 1.6)
                ## lip mask
                top_point, down_point = pose[53], pose[31]         
                lip_bbox = calculate_square_bbox(top_point, down_point)
                lip_bbox = scale_bbox(lip_bbox, 1.6, 0.8)
                ## union
                face_bbox = bbox_intersection(face_bbox, full_bbox)
                if face_bbox is None: face_bbox = full_bbox
                lip_bbox = bbox_intersection(lip_bbox, face_bbox)
                if lip_bbox is None: lip_bbox = face_bbox
                
                per_person_full_mask.append(full_bbox)
                per_person_face_mask.append(face_bbox)
                per_person_lip_mask.append(lip_bbox)
                ## attention mask
                attention_mask_bbox = scale_bbox(full_bbox, 1.2)
                attention_mask = generate_attention_mask(height, weight, attention_mask_bbox)
                per_person_attention_mask.append(attention_mask)
            full_mask_list.append(per_person_full_mask)
            face_mask_list.append(per_person_face_mask)
            lip_mask_list.append(per_person_lip_mask)
            attention_mask_list.append(per_person_attention_mask)     
        attention_mask_list, full_mask_list = np.array(attention_mask_list).transpose((1, 0, 2)), np.array(full_mask_list).transpose((1, 0, 2))
        face_mask_list, lip_mask_list = np.array(face_mask_list).transpose((1, 0, 2)), np.array(lip_mask_list).transpose((1, 0, 2))
        return attention_mask_list, full_mask_list, face_mask_list, lip_mask_list
    
    def calculate_center_coordinates(self, image_path, pose_path, activate_path, audio_embeding_path, motion_frame_num=1, depth_scale=80, z_ratio=0.3):
        poses = []
        frames = []
        npy_list = sorted(os.listdir(image_path))
        random_index = np.random.randint(0, len(npy_list)-self.sample_n_frames-1)
        for index_mask in range(random_index, random_index+self.sample_n_frames):
            ## video frame
            frame_name = npy_list[index_mask]
            frame_path = os.path.join(image_path, frame_name)
            frame_image = Image.open(frame_path)
            frame_image = pil_image_to_numpy(frame_image, size=self.sample_size)
            frames.append(frame_image)
            ## pose frame
            numpy_path = os.path.join(pose_path, frame_name)
            frame_pose = Image.open(numpy_path)
            frame_pose = pil_image_to_numpy(frame_pose, size=self.sample_size)
            poses.append(frame_pose)
        poses = numpy_to_pt(np.array(poses))
        pixel_values = numpy_to_pt(np.array(frames))
        ## audio embeding
        audio_embeding = torch.load(audio_embeding_path)[random_index:random_index+self.sample_n_frames, ...]
        ## masks: need to align with the human(check out if keypoints lie in instance bbox)
        with open(activate_path, 'r') as f: infos = json.load(f)
        dataset_path = image_path.replace('frames', 'pose')
        pose_name_list = [pose_name for pose_name in os.listdir(dataset_path) if pose_name.endswith('.json')]
        pose_name_list = sorted(pose_name_list)[random_index:random_index+self.sample_n_frames]
        frame_pose_list = []
        for pose_name in pose_name_list:
            with open(path.join(dataset_path, pose_name), 'r') as f: pose_list = json.load(f)['instance_info']
            pose_in_order = []
            for info in infos: 
                instance_bbox, within_bbox_count = info['instance_bbox'], []
                for pose in pose_list: within_bbox_count.append(count_keypoints_in_bbox(pose['keypoints'], instance_bbox))
                match_index = np.argmax(within_bbox_count)
                pose_in_order.append(pose_list[match_index]['keypoints'])
            frame_pose_list.append(pose_in_order)
        #get the mask: attention mask for each person, full mask, face mask, lip mask
        #attention_mask, full_mask, face_mask, lip_mask = self.generate_masks(self.sample_size[0], self.sample_size[1], frame_pose_list)
        attention_mask, full_mask, face_mask, lip_mask = torch.zeros(1), torch.zeros(1), torch.zeros(1), torch.zeros(1)
        mask = dict(attention_mask=attention_mask, 
                    full_mask=full_mask, 
                    face_mask=face_mask, 
                    lip_mask=lip_mask)
        ## activate scores
        activate_scores = []
        for info in infos:
            activate_score = info['speaking_score'][random_index:random_index+self.sample_n_frames]
            activate_scores.append(activate_score)
        activate_scores = self.activate_score_normalize(activate_scores)    
        ## frames(n frames to be generated), motion frames(the last k frames), reference_image
        frames = pixel_values
        frames = self.normalize(frames)
        # reference_image
        if random_index == 0: reference_index=0
        else: reference_index = np.random.randint(0, random_index)
        reference_image_name = npy_list[reference_index]
        reference_image_path = path.join(image_path, reference_image_name)
        reference_image = pil_image_to_numpy(Image.open(reference_image_path), size=self.sample_size)
        reference_image = numpy_to_pt(np.array(reference_image))
        reference_image = self.normalize(reference_image)
        # motion_frames
        motion_frames = []
        for frame_index in range(random_index, random_index-motion_frame_num, -1):
            if frame_index <= 0:
                motion_image_name = npy_list[0]
            else: motion_image_name = npy_list[frame_index]
            motion_image_path = path.join(image_path, motion_image_name)
            motion_frame = Image.open(motion_image_path)
            motion_frame = pil_image_to_numpy(motion_frame, size=self.sample_size)
            motion_frames.append(motion_frame)
        motion_frames = numpy_to_pt(np.array(motion_frames))
        motion_frames = self.normalize(motion_frames)
        return frames, reference_image, motion_frames, poses, audio_embeding, activate_scores, mask
    
    def get_batch(self, idx):
        while True:
            video_meta = self.data_config[idx]
            preprocessed_dir = video_meta["image_dir"]
            pose_frame_path = video_meta["pose_dir"]
            activate_path = video_meta["speak_info_path"]
            audio_embeding_path = video_meta["audio_embedding_path"]
            
            try:
                pixel_values, reference_image, motion_frames, poses, audio_embeding, activate_scores, mask = self.calculate_center_coordinates(preprocessed_dir, pose_frame_path, activate_path, audio_embeding_path)
            except Exception as e:
                # print(f"发生错误: {e}")
                # print(preprocessed_dir, pose_frame_path)
                # print("does not exist {}".format(videoid))
                idx = random.randint(0, len(self.dataset) - 1)
                continue
            motion_values = 180
            break
        return pixel_values, reference_image, motion_frames, motion_values, poses, audio_embeding, activate_scores, mask

    def __len__(self):
        return self.length
    
    def coordinates_normalize(self,center_coordinates):
        first_point = center_coordinates[0]
        center_coordinates = [one-first_point for one in center_coordinates]
        
        return center_coordinates
    
    def normalize(self, images):
        """
        Normalize an image array to [-1,1].
        """
        return 2.0 * images - 1.0
    
    def normalize_sam(self, images):
        """
        Normalize an image array to [-1,1].
        """
        return (images - torch.tensor([0.485, 0.456, 0.406]).unsqueeze(0).unsqueeze(-1).unsqueeze(-1))/torch.tensor([0.229, 0.224, 0.225]).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
    
    def __getitem__(self, idx):
        
        pixel_values, reference_image, motion_frames, motion_values, poses, audio_embeding, activate_scores, mask = self.get_batch(idx)
        sample = dict(pixel_values=pixel_values, 
                      motion_frames=motion_frames,
                      reference_image=reference_image,
                      
                      motion_values=motion_values,
                      audio_embeding=audio_embeding,
                      activate_scores=activate_scores,
                      List_2D=poses,
                      
                      mask=mask)
        return sample



if __name__ == "__main__":

    import mrcfile
    from tqdm import tqdm
    dataset = Obyssey(
        video_folder = "/ai/aidata/VideoGeneration/Motion3D/data/Obyssey",
        ann_folder = "/ai/aidata/VideoGeneration/Motion3D/data/Obyssey/Obyssey_Traj_Whole",
        sample_size=(576,320),
        sample_stride=1, sample_n_frames=15
    )

    inverse_process = transforms.Compose([
    transforms.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], std=[1/0.229, 1/0.224, 1/0.225]),])
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, num_workers=10,)
    for idx, batch in tqdm(enumerate(dataloader)):
        images = ((batch["pixel_values"][0].permute(0,2,3,1)+1)/2)*255
        List_2D = batch["List_2D"][0].permute(0,2,3,1)
        
        List_3D = batch["List_3D"][0].permute(0,2,3,1)

        # print(images.shape)
        # print(List_3D.shape)

        for i in range(images.shape[0]):
            image = images[i].numpy().astype(np.uint8)

            heatmap = List_3D[i].numpy()
            Image2D = List_2D[i].numpy()

            cv2.imwrite("./vis/image_{}.jpg".format(i), image) 
            cv2.imwrite("./vis/cc{}.jpg".format(i), Image2D.astype(np.uint8)*0.5+image*0.5) 

            heatmap = heatmap[::2, ::2,:]
            with mrcfile.new_mmap('./vis/{}.mrc'.format(i), overwrite=True, shape=heatmap.shape, mrc_mode=2) as mrc:
                    mrc.data[:] = heatmap

        break