import torch
from transformers import AutoTokenizer
from transformer_lens import HookedTransformer
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import types
import warnings
import einops
from transformer_lens.loading_from_pretrained import OFFICIAL_MODEL_NAMES, MODEL_ALIASES, make_model_alias_map
import random

warnings.filterwarnings("ignore", message="torch_dtype is deprecated")

# ========== Configuration Parameters ==========
MODEL_PATH_1 = r""  # Model 1 path
MODEL_ALIAS_1 = "pythia14m_trained"

MODEL_PATH_2 = r""  # Model 2 path
MODEL_ALIAS_2 = "pythia14m_baseline"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
OUTPUT_DIR = r""

NUM_SENTENCES = 100
NUM_REPEATS = 2
NUM_TOP_HEADS_TO_VERIFY = 3
NUM_VERIFICATION_SENTENCES = 20
NUM_DETAIL_EXAMPLES_PER_HEAD = 1
# ============================================

print("=" * 60)
print("Dual Model Comparison - Induction Head Detection")
print("=" * 60)
print(f"Device: {DEVICE}")
print(f"Output directory: {OUTPUT_DIR}")
print(f"Number of test sentences: {NUM_SENTENCES}")
print(f"Repetitions per sentence: {NUM_REPEATS}")
print(f"Model 1: {MODEL_ALIAS_1}")
print(f"Model 2: {MODEL_ALIAS_2}")

os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"✓ Output directory created/verified: {OUTPUT_DIR}\n")

# Step 0: Generate test sentences
print("Step 0: Generating test sentence pool...")

def generate_random_sentences(num_sentences=100, seed=42):
    """Generate completely random sentences without grammar rules"""
    random.seed(seed)
    np.random.seed(seed)
    
    all_words = [
        # Nouns
        "apple", "book", "car", "dog", "elephant", "flower", "guitar", "hat",
        "ice", "juice", "kite", "lamp", "moon", "notebook", "ocean", "pencil",
        "queen", "robot", "star", "table", "umbrella", "violin", "window", "box",
        "cloud", "door", "egg", "fire", "grass", "hill", "island", "jar",
        "key", "lion", "mirror", "nest", "orange", "piano", "ring", "sock",
        "tree", "vase", "wall", "yarn", "zebra", "ball", "cat", "desk",
        # Verbs
        "jumps", "runs", "sleeps", "eats", "drinks", "flies", "swims", "walks",
        "talks", "sings", "dances", "reads", "writes", "plays", "thinks", "looks",
        "listens", "feels", "smells", "tastes", "touches", "sees", "hears", "knows",
        "likes", "wants", "loves", "needs", "makes", "takes", "gives", "finds",
        # Adjectives
        "big", "small", "happy", "sad", "red", "blue", "green", "yellow",
        "old", "new", "hot", "cold", "fast", "slow", "high", "low",
        "bright", "dark", "soft", "hard", "smooth", "rough", "clean", "dirty",
        "pretty", "ugly", "long", "short", "wide", "narrow", "deep", "shallow",
        # Adverbs
        "quickly", "slowly", "happily", "sadly", "loudly", "quietly", "carefully",
        "suddenly", "always", "never", "sometimes", "often", "rarely", "usually",
        "very", "really", "quite", "almost", "nearly", "completely", "totally",
        # Prepositions/Conjunctions/Articles
        "on", "in", "under", "over", "beside", "behind", "near", "and", "but",
        "because", "so", "when", "if", "the", "a", "an", "with", "from", "to",
        # Pronouns
        "he", "she", "it", "they", "we", "I", "you", "him", "her", "them"
    ]
    
    sentences = []
    for i in range(num_sentences):
        sentence_length = random.randint(8, 20)
        words = [random.choice(all_words) for _ in range(sentence_length)]
        words[0] = words[0].capitalize()
        sentence = " ".join(words) + "."
        sentences.append(sentence)
    
    return sentences

TEST_SENTENCES = generate_random_sentences(NUM_SENTENCES)

print(f"✓ Generated {len(TEST_SENTENCES)} completely random test sentences")
print(f"  Example sentences:")
for i in range(min(5, len(TEST_SENTENCES))):
    print(f"    {i+1}. {TEST_SENTENCES[i]}")
print()

