"""
TimeSformer utilities and helper functions.
"""

import torch
import numpy as np
from typing import List, Dict, Any, Optional, Tuple
import cv2
from PIL import Image
import av

class TimeSformerConfig:
    """Configuration class for TimeSformer processing."""
    
    # Default configurations for different TimeSformer variants
    MODEL_CONFIGS = {
        "base-k400": {
            "num_frames": 8,
            "image_size": 224,
            "patch_size": 16,
            "num_classes": 400,
            "attention_type": "divided_space_time"
        },
        "base-k600": {
            "num_frames": 8,
            "image_size": 224,
            "patch_size": 16,
            "num_classes": 600,
            "attention_type": "divided_space_time"
        },
        "hr-k400": {
            "num_frames": 16,
            "image_size": 448,
            "patch_size": 16,
            "num_classes": 400,
            "attention_type": "divided_space_time"
        },
        "hr-k600": {
            "num_frames": 16,
            "image_size": 448,
            "patch_size": 16,
            "num_classes": 600,
            "attention_type": "divided_space_time"
        }
    }
    
    @classmethod
    def get_config(cls, model_name: str) -> Dict[str, Any]:
        """Get configuration for a specific model variant."""
        return cls.MODEL_CONFIGS.get(model_name, cls.MODEL_CONFIGS["base-k400"])

class VideoFrameExtractor:
    """Advanced video frame extraction utilities."""
    
    @staticmethod
    def extract_frames_uniform(video_path: str, num_frames: int = 8) -> List[Image.Image]:
        """
        Extract frames uniformly distributed across the video.
        
        Args:
            video_path: Path to video file
            num_frames: Number of frames to extract
            
        Returns:
            List of PIL Images
        """
        cap = cv2.VideoCapture(video_path)
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        
        if total_frames <= 0:
            raise ValueError(f"Could not read video or video is empty: {video_path}")
        
        # Calculate frame indices for uniform sampling
        if total_frames <= num_frames:
            # If video has fewer frames than requested, use all frames
            indices = list(range(total_frames))
        else:
            indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
        
        frames = []
        for idx in indices:
            cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
            ret, frame = cap.read()
            if ret:
                frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frames.append(Image.fromarray(frame_rgb))
        
        cap.release()
        
        if not frames:
            raise ValueError(f"No frames could be extracted from video: {video_path}")
        
        return frames
    
    @staticmethod
    def extract_frames_temporal_segments(video_path: str, num_frames: int = 8, 
                                       segment_overlap: float = 0.1) -> List[List[Image.Image]]:
        """
        Extract multiple temporal segments from video for robust analysis.
        
        Args:
            video_path: Path to video file
            num_frames: Number of frames per segment
            segment_overlap: Overlap between segments (0.0 to 1.0)
            
        Returns:
            List of frame segments, each containing PIL Images
        """
        cap = cv2.VideoCapture(video_path)
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        
        if total_frames <= num_frames:
            # Single segment if video is short
            return [VideoFrameExtractor.extract_frames_uniform(video_path, num_frames)]
        
        # Calculate segment parameters
        segment_length = num_frames
        step_size = int(segment_length * (1 - segment_overlap))
        
        segments = []
        start_frame = 0
        
        while start_frame + segment_length <= total_frames:
            # Extract frames for this segment
            indices = np.linspace(start_frame, start_frame + segment_length - 1, 
                                num_frames, dtype=int)
            
            frames = []
            for idx in indices:
                cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
                ret, frame = cap.read()
                if ret:
                    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                    frames.append(Image.fromarray(frame_rgb))
            
            if frames:
                segments.append(frames)
            
            start_frame += step_size
        
        cap.release()
        return segments
    
    @staticmethod
    def extract_frames_pyav_robust(video_path: str, num_frames: int = 8) -> List[Image.Image]:
        """
        Robust frame extraction using PyAV with fallback strategies.
        
        Args:
            video_path: Path to video file
            num_frames: Number of frames to extract
            
        Returns:
            List of PIL Images
        """
        try:
            container = av.open(video_path)
            video_stream = container.streams.video[0]
            total_frames = video_stream.frames
            
            if total_frames <= 0:
                # Try counting frames manually
                total_frames = sum(1 for _ in container.decode(video=0))
                container.seek(0)
            
            # Calculate sampling indices
            if total_frames <= num_frames:
                indices = list(range(total_frames))
            else:
                indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
            
            frames = []
            container.seek(0)
            
            frame_idx = 0
            target_indices = set(indices)
            
            for frame in container.decode(video=0):
                if frame_idx in target_indices:
                    frame_array = frame.to_ndarray(format="rgb24")
                    frames.append(Image.fromarray(frame_array))
                
                frame_idx += 1
                
                if frame_idx > max(indices):
                    break
            
            container.close()
            return frames
            
        except Exception as e:
            print(f"PyAV extraction failed: {e}, falling back to OpenCV")
            return VideoFrameExtractor.extract_frames_uniform(video_path, num_frames)

