#!/usr/bin/env python3
"""
Sharpness Evolution Analysis Script

This script loads training monitors from a directory and evaluates the sharpness
(largest eigenvalue of the Hessian) during training. It plots the sharpness evolution
and saves both eigenvalues and eigenvectors.

Usage:
    python scripts/eval_sharpness_evolution.py --base_dir path/to/experiment --neigs 1
"""

import argparse
import os
import sys
import json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import h5py
import time
from pathlib import Path
from datetime import datetime
from tqdm import tqdm
from contextlib import nullcontext
from typing import Callable
from torch.nn.utils import parameters_to_vector, vector_to_parameters
from torch.utils.data import DataLoader, Dataset, TensorDataset
from scipy.sparse.linalg import LinearOperator, eigsh

# Add the project root to the Python path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from src.model.model_llama import build_llama_model
from src.utils.logger import load_training_monitor


def lanczos(device: torch.device,
            matrix_vector: Callable, 
            dim: int, 
            neigs: int):
    """Invoke the Lanczos algorithm to compute the leading eigenvalues and eigenvectors of a matrix.
    
    Args:
        device: GPU or CPU
        matrix_vector: the matrix-vector product
        dim: the dimension of the matrix
        neigs: the number of eigenvalues to compute

    Returns:
        the eigenvalues and eigenvectors
    """
    def mv(vec: np.ndarray): # vec: numpy array
        gpu_vec = torch.tensor(vec, dtype=torch.float).to(device)
        return matrix_vector(gpu_vec).detach().cpu() # which should be a torch tensor on CPU

    operator = LinearOperator((dim, dim), matvec=mv)
    evals, evecs = eigsh(operator, neigs)
    return torch.from_numpy(np.ascontiguousarray(evals[::-1]).copy()).float(), \
           torch.from_numpy(np.ascontiguousarray(np.flip(evecs, -1)).copy()).float()


def get_filtered_parameters(model, exclude_ln=True):
    """Get model parameters, optionally excluding LayerNorm parameters.
    
    Args:
        model: the model
        exclude_ln: whether to exclude LayerNorm parameters
        
    Returns:
        filtered parameters and parameter names
    """
    if exclude_ln:
        # Filter out LayerNorm parameters
        filtered_params = []
        param_names = []
        for name, param in model.named_parameters():
            if 'norm' not in name.lower():  # Exclude parameters with 'norm' in the name
                filtered_params.append(param)
                param_names.append(name)
        return filtered_params, param_names
    else:
        # Include all parameters
        params = list(model.parameters())
        names = [name for name, _ in model.named_parameters()]
        return params, names


def compute_hvp(device: torch.device,
                model: nn.Module, 
                dataset: Dataset, 
                loss_fn: Callable,
                vector: torch.Tensor, 
                physical_batch_size,
                exclude_ln=True) -> torch.Tensor:
    """Compute a Hessian-vector product.

    Args:
        device: GPU or CPU
        model: the model
        dataset: the dataset
        loss_fn: the loss function
        vector: the vector
        physical_batch_size: the physical batch size
        exclude_ln: whether to exclude LayerNorm parameters

    Returns:
        the Hessian-vector product
    """
    # Get filtered parameters
    filtered_params, _ = get_filtered_parameters(model, exclude_ln)
    p = len(parameters_to_vector(filtered_params))
    n = len(dataset)
    
    print(f"Computing HVP with {p} parameters (exclude_ln={exclude_ln})")
    
    hvp = torch.zeros(p, dtype=torch.float, device=device)
    vector = vector.to(device)
    
    # Disable efficient attention for Hessian computation
    print("Disabling efficient attention for Hessian computation...")
    torch.backends.cuda.enable_flash_sdp(False)
    torch.backends.cuda.enable_mem_efficient_sdp(False)
    torch.backends.cuda.enable_math_sdp(True)
    
    try:
        dataloader = DataLoader(dataset, batch_size=physical_batch_size, shuffle=False)
        for (X, y) in tqdm(dataloader, desc="Computing HVP"):
            # move to GPU
            X, y = X.to(device), y.to(device)
            # compute the Hessian-vector product
            with torch.enable_grad():
                outputs = model(X)
                logits = outputs.logits if hasattr(outputs, 'logits') else outputs
                # Reshape logits and targets for loss computation
                logits_flat = logits.view(-1, logits.size(-1))
                y_flat = y.view(-1)
                loss = loss_fn(logits_flat, y_flat) / n
                
                grads = torch.autograd.grad(loss, inputs=filtered_params, create_graph=True)
                dot = parameters_to_vector(grads).mul(vector).sum()
                grads = [g.contiguous() for g in torch.autograd.grad(dot, filtered_params, retain_graph=True)]
                hvp += parameters_to_vector(grads)
    except RuntimeError as e:
        if "derivative for aten::_scaled_dot_product_efficient_attention_backward is not implemented" in str(e):
            print("Warning: Efficient attention still enabled. This may cause issues with Hessian computation.")
            print("Consider using a different PyTorch version or model configuration.")
            raise e
        else:
            raise e
    
    return hvp


