import pickle
import torch
import numpy as np
import os
import argparse
import pandas as pd
from pathlib import Path
from platonic_alignment_metrics import AlignmentMetrics, remove_outliers
import torch.nn.functional as F


def load_incremental_pickle(pkl_path):
    """
    Load pickle file that was saved incrementally.
    
    Args:
        pkl_path (str): Path to the pickle file
        
    Returns:
        dict: Combined dictionary of all entries
    """
    result = {}
    with open(pkl_path, "rb") as f:
        while True:
            try:
                entry = pickle.load(f)  # One {filename: output_subset} at a time
                result.update(entry)
            except EOFError:
                break
    return result


def prepare_features(text_data, img_data, text_agg: str = "mean", direct_img_features: bool = False):
    '''
    Prepare aligned features from text and image data.
    
    Args:
        text_data (dict): Dictionary of text features with filenames as keys
        img_data (dict or numpy.ndarray): Dictionary of image features with filenames as keys,
                                        or direct numpy array of image features if direct_img_features=True
        text_agg (str): Aggregation method for text features ("mean" or "max")
        direct_img_features (bool): If True, img_data is treated as direct features array
        
    Returns:
        t_feats: (N, l_dim) numpy array
        i_feats: (N, v_dim) numpy array
    '''
    if direct_img_features:
        # For image features, handle the specific format
        if isinstance(img_data, dict):
            # Convert dict values to numpy array and process
            selected = np.array(list(img_data.values()))
            if selected.ndim == 4:  # If shape is (N, 1, seq_len, dim)
                selected = selected.squeeze(1)  # Remove the extra dimension
            i_feats = selected[:, 0, :]  # Get CLS tokens
        elif isinstance(img_data, np.ndarray):
            if img_data.ndim == 4:  # If shape is (N, 1, seq_len, dim)
                i_feats = img_data.squeeze(1)[:, 0, :]  # Remove extra dim and get CLS
            elif img_data.ndim == 3:  # If shape is (N, seq_len, dim)
                i_feats = img_data[:, 0, :]  # Get CLS tokens
            else:
                # Assume it's already CLS tokens
                i_feats = img_data
        elif torch.is_tensor(img_data):
            # Handle torch tensor case
            if img_data.ndim == 4:  # If shape is (N, 1, seq_len, dim)
                i_feats = img_data.squeeze(1)[:, 0, :].cpu().numpy()
            elif img_data.ndim == 3:  # If shape is (N, seq_len, dim)
                i_feats = img_data[:, 0, :].cpu().numpy()
            else:
                i_feats = img_data.cpu().numpy()
        else:
            raise ValueError(f"Unsupported image data type: {type(img_data)}")
        
        # Process text features
        txt_feats = []
        for key in text_data:
            hs = text_data[key]["last_hidden_state"]    # Tensor: [1, seq_len, l_dim]
            if text_agg == "mean":
                vec = hs.mean(dim=1)    # [1, l_dim]
            elif text_agg == "max":
                vec, _ = hs.max(dim=1)  # [1, l_dim]
            else:
                raise ValueError(f"Unknown text_agg {text_agg}")
            txt_feats.append(vec.squeeze().cpu().numpy())  # Remove all extra dimensions
        
        t_feats = np.stack(txt_feats, axis=0)  # (N, l_dim)
        
        # Ensure same number of samples
        min_samples = min(t_feats.shape[0], i_feats.shape[0])
        t_feats = t_feats[:min_samples]
        i_feats = i_feats[:min_samples]
        
        print(f"Using first {min_samples} samples from both modalities")
        print(f"Image feature shape: {i_feats.shape}, Text feature shape: {t_feats.shape}")
        
    else:
        # Original matching logic
        norm_txt = {fn[-10:-4]: fn for fn in text_data}
        norm_img = {fn[-10:-4]: fn for fn in img_data}
        common = set(norm_txt) & set(norm_img)
        txt_keys = [norm_txt[k] for k in common]
        img_keys = [norm_img[k] for k in common]
        txt_feats, img_feats = [], []

        for tkey, ikey in zip(txt_keys, img_keys):
            hs = text_data[tkey]["last_hidden_state"]   # Tensor: [1, seq_len, l_dim]
            if text_agg == "mean":
                vec = hs.mean(dim=1)    # [1, l_dim]
            elif text_agg == "max":
                vec, _ = hs.max(dim=1)  # [1, l_dim]
            else:
                raise ValueError(f"Unknown text_agg {text_agg}")
            txt_feats.append(vec.squeeze().cpu().numpy())  # Remove all extra dimensions

            # Handle image features in the same format as direct case
            img_feat = img_data[ikey]
            if img_feat.ndim == 3:  # If shape is (1, seq_len, dim)
                v = img_feat.squeeze(0)[0]  # Remove batch dim and get CLS
            elif img_feat.ndim == 2:  # If shape is (seq_len, dim)
                v = img_feat[0]  # Get CLS token
            else:
                v = img_feat  # Assume it's already CLS token
            img_feats.append(v)

        t_feats = np.stack(txt_feats, axis=0)  # (N, l_dim)
        i_feats = np.stack(img_feats, axis=0)  # (N, v_dim)
        
        print(f"Matched {len(common)} samples")
        print(f"Image feature shape: {i_feats.shape}, Text feature shape: {t_feats.shape}")
    
    return t_feats, i_feats