class TimeSformerFeatureAnalyzer:
    """Utilities for analyzing TimeSformer features and outputs."""
    
    @staticmethod
    def compute_temporal_attention_maps(attention_weights: torch.Tensor) -> torch.Tensor:
        """
        Compute temporal attention maps from TimeSformer attention weights.
        
        Args:
            attention_weights: Attention weights tensor [batch, heads, seq_len, seq_len]
            
        Returns:
            Temporal attention maps
        """
        # Average across attention heads
        avg_attention = attention_weights.mean(dim=1)
        
        # Extract temporal attention (assumes divided space-time attention)
        # This is a simplified version - actual implementation depends on model architecture
        temporal_attention = avg_attention[:, 0, 1:]  # CLS to patch tokens
        
        return temporal_attention
    
    @staticmethod
    def extract_spatial_features(hidden_states: torch.Tensor, num_frames: int = 8) -> torch.Tensor:
        """
        Extract spatial features from TimeSformer hidden states.
        
        Args:
            hidden_states: Hidden states tensor [batch, seq_len, hidden_dim]
            num_frames: Number of frames in the input
            
        Returns:
            Spatial features tensor
        """
        batch_size, seq_len, hidden_dim = hidden_states.shape
        
        # Remove CLS token
        patch_tokens = hidden_states[:, 1:, :]
        
        # Reshape to separate spatial and temporal dimensions
        patches_per_frame = (seq_len - 1) // num_frames
        spatial_features = patch_tokens.view(batch_size, num_frames, patches_per_frame, hidden_dim)
        
        # Average across temporal dimension to get spatial features
        spatial_features = spatial_features.mean(dim=1)
        
        return spatial_features
    
    @staticmethod
    def extract_temporal_features(hidden_states: torch.Tensor, num_frames: int = 8) -> torch.Tensor:
        """
        Extract temporal features from TimeSformer hidden states.
        
        Args:
            hidden_states: Hidden states tensor [batch, seq_len, hidden_dim]
            num_frames: Number of frames in the input
            
        Returns:
            Temporal features tensor
        """
        batch_size, seq_len, hidden_dim = hidden_states.shape
        
        # Remove CLS token
        patch_tokens = hidden_states[:, 1:, :]
        
        # Reshape to separate spatial and temporal dimensions
        patches_per_frame = (seq_len - 1) // num_frames
        temporal_features = patch_tokens.view(batch_size, num_frames, patches_per_frame, hidden_dim)
        
        # Average across spatial dimension to get temporal features
        temporal_features = temporal_features.mean(dim=2)
        
        return temporal_features
    
    @staticmethod
    def compute_action_similarity(features1: np.ndarray, features2: np.ndarray, 
                                method: str = "cosine") -> float:
        """
        Compute similarity between action features.
        
        Args:
            features1: First feature vector
            features2: Second feature vector
            method: Similarity method ("cosine", "euclidean", "dot")
            
        Returns:
            Similarity score
        """
        if method == "cosine":
            norm1 = np.linalg.norm(features1)
            norm2 = np.linalg.norm(features2)
            if norm1 == 0 or norm2 == 0:
                return 0.0
            return np.dot(features1, features2) / (norm1 * norm2)
        
        elif method == "euclidean":
            return -np.linalg.norm(features1 - features2)
        
        elif method == "dot":
            return np.dot(features1, features2)
        
        else:
            raise ValueError(f"Unknown similarity method: {method}")

