import os
import cv2
import numpy as np
import json
from torch.utils.data import Dataset, DataLoader
import torch
import albumentations as A
import random
import os
import cv2
import json
import random
import numpy as np
import torch
from torch.utils.data import Dataset
import albumentations as A


class ImageSegmentationDataset(Dataset):
    """Image segmentation dataset."""

    def __init__(self, root_path, mode, feature_extractor=None, transforms=None, image_with_mask_flag=False):
        self.feature_extractor = feature_extractor
        self.root_path = root_path
        self.mode = mode

        if self.mode == 'train':
            self.image_folder = os.path.join(root_path, 'train', 'images')
            self.annotation_root_path = os.path.join(root_path, 'train', 'annotations')
            self.segformer_predictions_root_path_strong = os.path.join(root_path, 'train', 'segformer_prediction_clean')
            self.segformer_predictions_root_path_weak = os.path.join(root_path, 'train', 'segformer_predictions_b1_ep50')
        elif self.mode == 'val':
            self.image_folder = os.path.join(root_path, 'val', 'images')
            self.annotation_root_path = os.path.join(root_path, 'val', 'annotations')
            self.segformer_predictions_root_path_strong = os.path.join(root_path, 'val', 'segformer_prediction_clean')
            self.segformer_predictions_root_path_weak = os.path.join(root_path, 'val', 'segformer_predictions_b1_ep50')
        elif self.mode == 'test':
            self.image_folder = os.path.join(root_path, 'test', 'images')
            self.annotation_root_path = os.path.join(root_path, 'test', 'annotations')
            self.segformer_predictions_root_path_strong = os.path.join(root_path, 'test', 'segformer_prediction_clean')
            self.segformer_predictions_root_path_weak = os.path.join(root_path, 'test', 'segformer_predictions_b1_ep50')
        else:
            raise ValueError("Invalid mode. Choose 'train', 'val', or 'test'.")

        self.transforms = transforms
        self.image_with_mask_flag = image_with_mask_flag
        self.class_list = ['VilliM', 'CryptM']
        self.class_map = {'VilliM': 1, 'CryptM': 2}
        self.number_of_points = 3
        self.max_length = self.number_of_points
        self.transform_image_size = 512
        self.random_drop_mask = 0.1
        self.drop_mask = False
        self.mixup = True
        
        if self.mode =='train':
            self.images = sorted([f for f in os.listdir(self.image_folder) if os.path.isfile(os.path.join(self.image_folder, f))])
            self.annotations = sorted([f for f in os.listdir(self.annotation_root_path) if os.path.isfile(os.path.join(self.annotation_root_path, f))])
            self.segformer_predictions_mask_strong = sorted([f for f in os.listdir(self.segformer_predictions_root_path_strong) if os.path.isfile(os.path.join(self.segformer_predictions_root_path_strong, f))])
            self.segformer_predictions_mask_weak = sorted([f for f in os.listdir(self.segformer_predictions_root_path_weak) if os.path.isfile(os.path.join(self.segformer_predictions_root_path_weak, f))])
        else:
            self.images = sorted([f for f in os.listdir(self.image_folder) if os.path.isfile(os.path.join(self.image_folder, f))])
            self.annotations = sorted([f for f in os.listdir(self.annotation_root_path) if os.path.isfile(os.path.join(self.annotation_root_path, f))])
            self.segformer_predictions_mask_strong = sorted([f for f in os.listdir(self.segformer_predictions_root_path_strong) if os.path.isfile(os.path.join(self.segformer_predictions_root_path_strong, f))])
            self.segformer_predictions_mask_weak = sorted([f for f in os.listdir(self.segformer_predictions_root_path_weak) if os.path.isfile(os.path.join(self.segformer_predictions_root_path_weak, f))])


        assert len(self.images) == len(self.annotations), "There must be as many images as there are segmentation maps"

    def __len__(self):
        return len(self.images)

    def get_unique_points(self, points, width, height):
        unique_point = []
        for point in points:
            x, y = (point[0]/width), (point[1]/height)
            if (x, y) not in unique_point:
                unique_point.append((x, y))
        return unique_point

    def get_unique_IR_points(self, points, width, height):
        unique_point = []
        for point in points:
            x, y = (point[0]/width) * self.transform_image_size, (point[1]/height) * self.transform_image_size
            if (x, y) not in unique_point:
                unique_point.append((x, y))
        return unique_point

    def interpolate_points(self, points, target_length=3):
        points = np.array(points)
        num_points = len(points)

        # if num_points == 2:
            # return np.array([points[0], points[1], points[-1]])  # Duplicate the end point

        if num_points == 2:
            midpoint = (points[0] + points[1]) / 2
            return np.array([points[0], midpoint, points[1]])
        
        elif num_points == 3:
            return points
        
        elif num_points > 3:
            mid_index = num_points // 2
            midpoint = points[mid_index]
            return np.array([points[0], midpoint, points[-1]])

    def flip_points(self, points, img_width, img_height, flip_code):
        """ Flip points horizontally or vertically. """
        if flip_code == 1:  # Horizontal flip
            return [(1 - x, y) for (x, y) in points]
        elif flip_code == 0:  # Vertical flip
            return [(x, 1 - y) for (x, y) in points]
        return points


    def rotate_points(self, points, angle, img_width, img_height):
        """ Rotate points by 90°, 180°, or 270° angles."""
        if angle == 90:
            return [(y, 1 - x) for (x, y) in points]  # (y, img_width - x)
        elif angle == 180:
            return [(1 - x, 1 - y) for (x, y) in points]  # (img_width - x, img_height - y)
        elif angle == 270:
            return [(1 - y, x) for (x, y) in points]  # (img_height - y, x)
        return points
    

    def rotate_image(self, image, angle):
        """Rotate image by specified angle (90, 180, 270 degrees)"""
        # Get image dimensions
        (h, w) = image.shape[:2]
        # Calculate the center of the image
        center = (w // 2, h // 2)
        if angle == 90:
            rotated = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
            return rotated
        elif angle == 180:
            M = cv2.getRotationMatrix2D(center, 180, 1.0)
            rotated = cv2.warpAffine(image, M, (w, h))
            return rotated
        elif angle == 270:
            rotated = cv2.rotate(image, cv2.ROTATE_180)
            rotated = cv2.rotate(rotated, cv2.ROTATE_90_COUNTERCLOCKWISE)
            return rotated
        else:
            return image
    
    def apply_mixup(self, img1, img2, mask1, mask2, alpha=0.4):
        """Apply mixup to the images and masks."""
        lam = np.random.beta(alpha, alpha)
        mixed_img = lam * img1 + (1 - lam) * img2
        mixed_mask = lam * mask1 + (1 - lam) * mask2
        return mixed_img, mixed_mask, lam

    def __getitem__(self, index):
        image_path = os.path.join(self.image_folder, self.images[index])
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        # print(index, self.images[index])
        
        original_width, original_height = image.shape[0], image.shape[1]
        image = cv2.resize(image, (self.transform_image_size, self.transform_image_size))
        
        segmentation_map = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)  # Initialize single-channel mask
        
        with open(os.path.join(self.annotation_root_path, self.annotations[index].split('.')[0] + '.json'), 'r') as file:
            json_data = json.load(file)

        segformer_predictions_mask_path_strong = os.path.join(self.segformer_predictions_root_path_strong, self.segformer_predictions_mask_strong[index].split('.')[0]+'.png')
        segformer_predictions_mask_path_weak = os.path.join(self.segformer_predictions_root_path_weak, self.segformer_predictions_mask_weak[index].split('.')[0]+'.png')

        if self.drop_mask:
            if self.mode=='train' and random.random() < self.random_drop_mask:
                segformer_predictions_mask = np.zeros((self.transform_image_size, self.transform_image_size, 3))
            else:
                segformer_predictions_mask = cv2.imread(segformer_predictions_mask_path)
                segformer_predictions_mask = cv2.resize(segformer_predictions_mask, (self.transform_image_size, self.transform_image_size))
        else:
            segformer_predictions_mask_strong = cv2.imread(segformer_predictions_mask_path_strong)
            segformer_predictions_mask_weak = cv2.imread(segformer_predictions_mask_path_weak)
            segformer_predictions_mask_strong = cv2.resize(segformer_predictions_mask_strong, (self.transform_image_size, self.transform_image_size))
            segformer_predictions_mask_weak = cv2.resize(segformer_predictions_mask_weak, (self.transform_image_size, self.transform_image_size))

        if self.mode == 'train' and self.mixup:
            lam = np.random.beta(0.2, 0.2)
            segformer_predictions_mask_strong = lam * segformer_predictions_mask_strong + (1 - lam) * segformer_predictions_mask_weak
        else:
            segformer_predictions_mask_strong = segformer_predictions_mask_strong
            

        points_list = []
        class_label_list = []
        for data in json_data['shapes']:
            class_name = data['label']
            if class_name in self.class_list:
                class_label_list.append(class_name)
                data_points = self.get_unique_points(data['points'], original_height, original_width)
                coords = [(point[0], point[1]) for point in data_points]
                interpolated_coords = self.interpolate_points(coords, target_length=self.number_of_points)
                
                points = np.array([(point[0], point[1]) for point in interpolated_coords], dtype=np.float32)
                points_list.append(points)


        if self.mode=='train':
            rotation_angle = random.choice([0, 90, 180, 270])
            if rotation_angle != 0:
                print("Image Rotated", self.images[index], index)
                image = self.rotate_image(image, rotation_angle)
                segformer_predictions_mask_strong = self.rotate_image(segformer_predictions_mask_strong, rotation_angle)
                segformer_predictions_mask_weak = self.rotate_image(segformer_predictions_mask_weak, rotation_angle)
                points_list = [self.rotate_points(points, rotation_angle, self.transform_image_size, self.transform_image_size) for points in points_list]
            
            flip_type = random.choice(['none', 'horizontal', 'vertical', 'BothHV'])

            if flip_type == 'horizontal':
                image = np.fliplr(image).copy()
                segformer_predictions_mask_strong = np.fliplr(segformer_predictions_mask_strong).copy()
                segformer_predictions_mask_weak = np.fliplr(segformer_predictions_mask_weak).copy()
                points_list = [self.flip_points(points, self.transform_image_size, self.transform_image_size, 1) for points in points_list]
            elif flip_type == 'vertical':
                image = np.flipud(image).copy()
                segformer_predictions_mask_strong = np.flipud(segformer_predictions_mask_strong).copy()
                segformer_predictions_mask_weak = np.flipud(segformer_predictions_mask_weak).copy()
                points_list = [self.flip_points(points, self.transform_image_size, self.transform_image_size, 0) for points in points_list]
            elif flip_type == 'BothHV':
                image = np.fliplr(image).copy()
                image = np.flipud(image).copy()
                segformer_predictions_mask_strong = np.fliplr(segformer_predictions_mask_strong).copy()
                segformer_predictions_mask_strong = np.flipud(segformer_predictions_mask_strong).copy()
                segformer_predictions_mask_weak = np.fliplr(segformer_predictions_mask_weak).copy()
                segformer_predictions_mask_weak = np.flipud(segformer_predictions_mask_weak).copy()
                points_list = [self.flip_points(points, self.transform_image_size, self.transform_image_size, 1) for points in points_list]
                points_list = [self.flip_points(points, self.transform_image_size, self.transform_image_size, 0) for points in points_list]
            
            
            if self.transforms is not None:
                augmented = self.transforms(image=image).copy()
                image = augmented['image']
        
        fixed_arrays = np.full((len(points_list), self.max_length * 2), -1, dtype=np.float32)
        for i, sub_list in enumerate(points_list):
            for j, point in enumerate(sub_list):
                if j < self.max_length:
                    fixed_arrays[i, j * 2] = point[0]
                    fixed_arrays[i, j * 2 + 1] = point[1]
                else:
                    print("Max_Length Exceeded")

        coord_array = fixed_arrays.reshape(-1, self.number_of_points * 2)
        class_labels = np.array([self.class_map[label_type] for label_type in class_label_list])
        

        if self.image_with_mask_flag:
            image = np.concatenate((image, segformer_predictions_mask), axis=-1)

        image = np.moveaxis(image, -1, 0)
        segformer_predictions_mask_strong = np.moveaxis(segformer_predictions_mask_strong, -1, 0)
        segformer_predictions_mask_weak = np.moveaxis(segformer_predictions_mask_weak, -1, 0)
        
        target = {}
        image = torch.as_tensor(image, dtype=torch.float32)
        target['boxes'] = torch.as_tensor(coord_array, dtype=torch.float32).reshape(-1, self.number_of_points * 2)
        target['labels'] = torch.as_tensor(class_labels, dtype=torch.int64)
        target['orig_size'] = torch.tensor((original_width, original_height), dtype=torch.int64)
        target['image_id'] = torch.tensor(index, dtype=torch.int64)
        segformer_predictions_mask_strong = torch.as_tensor(segformer_predictions_mask_strong, dtype = torch.float32)
        segformer_predictions_mask_weak = torch.as_tensor(segformer_predictions_mask_weak, dtype = torch.float32)

        return image, segformer_predictions_mask_strong, segformer_predictions_mask_weak, target


transform = A.Compose([
    #A.Flip(p=0.5),
    A.RGBShift(r_shift_limit=25, g_shift_limit=25, b_shift_limit=25, p=0.5),
    A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.5),
    # A.RandomBrightnessContrast(p=0.2),
    #A.Rotate(limit=30, p=0.5),
])

def build_measurement(mode, args):
    dataset = ImageSegmentationDataset(root_path=args.coco_path, mode=mode, transforms = None)
    return dataset
