import pickle
import glob
import os
import numpy as np
import torch
from pathlib import Path
from tqdm import tqdm
import argparse

def load_single_feature_file(filepath):
    """Load a single pickle file containing features list"""
    try:
        with open(filepath, 'rb') as f:
            features = pickle.load(f)
        return features
    except Exception as e:
        print(f"Error loading {filepath}: {e}")
        return None

def convert_features_to_combined(input_dir, output_file=None, output_format='pickle'):
    """
    Convert individual feature pickle files to a combined format
    
    Args:
        input_dir: Directory containing individual pickle files
        output_file: Output file path (optional)
        output_format: 'pickle', 'numpy', or 'torch'
    """
    
    # Find all pickle files in the directory
    pickle_files = glob.glob(os.path.join(input_dir, "*.pkl"))
    pickle_files = [f for f in pickle_files if not f.endswith('_combined.pkl')]
    pickle_files.sort()  # Sort by filename
    
    print(f"Found {len(pickle_files)} pickle files in {input_dir}")
    
    if len(pickle_files) == 0:
        print("No pickle files found!")
        return
    
    # Load all features
    all_features = []
    failed_files = []
    
    print("Loading features from individual files...")
    for filepath in tqdm(pickle_files):
        features = load_single_feature_file(filepath)
        if features is not None:
            all_features.append(features)
        else:
            failed_files.append(filepath)
    
    print(f"Successfully loaded {len(all_features)} feature sets")
    if failed_files:
        print(f"Failed to load {len(failed_files)} files:")
        for f in failed_files:
            print(f"  {f}")
    
    if len(all_features) == 0:
        print("No features loaded successfully!")
        return
    
    # Print feature statistics
    print("\nFeature Statistics:")
    print(f"Total images: {len(all_features)}")
    if all_features:
        num_encoders = len(all_features[0])
        print(f"Number of encoders: {num_encoders}")
        
        for i in range(num_encoders):
            shapes = [feat[i].shape for feat in all_features if len(feat) > i]
            if shapes:
                print(f"Encoder {i}: {shapes[0]} (example shape)")
    
    # Save in specified format
    if output_file is None:
        base_name = os.path.basename(input_dir.rstrip('/'))
        output_file = os.path.join(input_dir, f"{base_name}_combined")
    
    if output_format == 'pickle':
        output_path = f"{output_file}.pkl"
        print(f"\nSaving combined features to {output_path}")
        with open(output_path, 'wb') as f:
            pickle.dump(all_features, f)
        print("Saved as pickle file")
    
    elif output_format == 'numpy':
        output_path = f"{output_file}.npz"
        print(f"\nSaving combined features to {output_path}")
        
        # Convert to numpy arrays
        numpy_features = {}
        for img_idx, features in enumerate(all_features):
            for enc_idx, feat in enumerate(features):
                if isinstance(feat, torch.Tensor):
                    feat_np = feat.numpy()
                else:
                    feat_np = np.array(feat)
                numpy_features[f"image_{img_idx:06d}_encoder_{enc_idx}"] = feat_np
        
        np.savez_compressed(output_path, **numpy_features)
        print("Saved as compressed numpy file")
    
    elif output_format == 'torch':
        output_path = f"{output_file}.pt"
        print(f"\nSaving combined features to {output_path}")
        
        # Convert to torch tensors if needed
        torch_features = []
        for features in all_features:
            torch_feat_list = []
            for feat in features:
                if isinstance(feat, torch.Tensor):
                    torch_feat_list.append(feat)
                else:
                    torch_feat_list.append(torch.tensor(feat))
            torch_features.append(torch_feat_list)
        
        torch.save(torch_features, output_path)
        print("Saved as PyTorch file")
    
    return all_features

def analyze_features(features_list):
    """Analyze the loaded features"""
    if not features_list:
        print("No features to analyze")
        return
    
    print("\n" + "="*50)
    print("FEATURE ANALYSIS")
    print("="*50)
    
    num_images = len(features_list)
    num_encoders = len(features_list[0]) if features_list else 0
    
    print(f"Total images: {num_images}")
    print(f"Number of encoders: {num_encoders}")
    
    # Analyze each encoder
    for enc_idx in range(num_encoders):
        print(f"\nEncoder {enc_idx}:")
        
        shapes = []
        means = []
        stds = []
        
        for img_idx, features in enumerate(features_list):
            if len(features) > enc_idx:
                feat = features[enc_idx]
                shapes.append(feat.shape)
                
                if isinstance(feat, torch.Tensor):
                    means.append(feat.mean().item())
                    stds.append(feat.std().item())
                else:
                    means.append(np.mean(feat))
                    stds.append(np.std(feat))
        
        if shapes:
            unique_shapes = list(set(shapes))
            print(f"  Shapes: {unique_shapes}")
            print(f"  Mean activation: {np.mean(means):.4f} ± {np.std(means):.4f}")
            print(f"  Std activation: {np.mean(stds):.4f} ± {np.std(stds):.4f}")

def clean_up_individual_files(input_dir, keep_combined=True):
    """Remove individual pickle files after successful combination"""
    pickle_files = glob.glob(os.path.join(input_dir, "*.pkl"))
    
    if keep_combined:
        # Keep only the combined file
        pickle_files = [f for f in pickle_files if not f.endswith('_combined.pkl')]
    
    print(f"\nRemoving {len(pickle_files)} individual files...")
    for filepath in pickle_files:
        try:
            os.remove(filepath)
        except Exception as e:
            print(f"Error removing {filepath}: {e}")
    
    print("Cleanup completed")

def main():
    parser = argparse.ArgumentParser(description='Convert individual feature pickle files to combined format')
    parser.add_argument('input_dir', help='Directory containing individual pickle files')
    parser.add_argument('--output', '-o', help='Output file path (without extension)')
    parser.add_argument('--format', '-f', choices=['pickle', 'numpy', 'torch'], 
                       default='pickle', help='Output format')
    parser.add_argument('--analyze', '-a', action='store_true', 
                       help='Analyze features after loading')
    parser.add_argument('--cleanup', '-c', action='store_true', 
                       help='Remove individual files after successful combination')
    
    args = parser.parse_args()
    
    # Convert features
    features = convert_features_to_combined(
        input_dir=args.input_dir,
        output_file=args.output,
        output_format=args.format
    )
    
    # Analyze if requested
    if args.analyze and features:
        analyze_features(features)
    
    # Cleanup if requested
    if args.cleanup and features:
        clean_up_individual_files(args.input_dir, keep_combined=True)

if __name__ == "__main__":
    # Example usage if run directly
    import sys
    
    if len(sys.argv) > 1:
        main()
    else:
        # Default behavior for testing
        input_directory = "MER/VLMEvalKit/features/eagle_x4_8b_mme"
        output_path = 'features/converted_features_eagle_x4_mme.pkl'
        
        print("Converting features with default settings...")
        features = convert_features_to_combined(
            input_dir=input_directory,
            output_format='pickle',
            output_file=output_path
        )
        
        if features:
            analyze_features(features)