class TimeSformerDataAugmentation:
    """Data augmentation utilities for TimeSformer training/inference."""
    
    @staticmethod
    def temporal_crop(frames: List[Image.Image], crop_ratio: float = 0.8) -> List[Image.Image]:
        """
        Perform temporal cropping on frame sequence.
        
        Args:
            frames: List of PIL Images
            crop_ratio: Ratio of frames to keep
            
        Returns:
            Cropped frame sequence
        """
        num_frames = len(frames)
        crop_length = int(num_frames * crop_ratio)
        
        if crop_length >= num_frames:
            return frames
        
        start_idx = np.random.randint(0, num_frames - crop_length + 1)
        return frames[start_idx:start_idx + crop_length]
    
    @staticmethod
    def temporal_subsample(frames: List[Image.Image], target_frames: int = 8) -> List[Image.Image]:
        """
        Subsample frames to target number.
        
        Args:
            frames: List of PIL Images
            target_frames: Target number of frames
            
        Returns:
            Subsampled frame sequence
        """
        if len(frames) <= target_frames:
            return frames
        
        indices = np.linspace(0, len(frames) - 1, target_frames, dtype=int)
        return [frames[i] for i in indices]
    
    @staticmethod
    def spatial_crop_consistent(frames: List[Image.Image], crop_size: Tuple[int, int],
                              position: Optional[str] = None) -> List[Image.Image]:
        """
        Apply consistent spatial cropping across all frames.
        
        Args:
            frames: List of PIL Images
            crop_size: (width, height) of crop
            position: Crop position ("center", "random", or None for random)
            
        Returns:
            Cropped frames
        """
        if not frames:
            return frames
        
        width, height = frames[0].size
        crop_w, crop_h = crop_size
        
        if crop_w >= width and crop_h >= height:
            return frames
        
        if position == "center":
            left = (width - crop_w) // 2
            top = (height - crop_h) // 2
        else:  # random
            left = np.random.randint(0, width - crop_w + 1)
            top = np.random.randint(0, height - crop_h + 1)
        
        right = left + crop_w
        bottom = top + crop_h
        
        return [frame.crop((left, top, right, bottom)) for frame in frames]

def create_kinetics_label_mapping():
    """Create label mapping for Kinetics datasets."""
    # This would typically load from a file or API
    # Placeholder implementation
    kinetics_400_labels = {
        i: f"action_{i}" for i in range(400)
    }
    
    # You can replace this with actual Kinetics-400 labels
    sample_labels = {
        0: "eating_spaghetti",
        1: "playing_basketball", 
        2: "dancing",
        3: "cooking",
        4: "swimming"
        # ... add more as needed
    }
    
    kinetics_400_labels.update(sample_labels)
    return kinetics_400_labels

def analyze_video_embeddings(embeddings: List[np.ndarray], 
                           labels: Optional[List[str]] = None) -> Dict[str, Any]:
    """
    Analyze a collection of video embeddings.
    
    Args:
        embeddings: List of embedding vectors
        labels: Optional labels for the videos
        
    Returns:
        Analysis results dictionary
    """
    embeddings_array = np.array(embeddings)
    
    # Basic statistics
    mean_embedding = np.mean(embeddings_array, axis=0)
    std_embedding = np.std(embeddings_array, axis=0)
    
    # Compute pairwise similarities
    similarities = []
    for i in range(len(embeddings)):
        for j in range(i + 1, len(embeddings)):
            sim = TimeSformerFeatureAnalyzer.compute_action_similarity(
                embeddings[i], embeddings[j]
            )
            similarities.append(sim)
    
    analysis = {
        "num_videos": len(embeddings),
        "embedding_dim": embeddings_array.shape[1],
        "mean_embedding": mean_embedding,
        "std_embedding": std_embedding,
        "mean_similarity": np.mean(similarities) if similarities else 0,
        "std_similarity": np.std(similarities) if similarities else 0,
        "min_similarity": np.min(similarities) if similarities else 0,
        "max_similarity": np.max(similarities) if similarities else 0
    }
    
    if labels:
        # Group by labels if available
        unique_labels = list(set(labels))
        label_groups = {label: [] for label in unique_labels}
        
        for embedding, label in zip(embeddings, labels):
            label_groups[label].append(embedding)
        
        # Compute within-class and between-class similarities
        within_class_sims = []
        between_class_sims = []
        
        for label, group_embeddings in label_groups.items():
            if len(group_embeddings) > 1:
                for i in range(len(group_embeddings)):
                    for j in range(i + 1, len(group_embeddings)):
                        sim = TimeSformerFeatureAnalyzer.compute_action_similarity(
                            group_embeddings[i], group_embeddings[j]
                        )
                        within_class_sims.append(sim)
        
        # Between-class similarities
        label_list = list(label_groups.keys())
        for i in range(len(label_list)):
            for j in range(i + 1, len(label_list)):
                group1 = label_groups[label_list[i]]
                group2 = label_groups[label_list[j]]
                
                for emb1 in group1:
                    for emb2 in group2:
                        sim = TimeSformerFeatureAnalyzer.compute_action_similarity(emb1, emb2)
                        between_class_sims.append(sim)
        
        analysis.update({
            "num_classes": len(unique_labels),
            "within_class_similarity": {
                "mean": np.mean(within_class_sims) if within_class_sims else 0,
                "std": np.std(within_class_sims) if within_class_sims else 0
            },
            "between_class_similarity": {
                "mean": np.mean(between_class_sims) if between_class_sims else 0,
                "std": np.std(between_class_sims) if between_class_sims else 0
            }
        })
    
    return analysis
