import torch
import torch.nn.functional as F
import numpy as np
import cv2
import matplotlib.pyplot as plt
from typing import Union, Tuple, Optional

def resize_image(image: np.ndarray, target_size: Tuple[int, int], keep_aspect_ratio: bool = True) -> np.ndarray:
    h, w = image.shape[:2]
    target_h, target_w = target_size
    
    if keep_aspect_ratio:
        scale = min(target_h / h, target_w / w)
        new_h, new_w = int(h * scale), int(w * scale)
        resized = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
        
        if len(image.shape) == 3:
            result = np.zeros((target_h, target_w, image.shape[2]), dtype=image.dtype)
        else:
            result = np.zeros((target_h, target_w), dtype=image.dtype)
            
        y_offset = (target_h - new_h) // 2
        x_offset = (target_w - new_w) // 2
        
        if len(image.shape) == 3:
            result[y_offset:y_offset+new_h, x_offset:x_offset+new_w] = resized
        else:
            result[y_offset:y_offset+new_h, x_offset:x_offset+new_w] = resized
            
        return result
    else:
        return cv2.resize(image, (target_w, target_h), interpolation=cv2.INTER_LINEAR)

def normalize_image(image: np.ndarray, mean: Tuple[float, float, float] = (0.485, 0.456, 0.406), 
                   std: Tuple[float, float, float] = (0.229, 0.224, 0.225)) -> np.ndarray:
    image = image.astype(np.float32) / 255.0
    for i in range(3):
        image[:, :, i] = (image[:, :, i] - mean[i]) / std[i]
    return image

def denormalize_image(image: np.ndarray, mean: Tuple[float, float, float] = (0.485, 0.456, 0.406), 
                     std: Tuple[float, float, float] = (0.229, 0.224, 0.225)) -> np.ndarray:
    for i in range(3):
        image[:, :, i] = image[:, :, i] * std[i] + mean[i]
    image = np.clip(image * 255.0, 0, 255).astype(np.uint8)
    return image

def tensor_to_numpy(tensor: torch.Tensor) -> np.ndarray:
    if tensor.dim() == 4:
        tensor = tensor.squeeze(0)
    if tensor.dim() == 3:
        tensor = tensor.permute(1, 2, 0)
    return tensor.detach().cpu().numpy()

def numpy_to_tensor(array: np.ndarray, device: str = 'cpu') -> torch.Tensor:
    if array.ndim == 3:
        array = array.transpose(2, 0, 1)
    if array.ndim == 3:
        array = array[None, ...]
    return torch.from_numpy(array).to(device)

def visualize_depth(depth_map: np.ndarray, colormap: int = cv2.COLORMAP_VIRIDIS) -> np.ndarray:
    depth_normalized = cv2.normalize(depth_map, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)
    depth_colored = cv2.applyColorMap(depth_normalized, colormap)
    return depth_colored

def create_comparison_grid(images: list, titles: list = None, figsize: Tuple[int, int] = (15, 5)) -> plt.Figure:
    n_images = len(images)
    fig, axes = plt.subplots(1, n_images, figsize=figsize)
    
    if n_images == 1:
        axes = [axes]
        
    for i, (ax, img) in enumerate(zip(axes, images)):
        if len(img.shape) == 3 and img.shape[2] == 3:
            ax.imshow(img)
        else:
            ax.imshow(img, cmap='viridis')
        ax.axis('off')
        if titles and i < len(titles):
            ax.set_title(titles[i])
            
    plt.tight_layout()
    return fig

def calculate_psnr(img1: np.ndarray, img2: np.ndarray) -> float:
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')
    max_pixel = 255.0
    psnr = 20 * np.log10(max_pixel / np.sqrt(mse))
    return psnr

def calculate_ssim(img1: np.ndarray, img2: np.ndarray) -> float:
    try:
        from skimage.metrics import structural_similarity as ssim
        if len(img1.shape) == 3:
            return ssim(img1, img2, multichannel=True, channel_axis=2)
        else:
            return ssim(img1, img2)
    except ImportError:
        print("scikit-image required for SSIM calculation")
        return 0.0

