import os
import numpy as np
import torch
import cv2
from PIL import Image, ImageEnhance, ImageOps
import torch.nn.functional as F
import torchvision.transforms as transforms
import random


class ImageLoader:
    def __init__(self, target_size=224, augmentation_config=None):
        self.target_size = target_size
        self.augmentation_config = augmentation_config or {}

        # Base transform
        self.base_transform = transforms.Compose([
            transforms.Resize((target_size, target_size), interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.ToTensor(),
        ])

        # Data augmentation transform
        if self.augmentation_config.get('enable', False):
            augment_list = []

            # Color enhancement
            if self.augmentation_config.get('brightness', 0) > 0:
                augment_list.append(
                    transforms.ColorJitter(
                        brightness=self.augmentation_config.get('brightness', 0.1),
                        contrast=self.augmentation_config.get('contrast', 0.1),
                        saturation=self.augmentation_config.get('saturation', 0.1),
                        hue=self.augmentation_config.get('hue', 0.05)
                    )
                )

            # Geometric transformation
            if self.augmentation_config.get('rotation', 0) > 0:
                augment_list.append(
                    transforms.RandomRotation(
                        degrees=self.augmentation_config.get('rotation', 5),
                        interpolation=transforms.InterpolationMode.BICUBIC
                    )
                )

            if self.augmentation_config.get('horizontal_flip', 0) > 0:
                augment_list.append(
                    transforms.RandomHorizontalFlip(
                        p=self.augmentation_config.get('horizontal_flip', 0.5)
                    )
                )

            augment_list.extend([
                transforms.Resize((target_size, target_size), interpolation=transforms.InterpolationMode.BICUBIC),
                transforms.ToTensor(),
            ])

            self.augment_transform = transforms.Compose(augment_list)
        else:
            self.augment_transform = self.base_transform

    def load(self, path, augment=False):
        """Load and preprocess image"""
        try:
            ext = os.path.splitext(path)[1].lower()

            if ext == '.npz':
                # Handle npz format
                arr = np.load(path)['arr_0']
                if arr.ndim == 3 and arr.shape[2] == 3:
                    arr = arr.transpose(2, 0, 1)
                elif arr.ndim == 3 and arr.shape[0] == 3:
                    pass  # Already CHW format
                else:
                    # Handle grayscale image
                    if arr.ndim == 2:
                        arr = np.stack([arr] * 3, axis=0)
                    elif arr.ndim == 3 and arr.shape[2] == 1:
                        arr = arr.squeeze(2)
                        arr = np.stack([arr] * 3, axis=0)

                # Normalize to [0,1]
                if arr.max() > 1.0:
                    arr = arr.astype(np.float32) / 255.0

                image = torch.from_numpy(arr).float()

                # Resize
                if image.shape[1] != self.target_size or image.shape[2] != self.target_size:
                    image = F.interpolate(
                        image.unsqueeze(0),
                        size=(self.target_size, self.target_size),
                        mode='bilinear',
                        align_corners=False
                    ).squeeze(0)

            else:
                # Handle common image formats
                img = Image.open(path)

                # Handle RGBA or grayscale images
                if img.mode == 'RGBA':
                    # Create white background
                    background = Image.new('RGB', img.size, (255, 255, 255))
                    background.paste(img, mask=img.split()[-1] if img.mode == 'RGBA' else None)
                    img = background
                elif img.mode == 'L':
                    img = img.convert('RGB')
                elif img.mode != 'RGB':
                    img = img.convert('RGB')

                # Apply transform
                transform = self.augment_transform if augment else self.base_transform
                image = transform(img)

            # Ensure values are in [0,1] range
            image = torch.clamp(image, 0.0, 1.0)

            # Validate image validity
            if torch.isnan(image).any() or torch.isinf(image).any():
                print(f"Warning: Invalid values detected in image {path}")
                return torch.zeros(3, self.target_size, self.target_size)

            return image

        except Exception as e:
            print(f"Error loading image {path}: {e}")
            # Return zero tensor as fallback
            return torch.zeros(3, self.target_size, self.target_size)

    def load_raw(self, path):
        """Read original resolution image without scaling"""
        try:
            ext = os.path.splitext(path)[1].lower()

            if ext == '.npz':
                arr = np.load(path)['arr_0']
                if arr.ndim == 3 and arr.shape[2] == 3:
                    arr = arr.transpose(2, 0, 1)
                elif arr.ndim == 2:
                    arr = np.stack([arr] * 3, axis=0)

                if arr.max() > 1.0:
                    arr = arr.astype(np.float32) / 255.0

                img = torch.from_numpy(arr).float()
            else:
                img_pil = Image.open(path)

                if img_pil.mode == 'RGBA':
                    background = Image.new('RGB', img_pil.size, (255, 255, 255))
                    background.paste(img_pil, mask=img_pil.split()[-1])
                    img_pil = background
                elif img_pil.mode != 'RGB':
                    img_pil = img_pil.convert('RGB')

                arr = np.array(img_pil).transpose(2, 0, 1)
                img = torch.from_numpy(arr).float() / 255.0

            return torch.clamp(img, 0.0, 1.0)

        except Exception as e:
            print(f"Error loading raw image {path}: {e}")
            return torch.zeros(3, 256, 256)  # Default size


class ROIQualityAnalyzer:
    """ROI Quality Analyzer - New feature"""

    def __init__(self, detection_threshold=0.7):
        self.detection_threshold = detection_threshold

    def analyze_roi_quality(self, roi_path, roi_type):
        """
        Analyze ROI detection quality
        Args:
            roi_path: ROI image path
            roi_type: 'Abnormal' or 'Normal'
        Returns:
            quality_type: Detection quality type
        """
        try:
            img = cv2.imread(roi_path)
            if img is None:
                return 'failed'

            # Estimate detection box relative size
            box_ratio = self._estimate_detection_box_ratio(img)

            if roi_type == 'Abnormal':
                if box_ratio < self.detection_threshold:  # Less than 70% considered precise detection
                    return 'abnormal_region_detection'  # Precisely detected abnormal region
                else:
                    return 'abnormal_global_detection'  # Unable to precisely locate abnormality
            else:  # Normal
                return 'normal_global'  # Normal image, basically whole image

        except Exception as e:
            print(f"ROI analysis error {roi_path}: {e}")
            return 'failed'

    def _estimate_detection_box_ratio(self, img):
        """Estimate detection box size ratio relative to image"""
        try:
            gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

            # Use edge detection to find detection box
            edges = cv2.Canny(gray, 50, 150)
            contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

            if not contours:
                return 1.0  # No contours detected, assume whole image

            # Find largest rectangular contour (should be detection box)
            max_area = 0
            for contour in contours:
                x, y, w, h = cv2.boundingRect(contour)
                area = w * h
                if area > max_area:
                    max_area = area

            # Calculate ratio relative to total image area
            total_area = img.shape[0] * img.shape[1]
            return max_area / total_area if total_area > 0 else 1.0

        except Exception as e:
            print(f"Detection box ratio estimation error: {e}")
            return 1.0

    def get_quality_weight(self, quality_type):
        """Return weight based on quality type"""
        weights = {
            'abnormal_region_detection': 1.0,  # Abnormal type small box (precise abnormal detection)
            'abnormal_global_detection': 0.7,  # Abnormal type large box (global detection)
            'normal_global': 0.5,  # Normal type (basically whole image)
            'no_roi': 0.3,  # Weight when no ROI
            'failed': 0.1  # Loading failed
        }
        return weights.get(quality_type, 0.3)


def crop_and_resize_patch(image, x0, y0, x1, y1, size):
    """Crop [x0,y0,x1,y1] from image (C×H×W) and resize to (size×size)"""
    try:
        C, H, W = image.shape

        # Boundary check and correction
        x0 = max(0, min(x0, W - 1))
        y0 = max(0, min(y0, H - 1))
        x1 = max(x0 + 1, min(x1, W))
        y1 = max(y0 + 1, min(y1, H))

        # Ensure region is valid
        if x1 <= x0 or y1 <= y0:
            print(f"Warning: Invalid region [{x0},{y0},{x1},{y1}], using center crop")
            # Use center region of image
            center_x, center_y = W // 2, H // 2
            crop_size = min(W, H) // 4
            x0 = max(0, center_x - crop_size)
            y0 = max(0, center_y - crop_size)
            x1 = min(W, center_x + crop_size)
            y1 = min(H, center_y + crop_size)

        patch = image[:, y0:y1, x0:x1]

        # Resize
        patch = F.interpolate(
            patch.unsqueeze(0),
            size=(size, size),
            mode='bilinear',
            align_corners=False
        ).squeeze(0)

        return patch

    except Exception as e:
        print(f"Error cropping patch: {e}")
        return torch.zeros(image.shape[0], size, size)


def validate_text_inputs(tokenized_text, max_length):
    """Validate text input validity"""
    if 'input_ids' not in tokenized_text or 'attention_mask' not in tokenized_text:
        return False

    input_ids = tokenized_text['input_ids']
    attention_mask = tokenized_text['attention_mask']

    # Check shape
    if input_ids.shape != attention_mask.shape:
        return False

    # Check length
    if input_ids.shape[-1] != max_length:
        return False

    # Check value range
    if torch.any(input_ids < 0) or torch.any(attention_mask < 0):
        return False

    return True


def safe_tokenize(tokenizer, text, max_length=128, return_tensors='pt'):
    """Safe text tokenization with error handling"""
    try:
        if not text or not isinstance(text, str):
            text = ""

        # Limit text length to avoid overly long input
        if len(text) > max_length * 10:  # Rough estimate
            text = text[:max_length * 10]

        result = tokenizer(
            text,
            max_length=max_length,
            padding='max_length',
            truncation=True,
            return_tensors=return_tensors
        )

        # Validate result
        if not validate_text_inputs(result, max_length):
            print(f"Warning: Tokenization result invalid, using empty text")
            result = tokenizer(
                "",
                max_length=max_length,
                padding='max_length',
                truncation=True,
                return_tensors=return_tensors
            )

        return result

    except Exception as e:
        print(f"Tokenization error: {e}, using empty text")
        return tokenizer(
            "",
            max_length=max_length,
            padding='max_length',
            truncation=True,
            return_tensors=return_tensors
        )


def process_fixed_negative_samples(negative_list, tokenizer, max_length, expected_count=5, allow_empty=False):
    """
    Fixed version: Ensure consistent tensor shape return
    """
    try:
        # Ensure input is a list
        if not isinstance(negative_list, list):
            if isinstance(negative_list, str):
                negative_list = [negative_list]
            else:
                negative_list = []

        # Check if empty and allow empty (No Finding case)
        if len(negative_list) == 0 and allow_empty:
            # Key fix: Even in empty case, return correct shape zero tensor
            # Create a tensor with shape [0, max_length]
            dummy_result = tokenizer("", max_length=max_length, padding='max_length', truncation=True,
                                     return_tensors='pt')
            return {
                'input_ids': torch.empty(0, max_length, dtype=dummy_result['input_ids'].dtype),
                'attention_mask': torch.empty(0, max_length, dtype=dummy_result['attention_mask'].dtype)
            }

        # Clean negative samples: only remove None
        cleaned_negatives = []
        for item in negative_list:
            if item is not None:
                text = str(item).strip()
                if len(text) > 0:
                    cleaned_negatives.append(text)

        # If no valid negative samples after cleaning, but expected_count > 0
        if len(cleaned_negatives) == 0 and expected_count > 0:
            # Use default negative samples
            cleaned_negatives = ["No abnormality detected"] * expected_count

        elif len(cleaned_negatives) < expected_count:
            # Pad to expected_count
            if len(cleaned_negatives) > 0:
                original_count = len(cleaned_negatives)
                while len(cleaned_negatives) < expected_count:
                    idx = len(cleaned_negatives) % original_count
                    cleaned_negatives.append(cleaned_negatives[idx])
            else:
                cleaned_negatives = ["No abnormality detected"] * expected_count

        elif len(cleaned_negatives) > expected_count:
            # Take first expected_count
            cleaned_negatives = cleaned_negatives[:expected_count]

        # Key fix: Ensure always return expected_count samples
        final_count = expected_count if expected_count > 0 else len(cleaned_negatives)
        if len(cleaned_negatives) != final_count:
            if final_count == 0:
                # Special case: return empty tensor
                dummy_result = tokenizer("", max_length=max_length, padding='max_length', truncation=True,
                                         return_tensors='pt')
                return {
                    'input_ids': torch.empty(0, max_length, dtype=dummy_result['input_ids'].dtype),
                    'attention_mask': torch.empty(0, max_length, dtype=dummy_result['attention_mask'].dtype)
                }
            else:
                cleaned_negatives = cleaned_negatives[:final_count]

        # Tokenization processing
        if len(cleaned_negatives) == 0:
            # Empty list case
            dummy_result = tokenizer("", max_length=max_length, padding='max_length', truncation=True,
                                     return_tensors='pt')
            return {
                'input_ids': torch.empty(0, max_length, dtype=dummy_result['input_ids'].dtype),
                'attention_mask': torch.empty(0, max_length, dtype=dummy_result['attention_mask'].dtype)
            }
        else:
            # Normal processing: use batch tokenization to ensure consistency
            result = tokenizer(
                cleaned_negatives,
                max_length=max_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )
            return {
                'input_ids': result['input_ids'],  # [N, max_length]
                'attention_mask': result['attention_mask']  # [N, max_length]
            }

    except Exception as e:
        print(f"Error processing negative samples: {e}")
        import traceback
        traceback.print_exc()

        # Safe return in error case
        if expected_count == 0:
            dummy_result = tokenizer("", max_length=max_length, padding='max_length', truncation=True,
                                     return_tensors='pt')
            return {
                'input_ids': torch.empty(0, max_length, dtype=dummy_result['input_ids'].dtype),
                'attention_mask': torch.empty(0, max_length, dtype=dummy_result['attention_mask'].dtype)
            }
        else:
            # Return expected_count default samples
            default_text = ["No abnormality detected"] * expected_count
            result = tokenizer(
                default_text,
                max_length=max_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )
            return {
                'input_ids': result['input_ids'],
                'attention_mask': result['attention_mask']
            }


def process_negative_samples(negative_list, tokenizer, max_length, max_negatives):
    """Process negative sample list (backward compatible with original interface)"""
    try:
        # Ensure enough negative samples
        if len(negative_list) == 0:
            negative_list = [""]

        # Limit negative sample count
        neg_list = negative_list[:max_negatives]

        # Pad to specified count
        while len(neg_list) < max_negatives:
            neg_list.append(random.choice(negative_list) if negative_list else "")

        # Tokenize
        results = []
        for text in neg_list:
            result = safe_tokenize(tokenizer, text, max_length)
            results.append(result)

        # Merge results
        input_ids = torch.cat([r['input_ids'] for r in results], dim=0)
        attention_mask = torch.cat([r['attention_mask'] for r in results], dim=0)

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask
        }

    except Exception as e:
        print(f"Error processing negative samples: {e}")
        # Return empty negative samples
        empty_result = safe_tokenize(tokenizer, "", max_length)
        return {
            'input_ids': empty_result['input_ids'].repeat(max_negatives, 1),
            'attention_mask': empty_result['attention_mask'].repeat(max_negatives, 1)
        }