# Step 1: Load tokenizer (shared by both models)
print("Step 1: Loading tokenizer...")
try:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH_1)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    print("✅ Tokenizer loaded successfully.")
except Exception as e:
    print(f"❌ Failed to load tokenizer: {e}")
    raise

# Tokenizer method definitions
def to_tokens_tinystories(self, text: str, prepend_bos: bool = False, move_to_device: bool = True) -> torch.Tensor:
    if isinstance(text, str):
        tokens = self.tokenizer.encode(text, return_tensors="pt", add_special_tokens=prepend_bos)
    else:
        tokens = torch.cat([self.tokenizer.encode(t, return_tensors="pt", add_special_tokens=prepend_bos) for t in text], dim=-1)
    
    if move_to_device:
        tokens = tokens.to(self.cfg.device)
    return tokens

def to_str_tokens_tinystories(self, tokens: torch.Tensor, prepend_bos: bool = False) -> list[str]:
    if tokens.dim() == 2:
        tokens = tokens.squeeze(0)
    return [self.tokenizer.decode([token_id], skip_special_tokens=False) for token_id in tokens.tolist()]

# Step 2: Load both models
print("Step 2: Loading two HookedTransformer models...")

def load_model(model_path, model_alias, device):
    """Load and configure a single model"""
    if model_path not in OFFICIAL_MODEL_NAMES:
        OFFICIAL_MODEL_NAMES.append(model_path)
    MODEL_ALIASES[model_path] = [model_alias]
    make_model_alias_map()
    
    try:
        model = HookedTransformer.from_pretrained_no_processing(
            model_alias,
            device=device,
            dtype=torch.float32,
            n_devices=1
        )
        model.eval()
        
        # Bind tokenizer methods
        model.tokenizer = tokenizer
        model.to_tokens = types.MethodType(to_tokens_tinystories, model)
        model.to_str_tokens = types.MethodType(to_str_tokens_tinystories, model)
        
        return model
    except Exception as e:
        print(f"❌ Model loading failed: {e}")
        raise

model_1 = load_model(MODEL_PATH_1, MODEL_ALIAS_1, DEVICE)
print(f"✓ Model 1 loaded successfully: {MODEL_ALIAS_1}")
print(f"  - Layers: {model_1.cfg.n_layers}, n_heads: {model_1.cfg.n_heads}")

model_2 = load_model(MODEL_PATH_2, MODEL_ALIAS_2, DEVICE)
print(f"✓ Model 2 loaded successfully: {MODEL_ALIAS_2}")
print(f"  - Layers: {model_2.cfg.n_layers}, n_heads: {model_2.cfg.n_heads}\n")

# Step 3: Define computation functions
def compute_induction_scores_diagonal_offset_1(cache, model, block_len: int, num_repeats: int):
    scores = torch.zeros(model.cfg.n_layers, model.cfg.n_heads, device=model.cfg.device)
    
    for layer in range(model.cfg.n_layers):
        pat = cache[f"blocks.{layer}.attn.hook_pattern"][0]
        all_head_repeat_scores = []
        
        for repeat_idx in range(1, num_repeats):
            query_start = repeat_idx * block_len
            query_end = min((repeat_idx + 1) * block_len, pat.shape[1])
            key_start = (repeat_idx - 1) * block_len
            key_end = repeat_idx * block_len
            
            if (query_end - query_start) > 0 and (key_end - key_start) > 0:
                attention_block = pat[:, query_start:query_end, key_start:key_end]
                
                if attention_block.shape[1] > 1 and attention_block.shape[2] > 0:
                    stripe_diagonal_offset_1 = attention_block.diagonal(dim1=-2, dim2=-1, offset=1)
                    if stripe_diagonal_offset_1.numel() > 0:
                        all_head_repeat_scores.append(stripe_diagonal_offset_1.mean(dim=-1))
        
        if all_head_repeat_scores:
            scores[layer] = torch.stack(all_head_repeat_scores).mean(dim=0)
    
    return scores

