import os
import cv2
import numpy as np
import json
from torch.utils.data import Dataset, DataLoader
import torch
import albumentations as A
import random
import os
import cv2
import json
import numpy as np
import torch
from torch.utils.data import Dataset
import albumentations as A


class ImageSegmentationDataset(Dataset):
    """Image segmentation dataset."""

    def __init__(self, root_path, mode, feature_extractor=None, transforms=None, image_with_mask_flag=False):
        self.feature_extractor = feature_extractor
        self.root_path = root_path
        self.mode = mode

        if self.mode == 'train':
            self.image_folder = os.path.join(root_path, 'train', 'images')
            self.annotation_root_path = os.path.join(root_path, 'train', 'annotations')
            self.all_parameter_root_path = os.path.join(root_path, 'train', 'all_parameter_mask')
        elif self.mode == 'val':
            self.image_folder = os.path.join(root_path, 'val', 'images')
            self.annotation_root_path = os.path.join(root_path, 'val', 'annotations')
            self.all_parameter_root_path = os.path.join(root_path, 'val', 'all_parameter_mask')
        elif self.mode == 'test':
            self.image_folder = os.path.join(root_path, 'test', 'images')
            self.annotation_root_path = os.path.join(root_path, 'test', 'annotations')
            self.all_parameter_root_path = os.path.join(root_path, 'test', 'all_parameter_mask')
        else:
            raise ValueError("Invalid mode. Choose 'train', 'val', or 'test'.")

        self.transforms = transforms
        self.image_with_mask_flag = image_with_mask_flag
        self.class_list = ['VilliM', 'CryptM']
        self.class_map = {'VilliM': 1, 'CryptM': 2}
        self.number_of_points = 3
        self.max_length = self.number_of_points
        self.transform_image_size = 512

        self.images = sorted([f for f in os.listdir(self.image_folder) if os.path.isfile(os.path.join(self.image_folder, f))])
        self.annotations = sorted([f for f in os.listdir(self.annotation_root_path) if os.path.isfile(os.path.join(self.annotation_root_path, f))])
        self.all_parameter_mask = sorted([f for f in os.listdir(self.all_parameter_root_path) if os.path.isfile(os.path.join(self.all_parameter_root_path, f))])

        assert len(self.images) == len(self.annotations), "There must be as many images as there are segmentation maps"

    def __len__(self):
        return len(self.images)

    def get_unique_points(self, points, width, height):
        unique_point = []
        for point in points:
            x, y = (point[0]/width), (point[1]/height)
            if (x, y) not in unique_point:
                unique_point.append((x, y))
        return unique_point

    def get_unique_IR_points(self, points, width, height):
        unique_point = []
        for point in points:
            x, y = (point[0]/width) * self.transform_image_size, (point[1]/height) * self.transform_image_size
            if (x, y) not in unique_point:
                unique_point.append((x, y))
        return unique_point

    def interpolate_points(self, points, target_length=3):
        points = np.array(points)
        num_points = len(points)

        if num_points == 2:
            return np.array([points[0], points[1], points[-1]])  # Duplicate the end point
        
        elif num_points == 3:
            return points
        
        elif num_points > 3:
            mid_index = num_points // 2
            midpoint = points[mid_index]
            return np.array([points[0], midpoint, points[-1]])

    def flip_points(self, points, img_width, img_height, flip_code):
        """ Flip points horizontally or vertically. """
        if flip_code == 1:  # Horizontal flip
            return [(1 - x, y) for (x, y) in points]
        elif flip_code == 0:  # Vertical flip
            return [(x, 1 - y) for (x, y) in points]
        return points

    def rotate_points(self, points, angle, img_width, img_height):
        """ Rotate points by 90°, 180°, or 270° angles."""
        if angle == 90:
            return [(y, 1 - x) for (x, y) in points]  # (y, img_width - x)
        elif angle == 180:
            return [(1 - x, 1 - y) for (x, y) in points]  # (img_width - x, img_height - y)
        elif angle == 270:
            return [(1 - y, x) for (x, y) in points]  # (img_height - y, x)
        return points

    def __getitem__(self, index):
        image_path = os.path.join(self.image_folder, self.images[index])
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        original_width, original_height = image.shape[0], image.shape[1]
        image = cv2.resize(image, (self.transform_image_size, self.transform_image_size))
        
        segmentation_map = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)  # Initialize single-channel mask
        
        with open(os.path.join(self.annotation_root_path, self.annotations[index].split('.')[0] + '.json'), 'r') as file:
            json_data = json.load(file)

        all_parameter_mask_path = os.path.join(self.all_parameter_root_path, self.all_parameter_mask[index])
        all_parameter_mask = cv2.imread(all_parameter_mask_path)
        all_parameter_mask = cv2.resize(all_parameter_mask, (self.transform_image_size, self.transform_image_size))
        
        points_list = []
        class_label_list = []
        for data in json_data['shapes']:
            class_name = data['label']
            if class_name in self.class_list:
                class_label_list.append(class_name)
                data_points = self.get_unique_points(data['points'], original_height, original_width)
                coords = [(point[0], point[1]) for point in data_points]
                interpolated_coords = self.interpolate_points(coords, target_length=self.number_of_points)
                
                points = np.array([(point[0], point[1]) for point in interpolated_coords], dtype=np.float32)
                points_list.append(points)


        if self.mode=='train':
            flip_type = random.choice(['none', 'horizontal', 'vertical', 'BothHV'])

            if flip_type == 'horizontal':
                image = np.fliplr(image).copy()
                points_list = [self.flip_points(points, self.transform_image_size, self.transform_image_size, 1) for points in points_list]
            elif flip_type == 'vertical':
                image = np.flipud(image).copy()
                points_list = [self.flip_points(points, self.transform_image_size, self.transform_image_size, 0) for points in points_list]
            elif flip_type == 'BothHV':
                image = np.fliplr(image).copy()
                image = np.flipud(image).copy()
                points_list = [self.flip_points(points, self.transform_image_size, self.transform_image_size, 1) for points in points_list]
                points_list = [self.flip_points(points, self.transform_image_size, self.transform_image_size, 0) for points in points_list]
            
            # elif flip_type=='rotate':
            #     rotation_angle = random.choice([90, 180, 270])
            #     if rotation_angle > 0:
            #         image = np.rot90(image, k=rotation_angle // 90).copy()
            #         points_list = [self.rotate_points(points, rotation_angle, sellf.transform_image_size, sellf.transform_image_size) for points in points_list]
            
            if self.transforms is not None:
                augmented = self.transforms(image=image).copy()
                image = augmented['image']
        
        fixed_arrays = np.full((len(points_list), self.max_length * 2), -1, dtype=np.float32)
        for i, sub_list in enumerate(points_list):
            for j, point in enumerate(sub_list):
                if j < self.max_length:
                    fixed_arrays[i, j * 2] = point[0]
                    fixed_arrays[i, j * 2 + 1] = point[1]
                else:
                    print("Max_Length Exceeded")

        coord_array = fixed_arrays.reshape(-1, self.number_of_points * 2)
        class_labels = np.array([self.class_map[label_type] for label_type in class_label_list])
        

        if self.image_with_mask_flag:
            image = np.concatenate((image, all_parameter_mask), axis=-1)

        
        image = np.moveaxis(image, -1, 0)
        
        target = {}
        image = torch.as_tensor(image, dtype=torch.float32)
        target['boxes'] = torch.as_tensor(coord_array, dtype=torch.float32).reshape(-1, self.number_of_points * 2)
        target['labels'] = torch.as_tensor(class_labels, dtype=torch.int64)
        target['orig_size'] = torch.tensor((original_width, original_height), dtype=torch.int64)
        target['image_id'] = torch.tensor(index, dtype=torch.int64)

        return image, target


transform = A.Compose([
    #A.Flip(p=0.5),
    A.RGBShift(r_shift_limit=25, g_shift_limit=25, b_shift_limit=25, p=0.5),
    A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.5),
    # A.RandomBrightnessContrast(p=0.2),
    #A.Rotate(limit=30, p=0.5),
])

def build_measurement(mode, args):
    dataset = ImageSegmentationDataset(root_path=args.coco_path, mode=mode, transforms = None)
    return dataset
