#!/usr/bin/env python3
"""
Step 3: Evaluate and Compare Models

Compares:
- seq: Baseline (seq-only, trained from scratch)
- distilled: Distilled model (seq-only, learned from teacher)
- full: Teacher (full annotations)

Metrics:
- Perplexity on masked language modeling
- Accuracy on base prediction
- Hidden state similarity to teacher
"""

import sys
sys.path.append('..')

import os
import json
import argparse
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from pathlib import Path
from tqdm import tqdm
import pandas as pd

from config import get_model_paths, get_model_config
from distill_config import get_distill_paths, DISTILL_CONFIG, DISTILLED_CONFIG

sys.path.append('../2_train')
from model import annDNA


class EvalDataset(Dataset):
    """Dataset for evaluation"""
    def __init__(self, data_dir, split='val'):
        self.input_ids = np.load(f"{data_dir}/{split}_input_ids.npy")
        self.labels = np.load(f"{data_dir}/{split}_labels.npy")
        self.attention_mask = np.load(f"{data_dir}/{split}_attention_mask.npy")
        print(f"Loaded {len(self.input_ids):,} {split} samples")

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return {
            'input_ids': torch.tensor(self.input_ids[idx], dtype=torch.long),
            'labels': torch.tensor(self.labels[idx], dtype=torch.long),
            'attention_mask': torch.tensor(self.attention_mask[idx], dtype=torch.long),
        }


def load_model(model_name, device, checkpoint_path=None):
    """Load a model"""
    config = get_model_config(model_name)
    paths = get_model_paths(model_name)

    with open(paths['vocab_file']) as f:
        vocab = json.load(f)

    model = annDNA(
        vocab_size=len(vocab),
        d_model=config['d_model'],
        nhead=config['nhead'],
        num_layers=config['num_layers'],
        max_seq_len=config['max_seq_len']
    )

    if checkpoint_path:
        checkpoint = torch.load(checkpoint_path, map_location=device)
    else:
        checkpoint = torch.load(paths['best_model'], map_location=device)

    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    model.eval()

    return model, vocab


def load_distilled_model(device):
    """Load distilled student model"""
    distill_paths = get_distill_paths()
    distilled_paths = get_model_paths(DISTILL_CONFIG['distilled_base'])

    with open(distilled_paths['vocab_file']) as f:
        vocab = json.load(f)

    # Import DistilledModel from train_distilled
    from train_distilled import DistilledModel

    model = DistilledModel(
        vocab_size=len(vocab),
        d_model=DISTILLED_CONFIG['d_model'],
        nhead=DISTILLED_CONFIG['nhead'],
        num_layers=DISTILLED_CONFIG['num_layers'],
        max_seq_len=DISTILLED_CONFIG['max_seq_len']
    )

    checkpoint_path = distill_paths['distilled_model'] / 'best_model.pt'
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    model.eval()

    return model, vocab


def evaluate_perplexity(model, dataloader, device, is_distilled=False):
    """Compute perplexity on validation set"""
    model.eval()
    total_loss = 0
    total_tokens = 0

    criterion = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction='sum')

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Computing perplexity"):
            input_ids = batch['input_ids'].to(device)
            labels = batch['labels'].to(device)
            attention_mask = batch['attention_mask'].to(device)

            if is_distilled:
                logits, _ = model(input_ids, attention_mask)
            else:
                logits = model(input_ids, attention_mask)

            loss = criterion(logits.view(-1, logits.size(-1)), labels.view(-1))

            total_loss += loss.item()
            total_tokens += (labels != -100).sum().item()

    avg_loss = total_loss / total_tokens
    perplexity = np.exp(avg_loss)

    return perplexity, avg_loss


