import argparse
import pathlib
import torch
import av
import numpy as np
import pickle
from transformers import (
    TimesformerForVideoClassification, 
    TimesformerModel, 
    TimesformerConfig,
    AutoImageProcessor
)
from tqdm.auto import tqdm
import os
from natsort import natsorted
from prompts import ALL_PROMPTS
from huggingface_hub import hf_hub_download
from pathlib import Path
import cv2
from PIL import Image
import warnings
warnings.filterwarnings('ignore')

# Model and cache setup
HUGGINGFACE_CACHE_DIR = "./"
os.environ['HF_HOME'] = './'
os.environ['HF_DATASETS_CACHE'] = './'
os.environ['TRANSFORMERS_CACHE'] = "./"

# Available TimeSformer models
TIMESFORMER_MODELS = {
    "base-k400": "facebook/timesformer-base-finetuned-k400",
    "base-k600": "facebook/timesformer-base-finetuned-k600",
    "hr-k400": "facebook/timesformer-hr-finetuned-k400",
    "hr-k600": "facebook/timesformer-hr-finetuned-k600",
    "base": "facebook/timesformer-base"
}

class TimeSformerVideoProcessor:
    """
    TimeSformer video processing class for action recognition and feature extraction.
    """
    
    def __init__(self, model_name="base-k400", device="auto", extract_features=False):
        """
        Initialize TimeSformer processor.
        
        Args:
            model_name (str): Model variant to use
            device (str): Device to run model on
            extract_features (bool): Whether to extract features or just classify
        """
        self.model_name = model_name
        self.model_id = TIMESFORMER_MODELS.get(model_name, model_name)
        self.device = device if device != "auto" else ("cuda" if torch.cuda.is_available() else "cpu")
        self.extract_features = extract_features
        
        # Load model and processor
        self._load_model()
        
    def _load_model(self):
        """Load the TimeSformer model and image processor."""
        print(f"Loading TimeSformer model: {self.model_id}")
        
        # Load image processor
        self.image_processor = AutoImageProcessor.from_pretrained(
            "MCG-NJU/videomae-base-finetuned-kinetics",
            cache_dir=HUGGINGFACE_CACHE_DIR
        )
        
        # Load model
        if self.extract_features:
            self.model = TimesformerModel.from_pretrained(
                self.model_id,
                cache_dir=HUGGINGFACE_CACHE_DIR,local_files_only=True,
                torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
            )
        else:
            self.model = TimesformerForVideoClassification.from_pretrained(
                self.model_id,
                cache_dir=HUGGINGFACE_CACHE_DIR,local_files_only=True,
                torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
            )
        
        self.model.to(self.device)
        self.model.eval()
        
        print(f"Model loaded on {self.device}")
        
    def read_video_pyav(self, container, indices):
        """
        Decode selected frames from a video using PyAV.
        
        Args:
            container: PyAV container
            indices: List of frame indices to decode
            
        Returns:
            np.ndarray: Array of decoded frames
        """
        frames = []
        container.seek(0)
        start_index = indices[0]
        end_index = indices[-1]
        
        for i, frame in enumerate(container.decode(video=0)):
            if i > end_index:
                break
            if i >= start_index and i in indices:
                frames.append(frame)
                
        if not frames:
            raise ValueError("No frames could be decoded from the video")
            
        return np.stack([x.to_ndarray(format="rgb24") for x in frames])
    
    def sample_frame_indices(self, clip_len, frame_sample_rate, seg_len):
        """
        Sample frame indices from video.
        
        Args:
            clip_len (int): Number of frames to sample
            frame_sample_rate (int): Sample every n-th frame
            seg_len (int): Total number of frames in video
            
        Returns:
            list: List of sampled frame indices
        """
        converted_len = int(clip_len * frame_sample_rate)
        if converted_len > seg_len:
            # If requested length exceeds video length, sample uniformly
            indices = np.linspace(0, seg_len - 1, num=clip_len)
        else:
            end_idx = min(converted_len, seg_len)
            start_idx = max(0, end_idx - converted_len)
            indices = np.linspace(start_idx, end_idx - 1, num=clip_len)
            
        return np.clip(indices, 0, seg_len - 1).astype(np.int64)
    
    def extract_frames_cv2(self, video_path, num_frames=8):
        """
        Extract frames using OpenCV as fallback.
        
        Args:
            video_path (str): Path to video file
            num_frames (int): Number of frames to extract
            
        Returns:
            list: List of PIL Images
        """
        cap = cv2.VideoCapture(str(video_path))
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        
        if total_frames == 0:
            raise ValueError(f"Could not read video: {video_path}")
        
        # Sample frame indices
        indices = np.linspace(0, total_frames - 1, num_frames).astype(int)
        frames = []
        
        for idx in indices:
            cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
            ret, frame = cap.read()
            if ret:
                frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frames.append(Image.fromarray(frame_rgb))
        
        cap.release()
        return frames
    
    def process_single_video(self, video_path, num_frames=8, frame_sample_rate=1):
        """
        Process a single video file.
        
        Args:
            video_path (str): Path to video file
            num_frames (int): Number of frames to sample
            frame_sample_rate (int): Frame sampling rate
            
        Returns:
            dict: Processing results
        """
        try:
            # Try PyAV first
            container = av.open(str(video_path))
            total_frames = container.streams.video[0].frames
            
            # Sample frames
            indices = self.sample_frame_indices(num_frames, frame_sample_rate, total_frames)
            video_frames = self.read_video_pyav(container, indices)
            
            # Convert to PIL Images
            frames = [Image.fromarray(frame) for frame in video_frames]
            
        except Exception as e:
            print(f"PyAV failed for {video_path}, trying OpenCV: {e}")
            try:
                frames = self.extract_frames_cv2(video_path, num_frames)
            except Exception as e2:
                print(f"OpenCV also failed for {video_path}: {e2}")
                return None
        
        # Prepare inputs
        inputs = self.image_processor(frames, return_tensors="pt")
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        
        # Process through model
        with torch.no_grad():
            if self.extract_features:
                outputs = self.model(**inputs, output_hidden_states=True)
                
                # Extract features from different layers (similar to Qwen format)
                hidden_states = outputs.hidden_states
                last_hidden_state = outputs.last_hidden_state
                print(hidden_states[-1].shape, last_hidden_state.shape)
                
                
                # Get CLS token representation (first token)
                cls_features = last_hidden_state[:, 0, :].cpu().numpy()
                
                # Average pool all tokens except CLS
                pooled_features = torch.mean(last_hidden_state[:, 1:, :], dim=1).cpu().numpy()
                
                # Format hidden states similar to Qwen (average across layers)
                if hidden_states:
                    hidden_states_to_save = np.array([
                        h[0].cpu().float().numpy() for h in hidden_states  # Last 4 layers
                    ])
                    # Average across layers like Qwen does
                    language_hidden_states = np.average(hidden_states_to_save[:, :, :], axis=1)
                else:
                    language_hidden_states = None
                
                result = {
                    "video_path": str(video_path),
                    "cls_features": cls_features,
                    "pooled_features": pooled_features,
                    "language_hidden_states": language_hidden_states  # Consistent with Qwen naming
                }
            else:
                outputs = self.model(**inputs)
                logits = outputs.logits
                predicted_class = torch.argmax(logits, dim=-1).item()
                confidence = torch.softmax(logits, dim=-1).max().item()
                
                # Get class label if available
                class_label = None
                if hasattr(self.model.config, 'id2label'):
                    class_label = self.model.config.id2label.get(predicted_class, f"class_{predicted_class}")
                
                # Extract hidden states for consistency with Qwen format
                hidden_states = None
                if hasattr(outputs, 'hidden_states') and outputs.hidden_states:
                    hidden_states_to_save = np.array([
                        h.cpu().float().numpy() for h in outputs.hidden_states[-4:]
                    ])
                    language_hidden_states = np.average(hidden_states_to_save[:, 0, 0, :], axis=0)
                else:
                    language_hidden_states = None
                
                result = {
                    "video_path": str(video_path),
                    "generated_text": [class_label] if class_label else [f"class_{predicted_class}"],  # Consistent with Qwen
                    "predicted_class": predicted_class,
                    "confidence": confidence,
                    "logits": logits.cpu().numpy(),
                    "language_hidden_states": language_hidden_states  # Consistent with Qwen naming
                }
        
        return result
    
    def process_videos_batch(self, video_paths, batch_size=1, num_frames=8, frame_sample_rate=1):
        """
        Process multiple videos in batches.
        
        Args:
            video_paths (list): List of video file paths
            batch_size (int): Batch size for processing
            num_frames (int): Number of frames to sample per video
            frame_sample_rate (int): Frame sampling rate
            
        Returns:
            list: List of processing results
        """
        results = []
        
        for i in tqdm(range(0, len(video_paths), batch_size), desc="Processing videos"):
            batch_paths = video_paths[i:i + batch_size]
            
            for video_path in batch_paths:
                result = self.process_single_video(video_path, num_frames, frame_sample_rate)
                if result is not None:
                    results.append(result)
                else:
                    print(f"Skipped {video_path} due to processing error")
        
        return results

