import pickle
import os
import torch
import numpy as np
from pathlib import Path
import fcntl

def save_vision_attention_scores(outputs, file_path, 
                                vision_token_start, vision_token_end, num_vision_tokens,
                                image_path=None, prompt=None, response=None, 
                                FIXED_ATTENTION_POWER=0.3):
    """
    Extract and save vision token attention scores for later analysis.
    """
    # Extract attention from LAST TOKEN's LAST LAYER only
    last_token_last_layer_attention = outputs["attentions"][-1][-1].squeeze(0)  # [num_heads, seq_len, seq_len]
    last_token_last_layer_attention = last_token_last_layer_attention.mean(dim=0)  # Average across heads
    last_token_last_layer_attention = last_token_last_layer_attention[-1].cpu()  # Get last token's attention row

    # Extract vision token attention (raw values)
    vision_attention_raw = last_token_last_layer_attention[vision_token_start:vision_token_end]
    
    # Apply consistent power scaling (same as visualization)
    vision_attention_enhanced = torch.pow(vision_attention_raw, FIXED_ATTENTION_POWER)
    
    # Extract full attention pattern for context
    full_attention_raw = last_token_last_layer_attention
    full_attention_enhanced = torch.pow(full_attention_raw, FIXED_ATTENTION_POWER)
    
    # Calculate statistics
    vision_attention_sum = vision_attention_raw.sum().item()
    total_attention_sum = full_attention_raw.sum().item()
    vision_percentage = (vision_attention_sum / total_attention_sum) * 100 if total_attention_sum > 0 else 0
    
    # GET PROCESS/GPU ID for unique filename
    import torch.distributed as dist
    if dist.is_initialized():
        rank = dist.get_rank()
    else:
        rank = torch.cuda.current_device() if torch.cuda.is_available() else 0
    
    # Modify file path to include rank
    base_path = Path(file_path)
    process_file_path = base_path.parent / f"{base_path.stem}_rank_{rank}{base_path.suffix}"
    
    # Prepare data to save
    attention_data = {
        'rank': rank,
        'timestamp': torch.now() if hasattr(torch, 'now') else None,
        'image_path': image_path,
        'prompt': prompt,
        'response': response,
        
        # Vision token attention data
        'vision_attention_raw': vision_attention_raw.clone(),
        'vision_attention_enhanced': vision_attention_enhanced.clone(),
        'vision_token_start': vision_token_start,
        'vision_token_end': vision_token_end,
        'num_vision_tokens': num_vision_tokens,
        
        # Full attention context
        'full_attention_raw': full_attention_raw.clone(),
        'full_attention_enhanced': full_attention_enhanced.clone(),
        'total_tokens': len(full_attention_raw),
        
        # Statistics
        'vision_attention_sum': vision_attention_sum,
        'total_attention_sum': total_attention_sum,
        'vision_percentage': vision_percentage,
        'max_vision_attention': vision_attention_raw.max().item(),
        'min_vision_attention': vision_attention_raw.min().item(),
        'mean_vision_attention': vision_attention_raw.mean().item(),
        
        # Parameters used
        'attention_power': FIXED_ATTENTION_POWER,
        
        # Additional metadata
        'sequence_length': len(outputs["sequences"][0]) if "sequences" in outputs else None,
        'num_attention_layers': len(outputs["attentions"]) if "attentions" in outputs else None,
    }
    
    # Create directory if it doesn't exist
    os.makedirs(os.path.dirname(process_file_path), exist_ok=True)
    
    # Load existing data or create new list
    if os.path.exists(process_file_path):
        try:
            with open(process_file_path, 'rb') as f:
                existing_data = pickle.load(f)
            if not isinstance(existing_data, list):
                existing_data = [existing_data]  # Convert single entry to list
        except:
            existing_data = []
    else:
        existing_data = []
    
    # Append new data
    existing_data.append(attention_data)
    
    # Save updated data
    with open(process_file_path, 'wb') as f:
        pickle.dump(existing_data, f)
    
    print(f"Saved attention data to {process_file_path} (total entries: {len(existing_data)})")
    print(f"Vision attention: {vision_attention_sum:.6f} ({vision_percentage:.1f}% of total)")
    
    return attention_data

