"""
Main script for QK CKA Similarity Calculation.
This script computes the CKA similarity for attention Q and K weights between different models.
It uses word embedding alignment to handle different feature dimensions.
Usage:
    python run_analysis.py --config moe_pairs
"""

import argparse
import torch
import matplotlib.pyplot as plt
import os
import numpy as np
from scipy.optimize import linear_sum_assignment
from sklearn.metrics.pairwise import cosine_similarity
import torch.nn.functional as F

os.chdir(os.path.dirname(os.path.abspath(__file__)))

from similarity_metrics import (
    load_all_weights_from_dir,
    get_attention_weights,
    calculate_attention_cka_similarities,
    generate_negative_sample,
    get_word_embedding_weight,
    load_vocab_from_dir,
    find_overlapping_vocab
)
from configs import get_config, AVAILABLE_CONFIGS, CHECKPOINT_BASE_DIR

# --- Plotting Style Configuration ---
TITLE_FONTSIZE = 36
LABEL_FONTSIZE = 32
TICK_FONTSIZE = 28
LEGEND_FONTSIZE = 28
FIG_SIZE = (18, 12)
COLORS = ['green', 'blue', 'purple', 'red', 'orange', 'brown', 'cyan', 'magenta']

# Line styles and colors for different metric types
METRIC_STYLES = {
    'q_permuted': {'color': 'blue', 'linestyle': '-', 'label': 'Q (Permuted)'},
    'k_permuted': {'color': 'green', 'linestyle': '-', 'label': 'K (Permuted)'},
    'v_permuted': {'color': 'cyan', 'linestyle': '-', 'label': 'V (Permuted)'},
    'o_permuted': {'color': 'magenta', 'linestyle': '-', 'label': 'O (Permuted)'},
    'q_direct': {'color': 'red', 'linestyle': '--', 'label': 'Q (Aligned)'},
    'k_direct': {'color': 'purple', 'linestyle': '--', 'label': 'K (Aligned)'},
    'v_direct': {'color': 'orange', 'linestyle': '--', 'label': 'V (Direct)'},
    'o_direct': {'color': 'brown', 'linestyle': '--', 'label': 'O (Direct)'},
}

def plot_layer_similarities(results, output_path, title):
    """Plots layer-wise similarities for each model pair."""
    print(f"\n--- Generating layer-wise similarity plot ---")
    fig, ax = plt.subplots(figsize=FIG_SIZE)
    try:
        plt.style.use('seaborn-v0_8-whitegrid')
    except:
        plt.grid(True)
        
    for item in results:
        pair_label = item['label']
        layer_results = item.get('layers', {})
        
        for metric, layer_data in layer_results.items():
            style = METRIC_STYLES.get(metric, {})
            if layer_data:
                layers = sorted(layer_data.keys())
                scores = [layer_data[l] for l in layers]
                ax.plot(layers, scores, marker='o', 
                        linestyle=style.get('linestyle', '-'), 
                        label=f"{pair_label} - {style.get('label', metric)}", 
                        color=style.get('color', 'black'))

    ax.set_xlabel('Layer Number', fontsize=LABEL_FONTSIZE)
    ax.set_ylabel('Attention CKA Similarity', fontsize=LABEL_FONTSIZE)
    ax.set_title(title, fontsize=TITLE_FONTSIZE, pad=25)
    ax.tick_params(axis='both', which='major', labelsize=TICK_FONTSIZE, pad=10)
    ax.legend(fontsize=LEGEND_FONTSIZE, loc='best')
    fig.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.close(fig)
    print(f"Plot saved to: {output_path}")