def batchify(data, batch_size):
    """Split data into batches."""
    return [data[i:i + batch_size] for i in range(0, len(data), batch_size)]

def process_videos_with_timesformer(video_paths, batch_size, output_dir, model_name="base-k400", 
                                   extract_features=False, num_frames=8, frame_sample_rate=1):
    """
    Main processing function for TimeSformer video analysis.
    
    Args:
        video_paths (list): List of video file paths
        batch_size (int): Batch size for processing
        output_dir (Path): Output directory for results
        model_name (str): TimeSformer model variant
        extract_features (bool): Whether to extract features or classify
        num_frames (int): Number of frames to sample per video
        frame_sample_rate (int): Frame sampling rate
    """
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Initialize processor
    processor = TimeSformerVideoProcessor(
        model_name=model_name,
        extract_features=extract_features
    )
    
    # Process videos in batches
    all_results = []
    
    for batch_num, batch in enumerate(tqdm(batchify(video_paths, batch_size), 
                                          desc="Processing batches")):
        print(f"\nProcessing batch {batch_num + 1}/{len(video_paths) // batch_size + 1}")
        
        batch_results = processor.process_videos_batch(
            batch, 
            batch_size=1,  # Process one at a time for memory efficiency
            num_frames=num_frames,
            frame_sample_rate=frame_sample_rate
        )
        
        all_results.extend(batch_results)
        
        # Save batch results (consistent with Qwen naming)
        batch_file = output_dir / f"batch_{batch_num + 1}.pkl"
        with open(batch_file, "wb") as f:
            pickle.dump(batch_results, f)
        
        print(f"✅ Saved batch {batch_num + 1} results to {batch_file}")
    
    # Save all results
    final_file = output_dir / "timesformer_all_results.pkl"
    with open(final_file, "wb") as f:
        pickle.dump(all_results, f)
    
    print(f"✅ Processing complete! All results saved to {final_file}")
    return all_results

