import os
import cv2
import numpy as np
import json
from torch.utils.data import Dataset, DataLoader
import torch
import albumentations as A

class ImageSegmentationDataset(Dataset):
    """Image segmentation dataset."""

    def __init__(self, root_path, mode, feature_extractor = None, transforms=None, image_with_mask_flag = False):
        
        self.feature_extractor = feature_extractor
        self.root_path = root_path
        self.mode = mode

        if self.mode == 'train':
            self.image_folder = os.path.join(root_path, 'train', 'images')
            self.annotation_root_path = os.path.join(root_path, 'train', 'annotations')
            self.all_parameter_root_path = os.path.join(root_path, 'train', 'all_parameter_mask')
        elif self.mode == 'val':
            self.image_folder = os.path.join(root_path, 'val', 'images')
            self.annotation_root_path = os.path.join(root_path, 'val', 'annotations')
            self.all_parameter_root_path = os.path.join(root_path, 'val', 'all_parameter_mask')
        elif self.mode == 'test':
            self.image_folder = os.path.join(root_path, 'test', 'images')
            self.annotation_root_path = os.path.join(root_path, 'test', 'annotations')
            self.all_parameter_root_path = os.path.join(root_path, 'test', 'all_parameter_mask')
        else:
            raise ValueError("Invalid mode. Choose 'train', 'val', or 'test'.")

        self.transforms = transforms
        # read images
        image_file_names = []
        for root, dirs, files in os.walk(self.image_folder):
            image_file_names.extend(files)
        self.images = sorted(image_file_names)    
        # read annotations
        annotation_file_names = []
        for root, dirs, files in os.walk(self.annotation_root_path):
            annotation_file_names.extend(files)
        self.annotations = sorted(annotation_file_names)

        # Read segmentation mask
        all_parameter_file_name = []
        for root, dirs, files in os.walk(self.all_parameter_root_path):
            all_parameter_file_name.extend(files)

        self.all_parameter_mask = all_parameter_file_name


        # self.class_list = ['VilliM']
        self.class_list = ['VilliM', 'CryptM']
        # self.class_list = ['VilliM', 'Villi Shoulder', 'Crypt Border']
        # self.class_map = {'VilliM': 1, 'Villi Shoulder': 2, 'Crypt Border': 3}
        self.class_map = {'VilliM': 1, 'CryptM': 2}
        self.number_of_points = 3
        self.max_length = self.number_of_points
        self.image_with_mask_flag = image_with_mask_flag

        assert len(self.images) == len(self.annotations), "There must be as many images as there are segmentation maps"

    def __len__(self):
        return len(self.images)

    def get_unique_points(self, points, width, height):
        unique_point = []
        for point in points:
            x, y = (point[0]/width), (point[1]/height)
            if (x,y) not in unique_point:
                unique_point.append((x,y))
        return unique_point
    
    def get_unique_IR_points(self, points, width, height):
        unique_point = []
        for point in points:
            x, y = (point[0]/width)*512, (point[1]/height)*512
            if (x,y) not in unique_point:
                unique_point.append((x,y))
        return unique_point
    
    def interpolate_points(self, points, target_length=4):
        points = np.array(points)
        if len(points) == target_length:
            return points
        
        
        if len(points) < target_length:
            while len(points) < target_length:
                # Find the longest segment and insert a midpoint
                max_dist = 0
                insert_pos = 0
                for i in range(len(points) - 1):
                    dist = np.linalg.norm(points[i] - points[i + 1])
                    if dist > max_dist:
                        max_dist = dist
                        insert_pos = i + 1
                mid_point = (points[insert_pos - 1] + points[insert_pos]) / 2
                points = np.insert(points, insert_pos, mid_point, axis=0)
            return points

        
        if len(points) > target_length:
            # Select the two end points (first and last)
            selected_points = [points[0], points[-1]]
            remaining_points = points[1:-1]

            # Randomly select 3 points from the remaining points without changing the order
            if len(remaining_points) > 3:
                selected_indices = np.random.choice(len(remaining_points), size=3, replace=False)
                selected_indices.sort()  # Sort indices to keep the original order
                for idx in selected_indices[::-1]:  # Insert in reverse order to keep positions correct
                    selected_points.insert(1, remaining_points[idx])
            else:
                for point in remaining_points[::-1]:  # Insert in reverse order to keep positions correct
                    if len(selected_points)<target_length:
                        selected_points.insert(1, point)

            selected_points = np.array(selected_points)
            return selected_points

    def __getitem__(self, index):

        image_path = os.path.join(self.image_folder, self.images[index])
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        original_width, original_height = image.shape[0], image.shape[1]
        
        image = cv2.resize(image, (512, 512))
        segmentation_map = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)  # Initialize single-channel mask
        IR_mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
            
        with open(os.path.join(self.annotation_root_path, self.annotations[index].split('.')[0]+'.json'), 'r') as file:
                json_data = json.load(file)

        all_parameter_mask_path = os.path.join(self.image_folder, self.all_parameter_mask[index])
        all_parameter_mask = cv2.imread(all_parameter_mask_path)
        all_parameter_mask = cv2.resize(all_parameter_mask, (512, 512))
        
        IR_bboxes = []
        points_list = []
        class_label_list = []
        for data in json_data['shapes']:
            class_name = data['label']
            if class_name in self.class_list:
                class_label_list.append(class_name)
                data_points = self.get_unique_points(data['points'], original_height, original_width)
                coords = [(point[0], point[1]) for point in data_points]
                if self.number_of_points >2:
                    interpolated_coords = self.interpolate_points(coords, target_length=self.number_of_points)  # Adjust points to have exactly 4 points
                else:
                    interpolated_coords = [coords[0], coords[-1]]
                
                points = np.array([(point[0], point[1]) for point in interpolated_coords], dtype=np.float32)
                points_list.append(points)

            if class_name =="Interpretable Region":
                data_points = self.get_unique_IR_points(data['points'], original_height, original_width)
                IR_points =np.array([(point[0], point[1]) for point in data_points], dtype=np.int32)
                IR_bboxes.append(IR_points)
        
        fixed_arrays = np.full((len(points_list), self.max_length * 2), -1, dtype=np.float32)
        for i, sub_list in enumerate(points_list):
            for j, point in enumerate(sub_list):
                if j < self.max_length:
                    fixed_arrays[i, j * 2] = point[0]  # x-coordinate
                    fixed_arrays[i, j * 2 + 1] = point[1]  # y-coordinate
                else:
                    print("Max_Length Exceeded")

        coord_array = fixed_arrays.reshape(-1, self.number_of_points * 2)
        class_labels = np.array([self.class_map[label_type] for label_type in class_label_list ])
        # Draw rectangles for IR bounding boxes on the IR mask
        for _, bbox in enumerate(IR_bboxes):
            x_min, y_min = int(bbox[0][0]), int(bbox[0][1])
            x_max, y_max = int(bbox[1][0]), int(bbox[1][1])
            cv2.rectangle(IR_mask, (x_min, y_min), (x_max, y_max), color=1, thickness=-1)  # Fill rectangle with 1

        if self.image_with_mask_flag:
            image= np.concatenate((image, all_parameter_mask), axis=-1)

        if self.transforms is not None:
            augmented = self.transforms(image=image, mask=segmentation_map)
            image = np.moveaxis(augmented['image'], -1, 0)
        else:
            image = np.moveaxis(image, -1, 0)
        
        target = {}
        image = torch.as_tensor(image, dtype=torch.float32)
        target['boxes'] = torch.as_tensor(coord_array, dtype=torch.float32).reshape(-1, self.number_of_points*2)
        target['labels'] = torch.as_tensor(class_labels, dtype=torch.int64)
        target['orig_size'] = torch.tensor((original_width, original_height), dtype=torch.int64)
        target['image_id'] = torch.tensor(index, dtype = torch.int64)

        return image, target
    
transform = A.Compose([
    A.Flip(p=0.5),  
    A.RandomBrightnessContrast(p=0.2),  
    A.Rotate(limit=30, p=0.5),  
])


def build_measurement(mode, args):
    dataset = ImageSegmentationDataset(root_path=args.coco_path, mode=mode)
    return dataset