def load_and_prepare_features(text_data_path, img_data_path, text_agg="mean", direct_img_features=False):
    """
    Load pickle files and prepare features.
    
    Args:
        text_data_path (str): Path to the text features pickle file
        img_data_path (str): Path to the image features pickle file
        text_agg (str): Aggregation method for text features ("mean" or "max")
        direct_img_features (bool): If True, image features are loaded directly without matching
        
    Returns:
        tuple: (text_features, image_features) as torch tensors
    """
    # Load pickle files
    text_data = load_incremental_pickle(text_data_path)
    img_data = load_incremental_pickle(img_data_path)
    
    # Prepare features using your function
    t_feats, i_feats = prepare_features(text_data, img_data, text_agg, direct_img_features)
    
    # Convert to torch tensors
    text_features = torch.from_numpy(t_feats).float()
    image_features = torch.from_numpy(i_feats).float()
    
    return text_features, image_features


def compute_all_metrics(text_features, image_features, topk=10, cca_dim=10):
    """
    Compute all available alignment metrics between text and image features.
    
    Args:
        text_features (torch.Tensor): Text features of shape (N, l_dim)
        image_features (torch.Tensor): Image features of shape (N, v_dim)
        topk (int): Number of nearest neighbors for kNN-based metrics
        cca_dim (int): Dimension for CCA-based metrics
        
    Returns:
        dict: Dictionary containing all computed metrics
    """
    # Normalize features
    text_features = F.normalize(text_features, dim=-1)
    image_features = F.normalize(image_features, dim=-1)
    
    metrics = {}
    
    # Compute all supported metrics
    for metric_name in AlignmentMetrics.SUPPORTED_METRICS:
        kwargs = {}
        if metric_name == 'cka':
            pass  # no extra kwargs
        elif metric_name == 'cknna':
            kwargs['topk'] = topk
        elif metric_name == 'svcca':
            kwargs['cca_dim'] = cca_dim
        elif metric_name in ['cca_linear_pca', 'cca_kernel_pca']:
            kwargs['n_components'] = 50
            kwargs['threshold'] = 0.95
        try:
            score = AlignmentMetrics.measure(metric_name, text_features, image_features, **kwargs)
            metrics[metric_name] = score
        except Exception as e:
            print(f"Warning: Failed to compute {metric_name}: {str(e)}")
    
    return metrics


def print_metrics(metrics):
    """
    Pretty print the computed metrics.
    
    Args:
        metrics (dict): Dictionary of metric names and values
    """
    print("\nAlignment Metrics:")
    print("-" * 50)
    for metric_name, score in metrics.items():
        print(f"{metric_name.replace('_', ' ').title():30s}: {score:.4f}")
    print("-" * 50)


def save_metrics_to_csv(metrics, output_path):
    """
    Save metrics to a CSV file.
    
    Args:
        metrics (dict): Dictionary of metric names and values
        output_path (str): Path to save the CSV file
    """
    df = pd.DataFrame([metrics])
    
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    
    df.to_csv(output_path, index=False)
    print(f"Results saved to: {output_path}")


def compute_platonic_alignment_metrics(text_data_path, img_data_path, output_dir="results", text_agg="mean", direct_img_features=False):
    """
    Process a single language file against the image features and save results.
    
    Args:
        text_data_path (str): Path to the text features pickle file
        img_data_path (str): Path to the image features pickle file
        output_dir (str): Directory to save results
        text_agg (str): Aggregation method for text features
        direct_img_features (bool): If True, image features are loaded directly without matching (i.e., testing on easy non-matching case)
    """
    text_feats, img_feats = load_and_prepare_features(text_data_path, img_data_path, text_agg, direct_img_features)
    
    metrics = compute_all_metrics(text_feats, img_feats)
    
    lang_name = Path(text_data_path).stem
    output_path = os.path.join(output_dir, f"{lang_name}_alignment_test.csv")
    save_metrics_to_csv(metrics, output_path)
    
    print(f"\nResults for {lang_name}:")
    print_metrics(metrics)
    print(f"Feature dimensions:")
    print(f"Text features: {text_feats.shape}")
    print(f"Image features: {img_feats.shape}")


def main():
    parser = argparse.ArgumentParser(description="Compute alignment metrics between text and image features")
    parser.add_argument("--text_files", nargs="+", required=True,
                      help="Paths to text feature pickle files")
    parser.add_argument("--img_file", required=True,
                      help="Path to image feature pickle file")
    parser.add_argument("--output_dir", default="results",
                      help="Directory to save results (default: results)")
    parser.add_argument("--text_agg", choices=["mean", "max"], default="mean",
                      help="Text feature aggregation method (default: mean)")
    parser.add_argument("--direct_img_features", action="store_true",
                      help="Use image features directly without matching")
    
    args = parser.parse_args()
    
    os.makedirs(args.output_dir, exist_ok=True)
    
    for text_file in args.text_files:
        try:
            compute_platonic_alignment_metrics(
                text_file, 
                args.img_file, 
                args.output_dir, 
                args.text_agg,
                direct_img_features=args.direct_img_features
            )
        except Exception as e:
            print(f"Error processing {text_file}: {str(e)}")

if __name__ == "__main__":
    main() 