class DatasetStats:
    """Dataset statistics tool"""

    def __init__(self):
        self.image_sizes = []
        self.text_lengths = []
        self.region_counts = []
        self.roi_quality_stats = {}
        self.load_errors = 0
        self.total_samples = 0
        self.domains = set()

    def update(self, image_size=None, text_length=None, region_count=None,
               roi_quality=None, domain=None, has_error=False):
        self.total_samples += 1
        if has_error:
            self.load_errors += 1
        if image_size:
            self.image_sizes.append(image_size)
        if text_length:
            self.text_lengths.append(text_length)
        if region_count is not None:
            self.region_counts.append(region_count)
        if roi_quality:
            self.roi_quality_stats[roi_quality] = self.roi_quality_stats.get(roi_quality, 0) + 1
        if domain:
            self.domains.add(domain)

    def print_stats(self):
        print("\n=== Dataset Statistics ===")
        print(f"Total samples: {self.total_samples}")
        print(f"Loading errors: {self.load_errors} ({100 * self.load_errors / self.total_samples:.2f}%)")
        print(f"Medical domains: {len(self.domains)} - {', '.join(sorted(self.domains))}")

        if self.image_sizes:
            print(
                f"Image sizes: min={min(self.image_sizes)}, max={max(self.image_sizes)}, avg={np.mean(self.image_sizes):.1f}")
        if self.text_lengths:
            print(
                f"Text lengths: min={min(self.text_lengths)}, max={max(self.text_lengths)}, avg={np.mean(self.text_lengths):.1f}")
        if self.region_counts:
            print(
                f"Region counts: min={min(self.region_counts)}, max={max(self.region_counts)}, avg={np.mean(self.region_counts):.1f}")

        if self.roi_quality_stats:
            print("ROI quality distribution:")
            for quality, count in self.roi_quality_stats.items():
                percentage = 100 * count / self.total_samples
                print(f"  {quality}: {count} ({percentage:.1f}%)")

        print("================\n")