def get_hessian(device: torch.device,
                model: nn.Module, 
                dataset: Dataset,
                loss_fn: nn.Module, 
                neigs: int = 6, 
                physical_batch_size: int = 1000,
                exclude_ln: bool = True) -> torch.Tensor:
    """Compute the leading Hessian eigenvalues.

    Args:
        device: GPU or CPU
        model: the model
        dataset: the dataset
        loss_fn: the loss function
        neigs: the number of eigenvalues to compute
        physical_batch_size: the physical batch size
        exclude_ln: whether to exclude LayerNorm parameters
    
    Returns:
        the eigenvalues and eigenvectors
    """
    hvp_delta = lambda delta: compute_hvp(device, model, dataset, loss_fn,
        delta, physical_batch_size=physical_batch_size, exclude_ln=exclude_ln)
    
    # Get the correct number of parameters based on filtering
    filtered_params, _ = get_filtered_parameters(model, exclude_ln)
    nparams = len(parameters_to_vector(filtered_params))
    
    evals, evecs = lanczos(device, hvp_delta, nparams, neigs=neigs)
    return evals, evecs


class SharpnessEvaluator:
    """Class to handle sharpness evaluation across training checkpoints."""
    
    def __init__(self, args):
        self.args = args
        self.device = torch.device(args.device)
        self.dtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[args.dtype]
        self.ctx = nullcontext() if args.device == 'cpu' else torch.amp.autocast(device_type=args.device, dtype=self.dtype)
        
        # Create output directory
        self.output_dir = self._create_output_dir()
        
        # Load data
        self.eval_data = self._load_data()
        self.eval_dataset = self.create_eval_dataset(self.args.num_samples)
        
    def _create_output_dir(self):
        """Create output directory with timestamp."""
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        experiment_name = os.path.basename(os.path.normpath(self.args.base_dir))
        ln_suffix = 'exclude_ln' if self.args.exclude_ln else 'include_ln'
        output_dir = os.path.join('out', 'sharpness_analysis', 
                                 f'sharpness_{experiment_name}_{ln_suffix}_{timestamp}')
        os.makedirs(output_dir, exist_ok=True)
        print(f"Results will be saved to: {output_dir}")
        return output_dir
    
    def _load_data(self):
        """Load evaluation data."""
        data = np.memmap(self.args.data_path, dtype=np.uint16, mode='r')
        return data
    
    def find_monitors(self):
        """Find all training monitor files."""
        monitor_dir = os.path.join(self.args.base_dir, 'train_monitor')
        if not os.path.exists(monitor_dir):
            print(f"Monitor directory not found: {monitor_dir}")
            return {}
        
        monitors = {}
        for filename in os.listdir(monitor_dir):
            if filename.startswith('training_monitor_iter_') and filename.endswith('.pkl'):
                try:
                    iter_str = filename.replace('training_monitor_iter_', '').replace('.pkl', '')
                    iter_num = int(iter_str)
                    monitors[iter_num] = os.path.join(monitor_dir, filename)
                except ValueError:
                    continue
        return monitors
    
    def load_config(self):
        """Load experiment configuration."""
        config_path = os.path.join(self.args.base_dir, 'configs/experiment_config.json')
        if not os.path.exists(config_path):
            print(f"Warning: Config file not found at {config_path}")
            return {}
        
        try:
            with open(config_path, 'r') as f:
                return json.load(f)
        except Exception as e:
            print(f"Error loading config from {config_path}: {e}")
            return {}
    
    def get_eval_batch(self):
        """Get a batch of evaluation data."""
        data = self.eval_data
        ix = torch.randint(len(data) - self.args.block_size, (self.args.batch_size,))
        x = torch.stack([torch.from_numpy((data[i:i+self.args.block_size]).astype(np.int64)) for i in ix])
        y = torch.stack([torch.from_numpy((data[i+1:i+1+self.args.block_size]).astype(np.int64)) for i in ix])
        
        if self.device.type == 'cuda':
            x, y = x.pin_memory().to(self.device, non_blocking=True), y.pin_memory().to(self.device, non_blocking=True)
        else:
            x, y = x.to(self.device), y.to(self.device)
        
        return x, y
    
    def create_eval_dataset(self, num_samples=12):
        """Create a dataset for Hessian computation."""
        print(f"Creating evaluation dataset with {num_samples} samples")
        
        inputs = []
        targets = []
        
        for _ in tqdm(range(num_samples), desc="Creating dataset"):
            x, y = self.get_eval_batch()
            # Keep sequences in their original shape [batch_size, block_size]
            # Move to CPU to save GPU memory during dataset creation
            inputs.append(x.cpu())
            targets.append(y.cpu())
        
        # Stack all samples
        inputs = torch.cat(inputs, dim=0)  # [num_samples * batch_size, block_size]
        targets = torch.cat(targets, dim=0)  # [num_samples * batch_size, block_size]
        
        print(f"Dataset created with shape: inputs {inputs.shape}, targets {targets.shape}")
        
        # Create dataset
        dataset = TensorDataset(inputs, targets)
        return dataset
    
    def restore_eigenvector_to_model(self, model, eigenvector, exclude_ln=True):
        """Restore eigenvector to model parameters (as shown in tmp.py).
        
        Args:
            model: the model
            eigenvector: the eigenvector to restore
            exclude_ln: whether to exclude LayerNorm parameters (must match the filtering used in Hessian computation)
        """
        # Get filtered parameters to match the Hessian computation
        filtered_params, _ = get_filtered_parameters(model, exclude_ln)
        
        # Convert eigenvector to filtered model parameters
        vector_to_parameters(eigenvector, filtered_params)
        return model
    
    def evaluate_checkpoint(self, monitor_path, iter_num):
        """Evaluate sharpness for a single checkpoint."""
        print(f"Evaluating checkpoint at iteration {iter_num}...")
        
        # Load model
        monitor = load_training_monitor(monitor_path)
        model = build_llama_model(monitor['config']['model_name'])
        model.load_state_dict(monitor['model_state_dict'])
        model.to(self.device)
        model.eval()
        
        # Define loss function for Hessian computation
        def loss_fn(logits_flat, targets_flat):
            return F.cross_entropy(logits_flat, targets_flat, ignore_index=-1, reduction='sum')
        
        # Compute Hessian eigenvalues and eigenvectors
        start_time = time.time()
        evals, evecs = get_hessian(
            device=self.device,
            model=model,
            dataset=self.eval_dataset,
            loss_fn=loss_fn,
            neigs=self.args.neigs,
            physical_batch_size=self.args.physical_batch_size,
            exclude_ln=self.args.exclude_ln
        )
        computation_time = time.time() - start_time
        
        print(f"Hessian computation completed in {computation_time:.2f} seconds")
        print(f"Top {self.args.neigs} eigenvalues: {evals}")
        
        # Save eigenvectors as model parameters
        eigenvector_models = []
        for i in range(self.args.neigs):
            # Create a copy of the model
            eigenvector_model = build_llama_model(monitor['config']['model_name'])
            eigenvector_model.load_state_dict(monitor['model_state_dict'])
            
            # Restore the eigenvector to model parameters
            eigenvector_models.append(self.restore_eigenvector_to_model(eigenvector_model, evecs[:, i], self.args.exclude_ln))
        
        return evals, evecs, eigenvector_models, computation_time
    
    def plot_sharpness_evolution(self, iter_nums, all_evals, all_evecs):
        """Plot the evolution of sharpness over training."""
        if len(iter_nums) < 2:
            print("Need at least 2 checkpoints to plot evolution")
            return
        
        # Extract largest eigenvalues (sharpness)
        sharpness_values = [evals[0].item() for evals in all_evals]
        
        # Create plots
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10))
        
        # Sharpness evolution
        ax1.plot(iter_nums, sharpness_values, 'bo-', linewidth=2, markersize=6)
        ax1.set_xlabel('Training Iteration', fontsize=12)
        ax1.set_ylabel('Sharpness (Largest Eigenvalue)', fontsize=12)
        ln_status = 'w/o LayerNorm' if self.args.exclude_ln else 'w/ LayerNorm'
        ax1.set_title(f'Sharpness Evolution During Training ({ln_status})', fontsize=14)
        ax1.grid(True, alpha=0.3)
        ax1.set_yscale('log')  # Log scale for better visualization
        
        # All eigenvalues evolution
        if self.args.neigs > 1:
            colors = plt.cm.viridis(np.linspace(0, 1, self.args.neigs))
            for i in range(self.args.neigs):
                eigenvals = [evals[i].item() for evals in all_evals]
                ax2.plot(iter_nums, eigenvals, color=colors[i], linewidth=2, 
                        marker='o', markersize=4, alpha=0.8, label=f'Eigenvalue {i+1}')
            
            ax2.set_xlabel('Training Iteration', fontsize=12)
            ax2.set_ylabel('Eigenvalue', fontsize=12)
            ax2.set_title('Top Eigenvalues Evolution', fontsize=14)
            ax2.grid(True, alpha=0.3)
            ax2.set_yscale('log')
            ax2.legend()
        else:
            ax2.text(0.5, 0.5, 'Only one eigenvalue computed', 
                    ha='center', va='center', transform=ax2.transAxes, fontsize=12)
            ax2.set_title('Top Eigenvalues Evolution (Single Eigenvalue)', fontsize=14)
        
        plt.tight_layout()
        plot_path = os.path.join(self.output_dir, 'sharpness_evolution.png')
        plt.savefig(plot_path, dpi=300, bbox_inches='tight')
        plt.close()
        
        print(f"Saved sharpness evolution plot: {plot_path}")
    
    def save_results(self, iter_nums, all_evals, all_evecs, all_eigenvector_models, computation_times):
        """Save all results to files."""
        # Save eigenvalues and eigenvectors
        results_path = os.path.join(self.output_dir, 'sharpness_results.h5')
        with h5py.File(results_path, 'w') as f:
            f['iterations'] = np.array(iter_nums)
            f['eigenvalues'] = np.array([evals.numpy() for evals in all_evals])
            f['eigenvectors'] = np.array([evecs.numpy() for evecs in all_evecs])
            f['computation_times'] = np.array(computation_times)
        
        # Save eigenvector models (as state dicts)
        for i, iter_num in enumerate(iter_nums):
            iter_dir = os.path.join(self.output_dir, f'iter_{iter_num}')
            os.makedirs(iter_dir, exist_ok=True)
            
            for j in range(self.args.neigs):
                eigenvector_model = all_eigenvector_models[i][j]
                model_path = os.path.join(iter_dir, f'eigenvector_{j+1}_model.pth')
                torch.save(eigenvector_model.state_dict(), model_path)
        
        # Save summary
        summary_path = os.path.join(self.output_dir, 'summary.json')
        summary = {
            'iterations': iter_nums,
            'sharpness_values': [evals[0].item() for evals in all_evals],
            'computation_times': computation_times,
            'total_computation_time': sum(computation_times),
            'num_checkpoints': len(iter_nums),
            'neigs': self.args.neigs,
            'exclude_ln': self.args.exclude_ln,
        }
        
        with open(summary_path, 'w') as f:
            json.dump(summary, f, indent=2)
        
        print(f"Saved results to: {results_path}")
        print(f"Saved summary to: {summary_path}")
    
    def run_analysis(self):
        """Run the complete sharpness evolution analysis."""
        # Find monitors
        monitors = self.find_monitors()
        if not monitors:
            print(f"No monitors found in {self.args.base_dir}")
            return 1
        
        sorted_iters = sorted(monitors.keys())[1:]
        print(f"Found {len(sorted_iters)} monitors: {sorted_iters}")
        
        # Process all checkpoints
        all_evals = []
        all_evecs = []
        all_eigenvector_models = []
        computation_times = []
        processed_iters = []
        
        total_start_time = time.time()
        
        for iter_num in sorted_iters:
            try:
                evals, evecs, eigenvector_models, comp_time = self.evaluate_checkpoint(
                    monitors[iter_num], iter_num
                )
                
                all_evals.append(evals)
                all_evecs.append(evecs)
                all_eigenvector_models.append(eigenvector_models)
                computation_times.append(comp_time)
                processed_iters.append(iter_num)
                
            except Exception as e:
                print(f"Error processing iteration {iter_num}: {e}")
                continue

            if processed_iters:
                self.plot_sharpness_evolution(processed_iters, all_evals, all_evecs)
                self.save_results(processed_iters, all_evals, all_evecs, all_eigenvector_models, computation_times)
        
        total_end_time = time.time()
    
        # Print summary
        avg_time = (total_end_time - total_start_time) / len(processed_iters)
        print(f"\nAnalysis completed: {len(processed_iters)} checkpoints")
        print(f"Average time per checkpoint: {avg_time:.2f} seconds")
        print(f"Total computation time: {total_end_time - total_start_time:.2f} seconds")
        print(f"Results saved in: {self.output_dir}")
            
        # Print sharpness summary
        print(f"\nSharpness Summary:")
        print(f"{'Iteration':<10} {'Sharpness':<15} {'Comp Time (s)':<15}")
        print("-" * 40)
        for iter_num, evals, comp_time in zip(processed_iters, all_evals, computation_times):
            print(f"{iter_num:<10} {evals[0].item():<15.6f} {comp_time:<15.2f}")
        
        return 0