def save_results(results: dict, output_dir: str, filename_prefix: str = "result"):
    import os
    os.makedirs(output_dir, exist_ok=True)
    
    if 'output' in results and results['output'] is not None:
        output_img = tensor_to_numpy(results['output'])
        output_img = np.clip(output_img * 255, 0, 255).astype(np.uint8)
        cv2.imwrite(os.path.join(output_dir, f"{filename_prefix}_output.png"), 
                   cv2.cvtColor(output_img, cv2.COLOR_RGB2BGR))
    
    if 'depth' in results and results['depth'] is not None:
        depth_map = tensor_to_numpy(results['depth'])
        depth_colored = visualize_depth(depth_map)
        cv2.imwrite(os.path.join(output_dir, f"{filename_prefix}_depth.png"), depth_colored)
        
        depth_raw = (depth_map * 65535).astype(np.uint16)
        cv2.imwrite(os.path.join(output_dir, f"{filename_prefix}_depth_raw.png"), depth_raw)
    
    if 'deblurred' in results and results['deblurred'] is not None:
        deblurred_img = tensor_to_numpy(results['deblurred'])
        deblurred_img = np.clip(deblurred_img * 255, 0, 255).astype(np.uint8)
        cv2.imwrite(os.path.join(output_dir, f"{filename_prefix}_deblurred.png"), 
                   cv2.cvtColor(deblurred_img, cv2.COLOR_RGB2BGR))

class ModelProfiler:
    def __init__(self, model: torch.nn.Module):
        self.model = model
        self.device = next(model.parameters()).device
        
    def count_parameters(self):
        total_params = sum(p.numel() for p in self.model.parameters())
        trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        return total_params, trainable_params
    
    def measure_inference_time(self, input_shape: Tuple[int, ...], num_runs: int = 100):
        dummy_input = torch.randn(input_shape).to(self.device)
        
        self.model.eval()
        
        with torch.no_grad():
            for _ in range(10):
                _ = self.model(dummy_input)
        
        torch.cuda.synchronize() if torch.cuda.is_available() else None
        
        import time
        start_time = time.time()
        
        with torch.no_grad():
            for _ in range(num_runs):
                _ = self.model(dummy_input)
                
        torch.cuda.synchronize() if torch.cuda.is_available() else None
        end_time = time.time()
        
        avg_time = (end_time - start_time) / num_runs
        fps = 1.0 / avg_time
        
        return avg_time, fps
    
    def get_model_summary(self, input_shape: Tuple[int, ...]):
        total_params, trainable_params = self.count_parameters()
        avg_time, fps = self.measure_inference_time(input_shape)
        
        summary = {
            'total_parameters': total_params,
            'trainable_parameters': trainable_params,
            'inference_time_ms': avg_time * 1000,
            'fps': fps,
            'model_size_mb': total_params * 4 / (1024 * 1024)
        }
        
        return summary

def adaptive_threshold_analysis(exit_decisions: list, threshold_range: Tuple[float, float] = (0.1, 0.9), 
                              num_points: int = 50):
    thresholds = np.linspace(threshold_range[0], threshold_range[1], num_points)
    early_exit_rates = []
    
    for threshold in thresholds:
        total_decisions = 0
        early_exits = 0
        
        for batch_decisions in exit_decisions:
            for decision_tensor in batch_decisions:
                decisions = torch.softmax(decision_tensor, dim=1)[:, 1]
                total_decisions += decisions.numel()
                early_exits += (decisions > threshold).sum().item()
        
        early_exit_rate = early_exits / total_decisions if total_decisions > 0 else 0
        early_exit_rates.append(early_exit_rate)
    
    return thresholds, early_exit_rates

def load_image(image_path: str, target_size: Optional[Tuple[int, int]] = None) -> np.ndarray:
    image = cv2.imread(image_path)
    if image is None:
        raise ValueError(f"Could not load image from {image_path}")
    
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    if target_size:
        image = resize_image(image, target_size)
    
    return image

def create_depth_visualization(depth_map: np.ndarray, original_image: Optional[np.ndarray] = None) -> np.ndarray:
    depth_colored = visualize_depth(depth_map)
    
    if original_image is not None:
        if original_image.shape[:2] != depth_map.shape[:2]:
            original_image = cv2.resize(original_image, (depth_map.shape[1], depth_map.shape[0]))
        
        visualization = np.hstack([original_image, depth_colored])
    else:
        visualization = depth_colored
    
    return visualization 