def compute_all_scores(model, sentences, num_repeats, model_name="Model"):
    """Compute scores for one model across all sentences"""
    all_scores = []
    sentence_lengths = []
    
    for idx, sentence in enumerate(tqdm(sentences, desc=f"Processing {model_name}")):
        try:
            single_tokens = model.to_tokens(sentence, prepend_bos=False)
            block_len = single_tokens.shape[1]
            
            if block_len < 2:
                print(f"\nWarning: Sentence '{sentence}' has token length too short ({block_len}), skipping.")
                continue
            
            sentence_lengths.append(block_len)
            tokens = single_tokens.repeat(1, num_repeats)

            with torch.no_grad():
                logits, cache = model.run_with_cache(
                    tokens,
                    return_type="logits",
                    names_filter=lambda name: "pattern" in name
                )
            
            scores = compute_induction_scores_diagonal_offset_1(cache, model, block_len, num_repeats)
            all_scores.append(scores.cpu())
            
            del cache, logits
            torch.cuda.empty_cache() if torch.cuda.is_available() else None
            
        except Exception as e:
            print(f"\nWarning: Error processing sentence {idx}: {e}")
            continue
    
    return all_scores, sentence_lengths

# Step 4: Compute scores for both models
print(f"\nStep 4a: Computing induction scores for Model 1 ({MODEL_ALIAS_1})...")
all_scores_1, sentence_lengths = compute_all_scores(model_1, TEST_SENTENCES, NUM_REPEATS, MODEL_ALIAS_1)

if not all_scores_1:
    print("\nError: No valid sentences for Model 1 computation.")
    exit()

averaged_scores_1 = torch.stack(all_scores_1).mean(dim=0)
std_scores_1 = torch.stack(all_scores_1).std(dim=0) if len(all_scores_1) > 1 else torch.zeros_like(averaged_scores_1)

print(f"\nStep 4b: Computing induction scores for Model 2 ({MODEL_ALIAS_2})...")
all_scores_2, _ = compute_all_scores(model_2, TEST_SENTENCES, NUM_REPEATS, MODEL_ALIAS_2)

if not all_scores_2:
    print("\nError: No valid sentences for Model 2 computation.")
    exit()

averaged_scores_2 = torch.stack(all_scores_2).mean(dim=0)
std_scores_2 = torch.stack(all_scores_2).std(dim=0) if len(all_scores_2) > 1 else torch.zeros_like(averaged_scores_2)

print(f"\n✓ Successfully processed {len(all_scores_1)} sentences")
print(f"  Average sentence length: {np.mean(sentence_lengths):.1f} tokens")

# Step 5: Compute differences
print("\nStep 5: Computing differences between the two models...")
diff_scores = averaged_scores_1 - averaged_scores_2
diff_std = torch.sqrt(std_scores_1**2 + std_scores_2**2)

# Find head with maximum absolute difference
flat_idx_max = diff_scores.abs().flatten().argmax().item()
max_diff_layer = flat_idx_max // model_1.cfg.n_heads
max_diff_head = flat_idx_max % model_1.cfg.n_heads
max_diff_value = diff_scores[max_diff_layer, max_diff_head].item()

# Find head with maximum positive difference (Model 1 stronger)
flat_idx_pos = diff_scores.flatten().argmax().item()
pos_diff_layer = flat_idx_pos // model_1.cfg.n_heads
pos_diff_head = flat_idx_pos % model_1.cfg.n_heads
pos_diff_value = diff_scores[pos_diff_layer, pos_diff_head].item()

# Find head with maximum negative difference (Model 2 stronger)
flat_idx_neg = diff_scores.flatten().argmin().item()
neg_diff_layer = flat_idx_neg // model_1.cfg.n_heads
neg_diff_head = flat_idx_neg % model_1.cfg.n_heads
neg_diff_value = diff_scores[neg_diff_layer, neg_diff_head].item()

print("\n" + "=" * 60)
print(f"Dual Model Comparison Results (Based on {len(all_scores_1)} sentences)")
print("=" * 60)
print(f"\nHead with maximum absolute difference:")
print(f"  Layer {max_diff_layer}, Head {max_diff_head}")
print(f"  Difference: {max_diff_value:+.3f}")
print(f"  Model 1 score: {averaged_scores_1[max_diff_layer, max_diff_head]:.3f}")
print(f"  Model 2 score: {averaged_scores_2[max_diff_layer, max_diff_head]:.3f}")

print(f"\nHead where Model 1 is significantly stronger:")
print(f"  Layer {pos_diff_layer}, Head {pos_diff_head}")
print(f"  Difference: {pos_diff_value:+.3f}")
print(f"  Model 1 score: {averaged_scores_1[pos_diff_layer, pos_diff_head]:.3f}")
print(f"  Model 2 score: {averaged_scores_2[pos_diff_layer, pos_diff_head]:.3f}")

