from pathlib import Path
from typing import List, Tuple, Dict, Any
from PIL import Image
import json
import random
import cv2
import numpy as np
import math

def fill_transparent_with_white(input_path):
    img = Image.open(input_path)

    if img.mode != 'RGBA':
        img = img.convert('RGBA')
    
    background = Image.new('RGB', img.size, (255, 255, 255))
    background.paste(img, mask=img.split()[-1])
    
    return background

def augment_reference_image(img: Image.Image) -> Image.Image:
    # Random horizontal flip (50% probability)
    if random.random() < 0.5:
        img = img.transpose(Image.FLIP_LEFT_RIGHT)
    
    # Random stretch (30% probability)
    if random.random() < 0.3:
        orig_size = img.size
        # Randomly choose to stretch width or height
        if random.random() < 0.5:
            # Stretch width by ±20%
            new_width = int(orig_size[0] * random.uniform(0.8, 1.2))
            img = img.resize((new_width, orig_size[1]), Image.BILINEAR)
            # Resize back to square while maintaining the stretch effect
            img = img.resize(orig_size, Image.BILINEAR)
        else:
            # Stretch height by ±20%
            new_height = int(orig_size[1] * random.uniform(0.8, 1.2))
            img = img.resize((orig_size[0], new_height), Image.BILINEAR)
            # Resize back to square while maintaining the stretch effect
            img = img.resize(orig_size, Image.BILINEAR)
    
    # Random affine transformation (40% probability)
    if random.random() < 0.4:
        # Get image size
        w, h = img.size
        
        # Define random transformation parameters
        # Rotation: ±10 degrees
        angle = random.uniform(-10, 10)
        # Scale: ±10%
        scale = random.uniform(0.9, 1.1)
        # Translation: ±5% of image size
        translate_x = random.uniform(-0.05, 0.05) * w
        translate_y = random.uniform(-0.05, 0.05) * h
        # Shear: ±5 degrees
        shear = random.uniform(-5, 5)
        
        # Apply affine transformation
        img = img.transform(
            img.size, 
            Image.AFFINE, 
            data=[
                scale * math.cos(math.radians(angle)), 
                -math.sin(math.radians(angle + shear)),
                translate_x,
                math.sin(math.radians(angle)), 
                scale * math.cos(math.radians(angle + shear)),
                translate_y
            ],
            resample=Image.BILINEAR,
            fillcolor=(255, 255, 255)  # Fill with white background
        )
    
    return img

def process_reference_image(input_path: str, eval: bool = False) -> Image.Image:
    if not eval:
        pil_image = augment_reference_image(fill_transparent_with_white(input_path)) # Apply augmentations
    else:
        pil_image = fill_transparent_with_white(input_path)
    img = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)

    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    
    _, binary = cv2.threshold(gray, 250, 255, cv2.THRESH_BINARY_INV)
    
    contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    if not contours:
        raise ValueError("No content found in the image")
    
    x_min = y_min = float('inf')
    x_max = y_max = 0
    
    for contour in contours:
        x, y, w, h = cv2.boundingRect(contour)
        x_min = min(x_min, x)
        y_min = min(y_min, y)
        x_max = max(x_max, x + w)
        y_max = max(y_max, y + h)
    
    width = x_max - x_min
    height = y_max - y_min
    
    square_size = max(width, height)
    
    pad_x = (square_size - width) // 2
    pad_y = (square_size - height) // 2
    
    square_img = np.full((square_size, square_size, 3), 255, dtype=np.uint8)
    
    content = img[y_min:y_max, x_min:x_max]
    square_img[pad_y:pad_y+height, pad_x:pad_x+width] = content
    
    square_img = cv2.cvtColor(square_img, cv2.COLOR_BGR2RGB)
    
    pil_image = Image.fromarray(square_img)
    
    return pil_image