def merge_distributed_attention_files(base_file_path, output_file_path=None):
    """
    Merge attention files from multiple processes into a single file.
    
    Args:
        base_file_path: Original file path (e.g., "attention_analysis/eagle_x4_mme_all.pkl")
        output_file_path: Output merged file path (optional)
    """
    base_path = Path(base_file_path)
    if output_file_path is None:
        output_file_path = base_path.parent / f"{base_path.stem}_merged{base_path.suffix}"
    
    # Find all rank-specific files
    pattern = f"{base_path.stem}_rank_*{base_path.suffix}"
    rank_files = list(base_path.parent.glob(pattern))
    
    if not rank_files:
        print(f"No rank-specific files found matching pattern: {pattern}")
        return
    
    merged_data = []
    total_samples = 0
    
    for rank_file in sorted(rank_files):
        try:
            with open(rank_file, 'rb') as f:
                data = pickle.load(f)
            if not isinstance(data, list):
                data = [data]
            merged_data.extend(data)
            total_samples += len(data)
            print(f"Loaded {len(data)} samples from {rank_file}")
        except Exception as e:
            print(f"Error loading {rank_file}: {e}")
    
    # Save merged data
    with open(output_file_path, 'wb') as f:
        pickle.dump(merged_data, f)
    
    print(f"Merged {total_samples} samples from {len(rank_files)} files into {output_file_path}")
    
    # Optionally remove individual rank files
    # for rank_file in rank_files:
    #     rank_file.unlink()
    
    return merged_data

def load_vision_attention_scores(file_path):
    """
    Load saved vision attention scores for analysis.
    
    Args:
        file_path: Path to the pickle file
        
    Returns:
        list: List of attention data dictionaries
    """
    with open(file_path, 'rb') as f:
        data = pickle.load(f)
    
    if not isinstance(data, list):
        data = [data]  # Convert single entry to list
    
    print(f"Loaded {len(data)} attention entries from {file_path}")
    return data

def extract_attention_arrays_for_plotting(data_list, enhanced=True):
    """
    Extract attention arrays from saved data for plotting.
    
    Args:
        data_list: List of attention data from load_vision_attention_scores()
        enhanced: Whether to use enhanced (power-scaled) or raw attention values
        
    Returns:
        dict: Arrays ready for plotting and analysis
    """
    vision_key = 'vision_attention_enhanced' if enhanced else 'vision_attention_raw'
    full_key = 'full_attention_enhanced' if enhanced else 'full_attention_raw'
    
    vision_attentions = []
    full_attentions = []
    metadata = []
    
    for entry in data_list:
        vision_attentions.append(entry[vision_key].numpy())
        full_attentions.append(entry[full_key].numpy())
        metadata.append({
            'image_path': entry.get('image_path'),
            'prompt': entry.get('prompt'),
            'response': entry.get('response'),
            'vision_percentage': entry.get('vision_percentage'),
            'vision_sum': entry.get('vision_attention_sum')
        })
    
    return {
        'vision_attentions': np.array(vision_attentions),
        'full_attentions': np.array(full_attentions),
        'metadata': metadata,
        'num_samples': len(data_list),
        'enhanced': enhanced
    }

# Example usage in your main script:
# def example_usage():
#     """
#     Example of how to use the function in your evaluation loop
#     """
    
#     # In your evaluation loop:
#     attention_data = save_vision_attention_scores(
#         outputs=outputs,
#         file_path="attention_analysis/vision_attention_scores.pkl",
#         vision_token_start=vision_token_start,
#         vision_token_end=vision_token_end,
#         num_vision_tokens=num_vision_tokens,
#         image_path=image_path,
#         prompt=input_prompt,
#         response=tokenizer.decode(outputs["sequences"][0]).strip(),
#         FIXED_ATTENTION_POWER=0.3
#     )
    
    # Later, for analysis:
    # data_list = load_vision_attention_scores("attention_analysis/vision_attention_scores.pkl")
    # plotting_data = extract_attention_arrays_for_plotting(data_list, enhanced=True)
    # 
    # # Now you can analyze:
    # mean_vision_attention = np.mean(plotting_data['vision_attentions'], axis=0)
    # std_vision_attention = np.std(plotting_data['vision_attentions'], axis=0)
    # # etc.