print(f"\nHead where Model 2 is significantly stronger:")
print(f"  Layer {neg_diff_layer}, Head {neg_diff_head}")
print(f"  Difference: {neg_diff_value:+.3f}")
print(f"  Model 1 score: {averaged_scores_1[neg_diff_layer, neg_diff_head]:.3f}")
print(f"  Model 2 score: {averaged_scores_2[neg_diff_layer, neg_diff_head]:.3f}")

print("\nMaximum difference per layer (Model 1 - Model 2):")
for layer in range(model_1.cfg.n_layers):
    layer_diff = diff_scores[layer]
    layer_max_head = layer_diff.abs().argmax().item()
    layer_diff_value = layer_diff[layer_max_head].item()
    print(f"  Layer {layer}: Head {layer_max_head}, Diff={layer_diff_value:+.3f} " +
          f"(M1={averaged_scores_1[layer, layer_max_head]:.3f}, M2={averaged_scores_2[layer, layer_max_head]:.3f})")
print("=" * 60 + "\n")

# Step 6: Generate comparison heatmaps
print("Step 6: Generating comparison heatmaps...")

# 6.1 Three-in-one heatmap: Model 1, Model 2, Difference
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(30, 8))

# Determine unified colorbar range
vmax_scores = max(averaged_scores_1.max().item(), averaged_scores_2.max().item())

# Model 1 heatmap
sns.heatmap(averaged_scores_1.numpy(), cmap='viridis', ax=ax1, cbar=True,
           xticklabels=range(model_1.cfg.n_heads),
           yticklabels=range(model_1.cfg.n_layers),
           annot=True, fmt='.2f', cbar_kws={'label': 'Induction Score'},
           linewidths=.5, linecolor='lightgray', vmin=0, vmax=vmax_scores)
ax1.set_title(f"Model 1: {MODEL_ALIAS_1}\n(N={len(all_scores_1)} sentences)", 
             fontsize=14, fontweight='bold')
ax1.set_xlabel("Head", fontsize=12)
ax1.set_ylabel("Layer", fontsize=12)

# Model 2 heatmap
sns.heatmap(averaged_scores_2.numpy(), cmap='viridis', ax=ax2, cbar=True,
           xticklabels=range(model_2.cfg.n_heads),
           yticklabels=range(model_2.cfg.n_layers),
           annot=True, fmt='.2f', cbar_kws={'label': 'Induction Score'},
           linewidths=.5, linecolor='lightgray', vmin=0, vmax=vmax_scores)
ax2.set_title(f"Model 2: {MODEL_ALIAS_2}\n(N={len(all_scores_2)} sentences)", 
             fontsize=14, fontweight='bold')
ax2.set_xlabel("Head", fontsize=12)
ax2.set_ylabel("Layer", fontsize=12)

# Difference heatmap (using diverging colormap, 0-centered)
max_abs_diff = diff_scores.abs().max().item()
sns.heatmap(diff_scores.numpy(), cmap='RdBu_r', ax=ax3, cbar=True,
           xticklabels=range(model_1.cfg.n_heads),
           yticklabels=range(model_1.cfg.n_layers),
           annot=True, fmt='+.2f', 
           cbar_kws={'label': 'Score Difference (Model1 - Model2)'},
           linewidths=.5, linecolor='lightgray',
           vmin=-max_abs_diff, vmax=max_abs_diff, center=0)
ax3.set_title(f"Difference (Model 1 - Model 2)\nRed=Model1 Stronger, Blue=Model2 Stronger", 
             fontsize=14, fontweight='bold')
ax3.set_xlabel("Head", fontsize=12)
ax3.set_ylabel("Layer", fontsize=12)

plt.tight_layout()
save_path = os.path.join(OUTPUT_DIR, "induction_scores_comparison.png")
plt.savefig(save_path, dpi=200, bbox_inches='tight')
plt.close()
print(f"  ✓ Comparison heatmap saved: {save_path}")