def main():
    """Command-line interface for TimeSformer video processing."""
    parser = argparse.ArgumentParser(description="Video Processing with TimeSformer")
    parser.add_argument("-v", "--video-dir", required=True, type=pathlib.Path, 
                       help="Directory containing video files")
    parser.add_argument("-b", "--batch-size", default=4, type=int, 
                       help="Batch size for video processing")
    parser.add_argument("-d", "--output-dir", required=True, type=pathlib.Path, 
                       help="Directory to save outputs")
    parser.add_argument("-m", "--model-name", default="base-k400", 
                       choices=list(TIMESFORMER_MODELS.keys()),
                       help="TimeSformer model variant to use")
    parser.add_argument("--extract-features", action="store_true",
                       help="Extract features instead of classification")
    parser.add_argument("--num-frames", default=8, type=int,
                       help="Number of frames to sample per video")
    parser.add_argument("--frame-sample-rate", default=1, type=int,
                       help="Frame sampling rate")
    parser.add_argument("--file-pattern", default="*.mp4", type=str,
                       help="File pattern to match video files")
    
    args = parser.parse_args()
    
    # Get video paths
    video_paths = natsorted(list(args.video_dir.glob(args.file_pattern)))
    
    if not video_paths:
        print(f"No video files found in {args.video_dir} with pattern {args.file_pattern}")
        return
    
    print(f"Found {len(video_paths)} video files")
    print(f"First video: {video_paths[0]}")
    
    # Create output directory
    model_name_clean = args.model_name.replace("-", "_")
    task_type = "features" if args.extract_features else "classification"
    output_dir = args.output_dir / f"timesformer_{model_name_clean}_{task_type}"
    
    # Process videos
    results = process_videos_with_timesformer(
        video_paths=video_paths,
        batch_size=args.batch_size,
        output_dir=output_dir,
        model_name=args.model_name,
        extract_features=args.extract_features,
        num_frames=args.num_frames,
        frame_sample_rate=args.frame_sample_rate
    )
    
    # Print summary
    print(f"\n📊 Processing Summary:")
    print(f"Total videos processed: {len(results)}")
    print(f"Model used: {args.model_name}")
    print(f"Task: {'Feature extraction' if args.extract_features else 'Classification'}")
    print(f"Frames per video: {args.num_frames}")
    print(f"Results saved to: {output_dir}")

if __name__ == "__main__":
    main()