def set_seed(seed):
    """Set the seed for reproducibility."""
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True


def main():
    parser = argparse.ArgumentParser(description='Sharpness Evolution Analysis for TPPT')
    
    # Required arguments
    parser.add_argument('--base_dir', type=str, required=True,
                        help='Base experiment directory containing monitors')
    parser.add_argument('--data_path', type=str, required=True,
                        help='Path to evaluation data file')
    
    # Model and computation settings
    parser.add_argument('--block_size', type=int, default=1024,
                        help='Block size for data loading (default: 1024)')
    parser.add_argument('--batch_size', type=int, default=12,
                        help='Batch size for evaluation (default: 12)')
    parser.add_argument('--neigs', type=int, default=2,
                        help='Number of eigenvalues to compute (default: 1)')
    parser.add_argument('--physical_batch_size', type=int, default=16,
                        help='Physical batch size for Hessian computation (default: 16)')
    parser.add_argument('--num_samples', type=int, default=12,
                        help='Number of samples for evaluation dataset (default: 12)')
    
    # Device settings
    parser.add_argument('--device', type=str, default='cuda',
                        help='Device to use (default: cuda)')
    parser.add_argument('--dtype', type=str, default='bfloat16',
                        choices=['float32', 'bfloat16', 'float16'],
                        help='Data type (default: bfloat16)')
    
    # Other settings
    parser.add_argument('--seed', type=int, default=42,
                        help='Random seed (default: 42)')
    parser.add_argument('--exclude_ln', action='store_true', dest='exclude_ln', default=False,
                        help='Exclude LayerNorm parameters from sharpness computation')
    
    args = parser.parse_args()
    
    set_seed(args.seed)
    
    # Create evaluator and run analysis
    evaluator = SharpnessEvaluator(args)
    evaluator.run_analysis()

if __name__ == "__main__":
    sys.exit(main())
