import os
import cv2
import numpy as np
import json
from torch.utils.data import Dataset, DataLoader
import torch
import albumentations as A

import os
import cv2
import json
import numpy as np
import torch
from torch.utils.data import Dataset
import albumentations as A


class ImageSegmentationDataset(Dataset):
    """Image segmentation dataset."""

    def __init__(self, root_path, mode, feature_extractor=None, transforms=None, image_with_mask_flag=False):
        self.feature_extractor = feature_extractor
        self.root_path = root_path
        self.mode = mode

        if self.mode == 'train':
            self.image_folder = os.path.join(root_path, 'train', 'images')
            self.annotation_root_path = os.path.join(root_path, 'train', 'annotations')
            self.all_parameter_root_path = os.path.join(root_path, 'train', 'all_parameter_mask')
        elif self.mode == 'val':
            self.image_folder = os.path.join(root_path, 'val', 'images')
            self.annotation_root_path = os.path.join(root_path, 'val', 'annotations')
            self.all_parameter_root_path = os.path.join(root_path, 'val', 'all_parameter_mask')
        elif self.mode == 'test':
            self.image_folder = os.path.join(root_path, 'test', 'images')
            self.annotation_root_path = os.path.join(root_path, 'test', 'annotations')
            self.all_parameter_root_path = os.path.join(root_path, 'test', 'all_parameter_mask')
        else:
            raise ValueError("Invalid mode. Choose 'train', 'val', or 'test'.")

        self.transforms = transforms
        self.image_with_mask_flag = image_with_mask_flag
        self.class_list = ['VilliM', 'CryptM']
        self.class_map = {'VilliM': 1, 'CryptM': 2}
        self.number_of_points = 3
        self.max_length = self.number_of_points

        self.degree = 2
        self.degree_points = self.degree + 1

        # Load image and annotation file names
        self.images = sorted([f for f in os.listdir(self.image_folder) if os.path.isfile(os.path.join(self.image_folder, f))])
        self.annotations = sorted([f for f in os.listdir(self.annotation_root_path) if os.path.isfile(os.path.join(self.annotation_root_path, f))])
        self.all_parameter_mask = sorted([f for f in os.listdir(self.all_parameter_root_path) if os.path.isfile(os.path.join(self.all_parameter_root_path, f))])

        assert len(self.images) == len(self.annotations), "There must be as many images as there are segmentation maps"

    def __len__(self):
        return len(self.images)

    def get_unique_points(self, points, width, height):
        unique_point = []
        for point in points:
            x, y = (point[0]/width), (point[1]/height)
            if (x, y) not in unique_point:
                unique_point.append((x, y))
        return unique_point

    def get_unique_IR_points(self, points, width, height):
        unique_point = []
        for point in points:
            x, y = (point[0]/width) * 512, (point[1]/height) * 512
            if (x, y) not in unique_point:
                unique_point.append((x, y))
        return unique_point

    def interpolate_points(self, points, target_length=4):
        points = np.array(points)
        if len(points) == target_length:
            return points
        
        if len(points) < target_length:
            while len(points) < target_length:
                # Find the longest segment and insert a midpoint
                max_dist = 0
                insert_pos = 0
                for i in range(len(points) - 1):
                    dist = np.linalg.norm(points[i] - points[i + 1])
                    if dist > max_dist:
                        max_dist = dist
                        insert_pos = i + 1
                mid_point = (points[insert_pos - 1] + points[insert_pos]) / 2
                points = np.insert(points, insert_pos, mid_point, axis=0)
            return points

        if len(points) > target_length:
            # Select the two end points (first and last)
            selected_points = [points[0], points[-1]]
            remaining_points = points[1:-1]

            # Randomly select 3 points from the remaining points without changing the order
            if len(remaining_points) > 3:
                selected_indices = np.random.choice(len(remaining_points), size=3, replace=False)
                selected_indices.sort()  # Sort indices to keep the original order
                for idx in selected_indices[::-1]:  # Insert in reverse order to keep positions correct
                    selected_points.insert(1, remaining_points[idx])
            else:
                for point in remaining_points[::-1]:  # Insert in reverse order to keep positions correct
                    if len(selected_points)<target_length:
                        selected_points.insert(1, point)

            selected_points = np.array(selected_points)
            return selected_points
        

    def __getitem__(self, index):
        image_path = os.path.join(self.image_folder, self.images[index])
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        original_width, original_height = image.shape[0], image.shape[1]
        image = cv2.resize(image, (512, 512))
        
        segmentation_map = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)  # Initialize single-channel mask
        
        with open(os.path.join(self.annotation_root_path, self.annotations[index].split('.')[0] + '.json'), 'r') as file:
            json_data = json.load(file)

        all_parameter_mask_path = os.path.join(self.all_parameter_root_path, self.all_parameter_mask[index])
        all_parameter_mask = cv2.imread(all_parameter_mask_path)
        all_parameter_mask = cv2.resize(all_parameter_mask, (512, 512))
        
        points_list = []
        class_label_list = []
        polynomial_coeffs_list = []
        for data in json_data['shapes']:
            class_name = data['label']
            if class_name in class_list:
                class_label_list.append(class_name)
                data_points = self.get_unique_points(data['points'], original_height, original_width)
                coords = [(point[0], point[1]) for point in data_points]
                if number_of_points > 2:
                    interpolated_coords = self.interpolate_points(coords, target_length=number_of_points)
                else:
                    interpolated_coords = [coords[0], coords[-1]]
                
                points = np.array([(point[0], point[1]) for point in interpolated_coords], dtype=np.float32)
                points_list.append(points)
                class_value = class_map[class_name]
                # Fit a 2nd degree polynomial to the points and get the coefficients
                polynomial_coeffs = self.fit_polynomial(points)
                polynomial_coeffs_list.append(polynomial_coeffs)
        
        fixed_arrays = np.full((len(points_list), self.max_length * 2 + self.degree_points), -1, dtype=np.float32)
        
        for i, sub_list in enumerate(points_list):
            for j, point in enumerate(sub_list):
                if j < self.max_length:
                    fixed_arrays[i, j * 2] = point[0]  # x-coordinate
                    fixed_arrays[i, j * 2 + 1] = point[1]  # y-coordinate
                else:
                    print("Max_Length Exceeded")
            # Append polynomial coefficients to the end of each box entry
            fixed_arrays[i, -self.degree_points:] = polynomial_coeffs_list[i]

        coord_array = fixed_arrays.reshape(-1, number_of_points*2 + self.degree_points)
        class_labels = np.array([self.class_map[label_type] for label_type in class_label_list])
        

        if self.image_with_mask_flag:
            image = np.concatenate((image, all_parameter_mask), axis=-1)

        if self.transforms is not None:
            augmented = self.transforms(image=image, mask=segmentation_map)
            image = np.moveaxis(augmented['image'], -1, 0)
        else:
            image = np.moveaxis(image, -1, 0)
        
        target = {}
        image = torch.as_tensor(image, dtype=torch.float32)
        print(image.shape)
        target['boxes'] = torch.as_tensor(coord_array, dtype=torch.float32).reshape(-1, number_of_points*2+self.degree_points)
        target['labels'] = torch.as_tensor(class_labels, dtype=torch.int64)
        target['orig_size'] = torch.tensor((original_width, original_height), dtype=torch.int64)
        target['image_id'] = torch.tensor(index, dtype=torch.int64)

        return image, target


# Example of using the dataset
transform = A.Compose([
    A.Flip(p=0.5),
    A.RandomBrightnessContrast(p=0.2),
    A.Rotate(limit=30, p=0.5),
])


def build_measurement(mode, args):
    dataset = ImageSegmentationDataset(root_path=args.coco_path, mode=mode)
    return dataset