# 6.2 Standalone detailed difference heatmap
fig, ax = plt.subplots(figsize=(12, 10))
sns.heatmap(diff_scores.numpy(), cmap='RdBu_r', ax=ax, cbar=True,
           xticklabels=range(model_1.cfg.n_heads),
           yticklabels=range(model_1.cfg.n_layers),
           annot=True, fmt='+.2f', annot_kws={'size': 9},
           cbar_kws={'label': 'Score Difference (Model1 - Model2)'},
           linewidths=.5, linecolor='lightgray',
           vmin=-max_abs_diff, vmax=max_abs_diff, center=0)
ax.set_title(f"Induction Score Difference Heatmap\n{MODEL_ALIAS_1} - {MODEL_ALIAS_2}\n" +
            f"(Red: Model1 Stronger, Blue: Model2 Stronger, White: Similar)\nN={len(all_scores_1)} sentences", 
            fontsize=14, fontweight='bold')
ax.set_xlabel("Head", fontsize=12)
ax.set_ylabel("Layer", fontsize=12)

plt.tight_layout()
save_path = os.path.join(OUTPUT_DIR, "induction_diff_detailed.png")
plt.savefig(save_path, dpi=200, bbox_inches='tight')
plt.close()
print(f"  ✓ Detailed difference heatmap saved: {save_path}")

# Step 7: Save results summary
results_path = os.path.join(OUTPUT_DIR, "comparison_summary.txt")
with open(results_path, "w", encoding="utf-8") as f:
    f.write(f"Dual Model Induction Head Comparison Results\n")
    f.write(f"=" * 60 + "\n")
    f.write(f"Model 1: {MODEL_ALIAS_1}\n")
    f.write(f"Model 2: {MODEL_ALIAS_2}\n")
    f.write(f"Number of test sentences: {len(all_scores_1)}\n")
    f.write(f"Repetitions per sentence: {NUM_REPEATS}\n")
    f.write(f"Average sentence length: {np.mean(sentence_lengths):.1f} tokens\n\n")
    
    f.write(f"=" * 60 + "\n")
    f.write(f"Difference Statistics (Model 1 - Model 2)\n")
    f.write(f"=" * 60 + "\n")
    f.write(f"Mean difference: {diff_scores.mean():.4f}\n")
    f.write(f"Std dev of difference: {diff_scores.std():.4f}\n")
    f.write(f"Maximum difference: {diff_scores.max():.4f}\n")
    f.write(f"Minimum difference: {diff_scores.min():.4f}\n\n")
    
    f.write(f"Head with maximum absolute difference:\n")
    f.write(f"  Layer {max_diff_layer}, Head {max_diff_head}\n")
    f.write(f"  Difference: {max_diff_value:+.3f}\n")
    f.write(f"  Model 1 score: {averaged_scores_1[max_diff_layer, max_diff_head]:.3f}\n")
    f.write(f"  Model 2 score: {averaged_scores_2[max_diff_layer, max_diff_head]:.3f}\n\n")
    f.write(f"Difference: {pos_diff_value:+.3f}\n\n")
    
    f.write(f"Head where Model 2 is significantly stronger:\n")
    f.write(f"  Layer {neg_diff_layer}, Head {neg_diff_head}\n")
    f.write(f"  Difference: {neg_diff_value:+.3f}\n\n")
    
    f.write(f"Maximum difference per layer:\n")
    for layer in range(model_1.cfg.n_layers):
        layer_diff = diff_scores[layer]
        layer_max_head = layer_diff.abs().argmax().item()
        layer_diff_value = layer_diff[layer_max_head].item()
        f.write(f"  Layer {layer}: Head {layer_max_head}, Diff={layer_diff_value:+.3f}\n")
    
    f.write(f"\nFirst 10 example sentences:\n")
    for i in range(min(10, len(TEST_SENTENCES))):
        f.write(f"  {i+1}. {TEST_SENTENCES[i]}\n")

print(f"  ✓ Comparison results summary saved: {results_path}")

print("\n" + "=" * 60)
print("Dual model comparison analysis complete!")
print("=" * 60)
print(f"Output files:")
print(f"  1. Three-in-one comparison heatmap: {OUTPUT_DIR}/induction_scores_comparison.png")
print(f"  2. Detailed difference heatmap: {OUTPUT_DIR}/induction_diff_detailed.png")
print(f"  3. Results summary: {OUTPUT_DIR}/comparison_summary.txt")
print("=" * 60)