def run_experiment(config_name, device='cpu'):
    """Runs the experiment for the specified configuration."""
    print(f"\n{'='*60}")
    print(f"Starting experiment: {config_name}")
    print(f"Using device: {device}")
    print(f"{'='*60}")
    
    config = get_config(config_name)
    model_paths = config["model_paths"]
    analysis_pairs = config["analysis_pairs"]
    
    print(f"\n--- Pre-loading model weights, vocabularies, and embeddings ---")
    all_attention_weights = {}
    all_vocabs = {}
    all_embeddings = {}

    for model_name, model_folder in model_paths.items():
        print(f"Loading model: {model_name}")
        model_path = os.path.join(CHECKPOINT_BASE_DIR, model_folder)
        state_dict = load_all_weights_from_dir(model_path)
        if state_dict:
            all_attention_weights[model_name] = get_attention_weights(state_dict)
            all_vocabs[model_name] = load_vocab_from_dir(model_path)
            all_embeddings[model_name] = get_word_embedding_weight(state_dict)
        else:
            print(f"Warning: Failed to load model {model_name}")
    
    if "random" in [p["weights2_name"] for p in analysis_pairs] or "random" in [p["weights1_name"] for p in analysis_pairs]:
        print(f"\n--- Generating random negative sample ---")
        ref_model_name = next((name for name, weights in all_attention_weights.items() if weights), None)
        
        if ref_model_name:
            print(f"Using {ref_model_name} as reference to generate random sample")
            all_attention_weights["random"] = generate_negative_sample(all_attention_weights[ref_model_name])
            all_vocabs["random"] = all_vocabs.get(ref_model_name)
            all_embeddings["random"] = all_embeddings.get(ref_model_name)
        else:
            print("Warning: Cannot generate random sample, no suitable reference model found.")
            all_attention_weights["random"], all_vocabs["random"], all_embeddings["random"] = None, None, None

    print(f"\n--- Starting QK CKA Similarity Analysis ---")
    qk_results = []
    
    for pair in analysis_pairs:
        name1, name2, label = pair["weights1_name"], pair["weights2_name"], pair["label"]
        weights1, weights2 = all_attention_weights.get(name1), all_attention_weights.get(name2)
        vocab1, embedding1 = all_vocabs.get(name1), all_embeddings.get(name1)
        vocab2, embedding2 = all_vocabs.get(name2), all_embeddings.get(name2)

        if not all([weights1, weights2]):
            print(f"Skipping analysis for '{label}' - missing model weights.")
            continue
        
        # --- Initialize alignment parameters ---
        subselect_indices = None
        subselect_signs = None
        base_model_is_first = None

        # --- Unified Alignment Logic ---
        if all([vocab1, embedding1 is not None, vocab2, embedding2 is not None]):
            print(f"\nApplying unified dimension alignment for '{label}'...")
            _, indices1, indices2 = find_overlapping_vocab(vocab1, vocab2)

            if len(indices1) > 0:
                emb1 = torch.index_select(embedding1, 0, torch.tensor(indices1)).to(torch.float32).numpy()
                emb2 = torch.index_select(embedding2, 0, torch.tensor(indices2)).to(torch.float32).numpy()

                # 1. Automatically determine base and target models
                if emb1.shape[1] >= emb2.shape[1]:
                    emb_base, emb_target = emb1, emb2
                    base_model_is_first = True
                else:
                    emb_base, emb_target = emb2, emb1
                    base_model_is_first = False
                
                print(f"  Base model embedding: {emb_base.shape}, Target model embedding: {emb_target.shape}")
                
                # 2. Compute cost matrix and run LAP
                use_cuda = torch.cuda.is_available() and device == 'cuda'
                compute_device = torch.device("cuda" if use_cuda else "cpu")
                print(f"  Using device for alignment computation: {compute_device}")

                if use_cuda:
                    emb_base_t = torch.from_numpy(emb_base).to(compute_device)
                    emb_target_t = torch.from_numpy(emb_target).to(compute_device)
                    emb_base_norm = F.normalize(emb_base_t.T, p=2, dim=1)
                    emb_target_norm = F.normalize(emb_target_t.T, p=2, dim=1)
                    similarity_matrix_t = torch.mm(emb_base_norm, emb_target_norm.T)
                    cost_matrix_t = 1 - torch.abs(similarity_matrix_t)
                    cost_matrix = cost_matrix_t.cpu().numpy()
                    similarity_matrix = similarity_matrix_t.cpu().numpy()
                else:
                    similarity_matrix = cosine_similarity(emb_base.T, emb_target.T)
                    cost_matrix = 1 - np.abs(similarity_matrix)

                base_indices, target_indices = linear_sum_assignment(cost_matrix)

                # 3. Create dimension index map and sign vector
                perm = np.argsort(target_indices)
                subselect_indices = base_indices[perm]
                
                d_target = emb_target.shape[1]
                subselect_signs = np.sign(similarity_matrix[subselect_indices, np.arange(d_target)])
                
                print(f"  Found {len(subselect_indices)} matching dimensions, with {(subselect_signs == -1).sum()} needing sign flips.")
            else:
                print("  Warning: No overlapping vocabulary, cannot perform alignment.")
        else:
            print(f"Warning: Missing embeddings or vocabularies, skipping alignment for '{label}'.")

        print(f"\nAnalyzing pair: {label} ({name1} vs {name2})")
        result = calculate_attention_cka_similarities(
            weights1, weights2,
            device=device,
            subselect_indices=subselect_indices,
            subselect_signs=subselect_signs,
            base_model_is_first=base_model_is_first
        )
        
        if result and 'averages' in result and result['averages']:
            result['label'] = label
            qk_results.append(result)

    if qk_results:
        output_dir = "figures/qk_cka"
        os.makedirs(output_dir, exist_ok=True)
        output_path = f"{output_dir}/qk_cka_similarity_{config_name}.png"
        title = f"Layer-wise Attention CKA Similarity ({config_name})"
        plot_layer_similarities(qk_results, output_path, title)
        
        print(f"\n--- {config_name} Experiment Results Summary ---")
        sort_key = 'avg_direct' if any('avg_direct' in res['averages'] for res in qk_results) else 'q_direct'
        qk_results.sort(key=lambda x: x['averages'].get(sort_key, 0), reverse=True)
        for res in qk_results:
            print(f"--- {res['label']} ---")
            sorted_averages = sorted(
                res['averages'].items(),
                key=lambda item: (item[0].startswith('avg_'), item[0])
            )
            for metric, avg_score in sorted_averages:
                print(f"  Average {metric:<15} CKA Similarity = {avg_score:.4f}")

    print(f"\n{'='*60}")
    print("Experiment finished!")
    print(f"{'='*60}")

def main():
    parser = argparse.ArgumentParser(description="Run QK CKA Similarity Analysis Experiments")
    parser.add_argument(
        "--config", 
        choices=list(AVAILABLE_CONFIGS.keys()),
        default="debug_pairs",
        help="Select the experiment configuration to run."
    )
    parser.add_argument(
        "--device",
        default="auto",
        choices=["auto", "cpu", "cuda"],
        help="Select the computation device ('auto' detects CUDA availability)."
    )
    
    args = parser.parse_args()
    
    device_to_use = 'cuda' if args.device == "auto" and torch.cuda.is_available() else 'cpu'
    
    run_experiment(config_name=args.config, device=device_to_use)

if __name__ == "__main__":
    main() 