def evaluate_accuracy(model, dataloader, device, is_distilled=False):
    """Compute accuracy on masked positions"""
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Computing accuracy"):
            input_ids = batch['input_ids'].to(device)
            labels = batch['labels'].to(device)
            attention_mask = batch['attention_mask'].to(device)

            if is_distilled:
                logits, _ = model(input_ids, attention_mask)
            else:
                logits = model(input_ids, attention_mask)

            preds = logits.argmax(dim=-1)

            mask = labels != -100
            correct += ((preds == labels) & mask).sum().item()
            total += mask.sum().item()

    accuracy = correct / total
    return accuracy


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--gpu', default='0')
    parser.add_argument('--batch_size', type=int, default=32)
    args = parser.parse_args()

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    distill_paths = get_distill_paths()
    distill_paths['results'].mkdir(parents=True, exist_ok=True)

    print("=" * 60)
    print("Model Evaluation")
    print("=" * 60)

    results = []

    # =========================================================================
    # Evaluate seq (Baseline)
    # =========================================================================
    print("\n[1] Loading seq (baseline)...")
    seq_model, seq_vocab = load_model('seq', device)

    seq_paths = get_model_paths('seq')
    seq_loader = DataLoader(
        EvalDataset(seq_paths['processed_dir'], 'val'),
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=4
    )

    print("Evaluating seq...")
    seq_ppl, seq_loss = evaluate_perplexity(seq_model, seq_loader, device)
    seq_acc = evaluate_accuracy(seq_model, seq_loader, device)

    results.append({
        'model': 'seq (baseline)',
        'perplexity': seq_ppl,
        'loss': seq_loss,
        'accuracy': seq_acc,
    })
    print(f"  Perplexity: {seq_ppl:.4f}, Accuracy: {seq_acc:.4f}")

    del seq_model
    torch.cuda.empty_cache()

    # =========================================================================
    # Evaluate Distilled
    # =========================================================================
    print("\n[2] Loading Distilled...")
    try:
        distilled_model, distilled_vocab = load_distilled_model(device)

        # Use distillation data for distilled evaluation
        distilled_loader = DataLoader(
            EvalDataset(distill_paths['teacher_hidden'], 'val'),
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=4
        )

        print("Evaluating Distilled...")
        distilled_ppl, distilled_loss = evaluate_perplexity(
            distilled_model, distilled_loader, device, is_distilled=True
        )
        distilled_acc = evaluate_accuracy(
            distilled_model, distilled_loader, device, is_distilled=True
        )

        results.append({
            'model': 'Distilled',
            'perplexity': distilled_ppl,
            'loss': distilled_loss,
            'accuracy': distilled_acc,
        })
        print(f"  Perplexity: {distilled_ppl:.4f}, Accuracy: {distilled_acc:.4f}")

        del distilled_model
        torch.cuda.empty_cache()

    except FileNotFoundError as e:
        print(f"  Distilled model not found: {e}")
        distilled_ppl, distilled_acc = None, None

    # =========================================================================
    # Evaluate full (Teacher)
    # =========================================================================
    print("\n[3] Loading full (teacher)...")
    full_model, full_vocab = load_model('full', device)

    full_paths = get_model_paths('full')
    full_loader = DataLoader(
        EvalDataset(full_paths['processed_dir'], 'val'),
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=4
    )

    print("Evaluating full...")
    full_ppl, full_loss = evaluate_perplexity(full_model, full_loader, device)
    full_acc = evaluate_accuracy(full_model, full_loader, device)

    results.append({
        'model': 'full (teacher)',
        'perplexity': full_ppl,
        'loss': full_loss,
        'accuracy': full_acc,
    })
    print(f"  Perplexity: {full_ppl:.4f}, Accuracy: {full_acc:.4f}")

    # =========================================================================
    # Summary
    # =========================================================================
    print("\n" + "=" * 60)
    print("SUMMARY")
    print("=" * 60)

    df = pd.DataFrame(results)
    print(df.to_string(index=False))

    # Save results
    output_file = distill_paths['results'] / 'evaluation_results.csv'
    df.to_csv(output_file, index=False)
    print(f"\nResults saved to: {output_file}")

    # =========================================================================
    # Comparison
    # =========================================================================
    if distilled_ppl is not None:
        print("\n" + "=" * 60)
        print("DISTILLATION EFFECT")
        print("=" * 60)

        ppl_vs_seq = (seq_ppl - distilled_ppl) / seq_ppl * 100
        acc_vs_seq = (distilled_acc - seq_acc) / seq_acc * 100

        print(f"\nDistilled vs seq (Baseline):")
        print(f"  Perplexity: {distilled_ppl:.2f} vs {seq_ppl:.2f} ({ppl_vs_seq:+.1f}%)")
        print(f"  Accuracy:   {distilled_acc:.4f} vs {seq_acc:.4f} ({acc_vs_seq:+.1f}%)")

        if ppl_vs_seq > 0:
            print("\n  => Distillation SUCCESSFUL!")
            print("     Distilled outperforms baseline (lower perplexity)")
        else:
            print("\n  => Distillation needs tuning")
            print("     Try adjusting alpha or training longer")

        # How close to teacher?
        acc_gap_closed = (full_acc - distilled_acc) / (full_acc - seq_acc) * 100 if full_acc != seq_acc else 0
        print(f"\n  Teacher gap closed: {acc_gap_closed:.1f}%")


if __name__ == '__main__':
    main()
