import os
import cv2
import numpy as np
import json
from torch.utils.data import Dataset, DataLoader
import torch
import albumentations as A

import os
import cv2
import json
import numpy as np
import torch
from torch.utils.data import Dataset
import albumentations as A


class ImageSegmentationDataset(Dataset):
    """Image segmentation dataset."""

    def __init__(self, root_path, mode, feature_extractor=None, transforms=None, image_with_mask_flag=True):
        self.feature_extractor = feature_extractor
        self.root_path = root_path
        self.mode = mode

        if self.mode == 'train':
            self.image_folder = os.path.join(root_path, 'train', 'images')
            self.annotation_root_path = os.path.join(root_path, 'train', 'annotations')
            self.all_parameter_root_path = os.path.join(root_path, 'train', 'all_parameter_mask')
        elif self.mode == 'val':
            self.image_folder = os.path.join(root_path, 'val', 'images')
            self.annotation_root_path = os.path.join(root_path, 'val', 'annotations')
            self.all_parameter_root_path = os.path.join(root_path, 'val', 'all_parameter_mask')
        elif self.mode == 'test':
            self.image_folder = os.path.join(root_path, 'test', 'images')
            self.annotation_root_path = os.path.join(root_path, 'test', 'annotations')
            self.all_parameter_root_path = os.path.join(root_path, 'test', 'all_parameter_mask')
        else:
            raise ValueError("Invalid mode. Choose 'train', 'val', or 'test'.")

        self.transforms = transforms
        self.image_with_mask_flag = image_with_mask_flag
        self.class_list = ['VilliM', 'CryptM']
        self.class_map = {'VilliM': 1, 'CryptM': 2}
        self.number_of_points = 3
        self.max_length = self.number_of_points

        # Load image and annotation file names
        self.images = sorted([f for f in os.listdir(self.image_folder) if os.path.isfile(os.path.join(self.image_folder, f))])
        self.annotations = sorted([f for f in os.listdir(self.annotation_root_path) if os.path.isfile(os.path.join(self.annotation_root_path, f))])
        self.all_parameter_mask = sorted([f for f in os.listdir(self.all_parameter_root_path) if os.path.isfile(os.path.join(self.all_parameter_root_path, f))])

        assert len(self.images) == len(self.annotations), "There must be as many images as there are segmentation maps"

    def __len__(self):
        return len(self.images)

    def get_unique_points(self, points, width, height):
        unique_point = []
        for point in points:
            x, y = (point[0]/width), (point[1]/height)
            if (x, y) not in unique_point:
                unique_point.append((x, y))
        return unique_point

    def get_unique_IR_points(self, points, width, height):
        unique_point = []
        for point in points:
            x, y = (point[0]/width) * 512, (point[1]/height) * 512
            if (x, y) not in unique_point:
                unique_point.append((x, y))
        return unique_point

    def interpolate_points(self, points, target_length=3):
        points = np.array(points)
        num_points = len(points)

        if num_points == 2:
            return np.array([points[0], points[1], points[-1]])  # Duplicate the end point
        
        elif num_points == 3:
            return points
        
        elif num_points > 3:
            mid_index = num_points // 2
            midpoint = points[mid_index]
            return np.array([points[0], midpoint, points[-1]])

    def __getitem__(self, index):
        image_path = os.path.join(self.image_folder, self.images[index])
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        original_width, original_height = image.shape[0], image.shape[1]
        image = cv2.resize(image, (512, 512))
        
        segmentation_map = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)  # Initialize single-channel mask
        
        with open(os.path.join(self.annotation_root_path, self.annotations[index].split('.')[0] + '.json'), 'r') as file:
            json_data = json.load(file)

        all_parameter_mask_path = os.path.join(self.all_parameter_root_path, self.all_parameter_mask[index])
        all_parameter_mask = cv2.imread(all_parameter_mask_path)
        all_parameter_mask = cv2.resize(all_parameter_mask, (512, 512))
        
        points_list = []
        class_label_list = []
        for data in json_data['shapes']:
            class_name = data['label']
            if class_name in self.class_list:
                class_label_list.append(class_name)
                data_points = self.get_unique_points(data['points'], original_height, original_width)
                coords = [(point[0], point[1]) for point in data_points]
                interpolated_coords = self.interpolate_points(coords, target_length=self.number_of_points)
                
                points = np.array([(point[0], point[1]) for point in interpolated_coords], dtype=np.float32)
                points_list.append(points)

        fixed_arrays = np.full((len(points_list), self.max_length * 2), -1, dtype=np.float32)
        for i, sub_list in enumerate(points_list):
            for j, point in enumerate(sub_list):
                if j < self.max_length:
                    fixed_arrays[i, j * 2] = point[0]  # x-coordinate
                    fixed_arrays[i, j * 2 + 1] = point[1]  # y-coordinate
                else:
                    print("Max_Length Exceeded")

        coord_array = fixed_arrays.reshape(-1, self.number_of_points * 2)
        class_labels = np.array([self.class_map[label_type] for label_type in class_label_list])
        

        if self.image_with_mask_flag:
            image = np.concatenate((image, all_parameter_mask), axis=-1)

        if self.transforms is not None:
            augmented = self.transforms(image=image, mask=segmentation_map)
            image = np.moveaxis(augmented['image'], -1, 0)
        else:
            image = np.moveaxis(image, -1, 0)
        
        target = {}
        image = torch.as_tensor(image, dtype=torch.float32)
        target['boxes'] = torch.as_tensor(coord_array, dtype=torch.float32).reshape(-1, self.number_of_points * 2)
        target['labels'] = torch.as_tensor(class_labels, dtype=torch.int64)
        target['orig_size'] = torch.tensor((original_width, original_height), dtype=torch.int64)
        target['image_id'] = torch.tensor(index, dtype=torch.int64)

        return image, target


# Example of using the dataset
transform = A.Compose([
    A.Flip(p=0.5),
    A.RandomBrightnessContrast(p=0.2),
    A.Rotate(limit=30, p=0.5),
])


def build_measurement(mode, args):
    dataset = ImageSegmentationDataset(root_path=args.coco_path, mode=mode)
    return dataset
