#!/usr/bin/env python3
"""
Simplified Loss Landscape Visualization Script

Features:
- Multi-GPU parallel computation
- Second-order finite difference computation 
- Combined visualization of loss landscape and curvature
- Evolution plots showing training progression

Usage:
    python scripts/plot_loss_landscape.py --base_dir path/to/experiment
"""

import argparse
import os
import sys
import json
import numpy as np
import torch
import torch.nn as nn
import torch.multiprocessing as mp
import matplotlib.pyplot as plt
import h5py
import time
from pathlib import Path
from datetime import datetime
from tqdm import tqdm

# Add the project root to the Python path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from src.model.model_llama import build_llama_model
from src.utils.logger import load_training_monitor, get_logger
    
class LossComputer:
    """Simplified loss computation manager."""
    
    def __init__(self, model_config, data_loader_config, gpu_ids):
        self.model_config = model_config
        self.data_loader_config = data_loader_config
        self.gpu_ids = gpu_ids
        self.num_gpus = len(gpu_ids)
        print(f"Initialized LossComputer with {self.num_gpus} GPUs: {gpu_ids}")
    
    def compute_losses(self, original_weights, direction, alpha_values, num_eval_batches=10):
        """Compute losses for multiple alpha values using available GPUs."""
        if self.num_gpus == 1:
            return self._compute_single_gpu(original_weights, direction, alpha_values, num_eval_batches)
        else:
            return self._compute_multi_gpu(original_weights, direction, alpha_values, num_eval_batches)
    
    def _compute_single_gpu(self, original_weights, direction, alpha_values, num_eval_batches):
        """Compute losses on a single GPU."""
        device = f'cuda:{self.gpu_ids[0]}'
        torch.cuda.set_device(self.gpu_ids[0])
        
        model = build_llama_model(self.model_config['model_name']).to(device)
        criterion = nn.CrossEntropyLoss()
        
        losses = []
        for alpha in tqdm(alpha_values):
            # Compute perturbed weights
            perturbed_weights = [w + alpha * d for w, d in zip(original_weights, direction)]
            
            # Set weights on model
            for p, w in zip(model.parameters(), perturbed_weights):
                p.data.copy_(w.to(device))
            
            # Evaluate loss
            model.eval()
            total_loss = 0.0
            
            with torch.no_grad():
                for _ in range(num_eval_batches):
                    x, y = self._get_batch('val')
                    x, y = x.to(device), y.to(device)
                    
                    outputs = model(x)
                    logits = outputs.logits if hasattr(outputs, 'logits') else outputs
                    loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1))
                    total_loss += loss.item()
            
            losses.append(total_loss / num_eval_batches)
            torch.cuda.empty_cache()
        
        return np.array(losses)
    
    def _compute_multi_gpu(self, original_weights, direction, alpha_values, num_eval_batches):
        """Compute losses using multiple GPUs in parallel."""
        alpha_batches = self._split_alphas(alpha_values)
        
        with mp.Pool(processes=self.num_gpus) as pool:
            args_list = []
            for gpu_idx, (gpu_id, alphas) in enumerate(zip(self.gpu_ids, alpha_batches)):
                args = (gpu_id, self.model_config, original_weights, direction,
                       alphas, self.data_loader_config, num_eval_batches, gpu_idx)
                args_list.append(args)
            
            results = pool.map(self._compute_loss_on_gpu, args_list)
        
        # Merge results
        all_losses = {}
        for gpu_results in results:
            all_losses.update(gpu_results)
        
        return np.array([all_losses[alpha] for alpha in alpha_values])
    
    def _split_alphas(self, alpha_values):
        """Split alpha values across GPUs."""
        batches = [[] for _ in range(self.num_gpus)]
        for i, alpha in enumerate(alpha_values):
            batches[i % self.num_gpus].append(alpha)
        return batches
    
    @staticmethod
    def _compute_loss_on_gpu(args):
        """Worker function for GPU computation."""
        (gpu_id, model_config, original_weights, direction, 
         alpha_values, data_loader_config, num_eval_batches, process_idx) = args
        
        device = f'cuda:{gpu_id}'
        torch.cuda.set_device(gpu_id)
        
        model = build_llama_model(model_config['model_name']).to(device)
        criterion = nn.CrossEntropyLoss()
        
        results = {}
        for alpha in tqdm(alpha_values):
            # compute perturbed weights
            perturbed_weights = [w + alpha * d for w, d in zip(original_weights, direction)]
            for p, w in zip(model.parameters(), perturbed_weights):
                p.data.copy_(w.to(device))
            
            # evaluate loss
            model.eval()
            total_loss = 0.0
            with torch.no_grad():
                for _ in range(num_eval_batches):
                    x, y = LossComputer._get_batch_from_config(data_loader_config, 'val')
                    x, y = x.to(device), y.to(device)
                    
                    outputs = model(x)
                    logits = outputs.logits if hasattr(outputs, 'logits') else outputs
                    loss = criterion(logits.view(data_loader_config.batch_size * data_loader_config.block_size, -1), y.view(-1))
                    total_loss += loss.item()   
            results[alpha] = total_loss / num_eval_batches
            torch.cuda.empty_cache()
        
        return results
    
    def _get_batch(self, split):
        """Get a batch of data."""
        return self._get_batch_from_config(self.data_loader_config, split)
    
    @staticmethod
    def _get_batch_from_config(config, split):
        """Get a batch of data using config parameters."""
        train_path = os.path.join(config.data_dir, 'train.bin')
        val_path = os.path.join(config.data_dir, 'val.bin')
        data = np.memmap(train_path if split == 'train' else val_path, dtype=np.uint16, mode='r')
        ix = torch.randint(len(data) - config.block_size, (config.batch_size,))
        x = torch.stack([torch.from_numpy((data[i:i+config.block_size]).astype(np.int64)) for i in ix])
        y = torch.stack([torch.from_numpy((data[i+1:i+1+config.block_size]).astype(np.int64)) for i in ix])
        return x, y


class DataLoaderConfig:
    """Simple data loader configuration."""
    def __init__(self, dataset, block_size, batch_size, data_dir):
        self.dataset = dataset
        self.block_size = block_size
        self.batch_size = batch_size
        self.data_dir = data_dir


class LossLandscapeAnalyzer:
    """Main analyzer class for loss landscape visualization."""
    
    def __init__(self, base_dir, gpu_ids, norm_type, 
                 block_type='all', direction_type='random',
                 ignore_ln=True, experiment_name=None):
        self.base_dir = base_dir
        self.gpu_ids = gpu_ids
        self.norm_type = norm_type
        self.block_type = block_type 
        self.direction_type = direction_type
        self.ignore_ln = ignore_ln
        self.experiment_name = experiment_name or os.path.basename(os.path.normpath(base_dir))
        self.output_dir = self._create_output_dir()
        
    def _create_output_dir(self):
        """Create output directory with timestamp."""
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        parent_dir = Path(self.base_dir).parent
        ignore_ln_str = 'ignore_ln' if self.ignore_ln else 'include_ln'
        output_dir = os.path.join('out', 'loss_landscape_analysis', 
                                 f'landscape_{self.experiment_name}', 
                                 f'{self.norm_type}_{self.block_type}_{self.direction_type}_{ignore_ln_str}_{timestamp}')
        os.makedirs(output_dir, exist_ok=True)
        print(f"Results will be saved to: {output_dir}")
        return output_dir
    
    def find_monitors(self):
        """Find all training monitor files."""
        monitor_dir = os.path.join(self.base_dir, 'train_monitor')
        if not os.path.exists(monitor_dir):
            print(f"Monitor directory not found: {monitor_dir}")
            return {}
        
        monitors = {}
        for filename in os.listdir(monitor_dir):
            if filename.startswith('training_monitor_iter_') and filename.endswith('.pkl'):
                try:
                    iter_str = filename.replace('training_monitor_iter_', '').replace('.pkl', '')
                    iter_num = int(iter_str)
                    monitors[iter_num] = os.path.join(monitor_dir, filename)
                except ValueError:
                    continue
        return monitors
    
    def load_config(self):
        """Load experiment configuration."""
        config_path = os.path.join(self.base_dir, 'configs/experiment_config.json')
        if not os.path.exists(config_path):
            print(f"Warning: Config file not found at {config_path}")
            return {}
        
        try:
            with open(config_path, 'r') as f:
                return json.load(f)
        except Exception as e:
            print(f"Error loading config from {config_path}: {e}")
            return {}
    
    def setup_data_loader(self, config):
        """Setup data loader configuration."""
        dataset = config.get('dataset', 'finewebedu')
        block_size = config.get('block_size', 1024)
        batch_size = config.get('batch_size', 12)
        data_root = config.get('data_root', None)
        
        data_dir = os.path.join(data_root, dataset)
        if not os.path.exists(data_dir):
            raise FileNotFoundError(f"Dataset directory not found: {data_dir}")
        
        return DataLoaderConfig(dataset, block_size, batch_size, data_dir)
    
    def get_model_weights(self, model):
        """Extract model parameters."""
        named_params = [(n, p.data.cpu().clone()) for n, p in model.named_parameters()]
        params = [p for _, p in named_params]
        return named_params, params
    
    def get_random_direction(self, model, norm_type='filter'):
        """Generate random direction with normalization."""
        direction = []
        named_params, weights = self.get_model_weights(model)
        for name, w in named_params:
            if self.block_type == 'all' or self._get_param_type(name) == self.block_type:
                d = torch.randn_like(w)
            else:
                d = torch.zeros_like(w) 
            direction.append(d)
        
        self._normalize_direction(direction, weights, norm_type)
        return direction

    def _get_param_type(self, param_name):
        """Determine parameter type based on name."""
        if 'embed' in param_name:
            return 'embed'
        elif 'q_proj' in param_name or 'k_proj' in param_name:
            return 'qk'
        elif 'v_proj' in param_name or 'o_proj' in param_name:
            return 'vo'
        elif 'mlp' in param_name or 'gate_proj' in param_name or 'up_proj' in param_name or 'down_proj' in param_name:
            return 'mlp'
        elif 'norm' in param_name:
            return 'ln'
        elif 'lm_head' in param_name:
            return 'head'
        else:
            return 'other'
    
    def _normalize_direction(self, direction, weights, norm_type='filter'):
        """Normalize direction vectors."""
        for d, w in zip(direction, weights):
            if norm_type == 'filter':
                if d.dim() > 1:
                    d.mul_(w.norm() / (d.norm() + 1e-10))
                else:
                    if self.ignore_ln:
                        d.fill_(0)
                    else:
                        d.mul_(w.norm() / (d.norm() + 1e-10))
            elif norm_type == 'layer':
                d.mul_(w.norm() / (d.norm() + 1e-10))
            elif norm_type == 'weight':
                d.mul_(w)
    
    def compute_second_order(self, alpha_values, loss_values):
        """Compute second-order finite difference."""
        n = len(alpha_values)
        if n < 3:
            raise ValueError("Need at least 3 points to compute second-order finite difference")
        
        h = alpha_values[1] - alpha_values[0]
        h_squared = h * h
        second_order = np.zeros_like(loss_values)
        
        # Interior points
        for i in range(1, n-1):
            second_order[i] = (loss_values[i+1] - 2*loss_values[i] + loss_values[i-1]) / h_squared
        
        # Boundary points
        second_order[0] = (loss_values[0] - 2*loss_values[1] + loss_values[2]) / h_squared
        second_order[-1] = (loss_values[-3] - 2*loss_values[-2] + loss_values[-1]) / h_squared
        
        return second_order
    
    def compute_metrics(self, alpha_values, second_order_values):
        """Compute quantitative metrics."""
        mean_second_order = np.mean(second_order_values)
        
        positive_mask = alpha_values > 0
        negative_mask = alpha_values < 0
        
        positive_peak_alpha = None
        negative_peak_alpha = None
        alpha_distance = None
        
        if np.any(positive_mask):
            positive_alphas = alpha_values[positive_mask]
            positive_second_order = second_order_values[positive_mask]
            max_pos_idx = np.argmax(positive_second_order)
            positive_peak_alpha = positive_alphas[max_pos_idx]
        
        if np.any(negative_mask):
            negative_alphas = alpha_values[negative_mask]
            negative_second_order = second_order_values[negative_mask]
            max_neg_idx = np.argmax(negative_second_order)
            negative_peak_alpha = negative_alphas[max_neg_idx]
        
        if positive_peak_alpha is not None and negative_peak_alpha is not None:
            alpha_distance = abs(positive_peak_alpha - negative_peak_alpha)
        
        return {
            'mean_second_order': mean_second_order,
            'alpha_distance': alpha_distance,
            'positive_peak_alpha': positive_peak_alpha,
            'negative_peak_alpha': negative_peak_alpha
        }
    
    def plot_results(self, alpha_values, loss_values, second_order_values, 
                    iter_num, metrics, output_paths):
        """Generate all plots for a single iteration."""
        # Create subplots
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 10))
        
        # Loss landscape
        ax1.plot(alpha_values, loss_values, 'b-', linewidth=2, marker='o', markersize=4)
        ax1.set_xlabel('Alpha', fontsize=12)
        ax1.set_ylabel('Loss', fontsize=12)
        ax1.set_title(f'{self.experiment_name} - Loss Landscape - Iteration {iter_num}', fontsize=14)
        ax1.grid(True, alpha=0.3)
        
        # Second-order finite difference
        ax2.plot(alpha_values, second_order_values, 'r-', linewidth=2, marker='s', markersize=4)
        ax2.set_xlabel('Alpha', fontsize=12)
        ax2.set_ylabel('Second-order Finite Difference', fontsize=12)
        ax2.set_title(f'{self.experiment_name} - Second-order Curvature - Iteration {iter_num}', fontsize=14)
        ax2.grid(True, alpha=0.3)
        ax2.axhline(y=0, color='k', linestyle='--', alpha=0.5)
        
        plt.tight_layout()
        plt.savefig(output_paths['combined'], dpi=300, bbox_inches='tight')
        plt.close()
        
        # Save data
        with h5py.File(output_paths['data'], 'w') as f:
            f['alpha_values'] = alpha_values
            f['loss_values'] = loss_values
            f['second_order_values'] = second_order_values
        
        # Save metrics
        with open(output_paths['metrics'], 'w') as f:
            json.dump(metrics, f, indent=2)
    
    def plot_evolution(self, alpha_values, all_loss_values, all_second_order_values, 
                      iter_nums, all_metrics):
        """Plot evolution of landscapes over iterations."""
        if len(iter_nums) < 2:
            return
        
        # Combined evolution plot
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 12))
        colors = plt.cm.viridis(np.linspace(0, 1, len(iter_nums)))
        
        # Loss landscape evolution
        for i, (iter_num, loss_values) in enumerate(zip(iter_nums, all_loss_values)):
            ax1.plot(alpha_values, loss_values, color=colors[i], linewidth=2, 
                    marker='o', markersize=3, alpha=0.8, label=f'Iter {iter_num}')
        
        ax1.set_xlabel('Alpha (Perturbation Scale)', fontsize=14)
        ax1.set_ylabel('Loss', fontsize=14)
        ax1.set_title(f'{self.experiment_name} - Loss Landscape Evolution', fontsize=16)
        ax1.grid(True, alpha=0.3)
        
        # Second-order evolution
        for i, (iter_num, second_order_values) in enumerate(zip(iter_nums, all_second_order_values)):
            ax2.plot(alpha_values, second_order_values, color=colors[i], linewidth=2, 
                    marker='s', markersize=3, alpha=0.8, label=f'Iter {iter_num}')
        
        ax2.set_xlabel('Alpha (Perturbation Scale)', fontsize=14)
        ax2.set_ylabel('Second-order Finite Difference', fontsize=14)
        ax2.set_title(f'{self.experiment_name} - Second-order Curvature Evolution', fontsize=16)
        ax2.grid(True, alpha=0.3)
        ax2.axhline(y=0, color='k', linestyle='--', alpha=0.5)
        
        # Add colorbar
        sm = plt.cm.ScalarMappable(cmap=plt.cm.viridis, 
                                  norm=plt.Normalize(vmin=min(iter_nums), vmax=max(iter_nums)))
        sm.set_array([])
        cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
        cbar = fig.colorbar(sm, cax=cbar_ax)
        cbar.set_label('Training Iteration', fontsize=12)
        
        plt.tight_layout(rect=[0, 0, 0.9, 1])
        evolution_path = os.path.join(self.output_dir, f'{self.experiment_name}_evolution_summary.png')
        plt.savefig(evolution_path, dpi=300, bbox_inches='tight')
        plt.close()
        
        # Metrics evolution
        self._plot_metrics_evolution(iter_nums, all_metrics)
        
        print(f"Saved evolution summary: {evolution_path}")
    
    def _plot_metrics_evolution(self, iter_nums, all_metrics):
        """Plot evolution of quantitative metrics."""
        mean_second_orders = [metrics['mean_second_order'] for metrics in all_metrics]
        alpha_distances = [metrics['alpha_distance'] for metrics in all_metrics if metrics['alpha_distance'] is not None]
        valid_iter_nums = [iter_nums[i] for i, metrics in enumerate(all_metrics) if metrics['alpha_distance'] is not None]
        
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10))
        
        # Mean second-order evolution
        ax1.plot(iter_nums, mean_second_orders, 'bo-', linewidth=2, markersize=6)
        ax1.set_xlabel('Training Iteration', fontsize=12)
        ax1.set_ylabel('Mean Second-order Finite Difference', fontsize=12)
        ax1.set_title('Evolution of Mean Second-order Finite Difference', fontsize=14)
        ax1.grid(True, alpha=0.3)
        ax1.axhline(y=0, color='k', linestyle='--', alpha=0.5)
        
        # Alpha distance evolution
        if alpha_distances:
            ax2.plot(valid_iter_nums, alpha_distances, 'ro-', linewidth=2, markersize=6)
            ax2.set_xlabel('Training Iteration', fontsize=12)
            ax2.set_ylabel('Alpha Distance Between Peaks', fontsize=12)
            ax2.set_title('Evolution of Alpha Distance Between Peaks', fontsize=14)
            ax2.grid(True, alpha=0.3)
        else:
            ax2.text(0.5, 0.5, 'No valid alpha distances computed', 
                    ha='center', va='center', transform=ax2.transAxes, fontsize=12)
            ax2.set_title('Alpha Distance Between Peaks (No Data)', fontsize=14)
        
        plt.tight_layout()
        metrics_path = os.path.join(self.output_dir, f'{self.experiment_name}_metrics_evolution.png')
        plt.savefig(metrics_path, dpi=300, bbox_inches='tight')
        plt.close()
        
        print(f"Saved metrics evolution: {metrics_path}")
    
    def process_iteration(self, monitor_path, model_config, direction, alpha_range, 
                         num_points, data_loader_config, loss_computer, iter_num, num_eval_batches):
        """Process a single iteration."""
        print(f"Processing iteration {iter_num}...")
        
        # Load monitor and weights
        monitor = load_training_monitor(monitor_path)
        temp_model = build_llama_model(model_config['model_name'])
        temp_model.load_state_dict(monitor['model_state_dict'])
        _, original_weights = self.get_model_weights(temp_model)

        if self.direction_type == 'iter':
            print(f"reset direction each iter when direction type is {self.direction_type}")
            direction = self.get_random_direction(temp_model, self.norm_type)
        
        # Generate alpha values and compute losses
        alpha_values = np.linspace(alpha_range[0], alpha_range[1], num_points)
        loss_values = loss_computer.compute_losses(original_weights, direction, alpha_values, num_eval_batches)
        
        # Compute second-order and metrics
        second_order_values = self.compute_second_order(alpha_values, loss_values)
        metrics = self.compute_metrics(alpha_values, second_order_values)
        
        # Print metrics
        print(f"  Metrics for iteration {iter_num}:")
        print(f"    Mean second-order: {metrics['mean_second_order']:.6f}")
        if metrics['alpha_distance'] is not None:
            print(f"    Alpha distance: {metrics['alpha_distance']:.6f}")
        
        # Save results
        iter_output_dir = os.path.join(self.output_dir, 'iterations', f'iter_{iter_num}')
        os.makedirs(iter_output_dir, exist_ok=True)
        
        output_paths = {
            'combined': os.path.join(iter_output_dir, f'{self.experiment_name}_combined_iter_{iter_num}.png'),
            'data': os.path.join(iter_output_dir, f'{self.experiment_name}_data_iter_{iter_num}.h5'),
            'metrics': os.path.join(iter_output_dir, f'{self.experiment_name}_metrics_iter_{iter_num}.json')
        }
        
        self.plot_results(alpha_values, loss_values, second_order_values, iter_num, metrics, output_paths)
        
        return alpha_values, loss_values, second_order_values, metrics
    
    def run_analysis(self, alpha_range, num_points, num_eval_batches):
        """Run the complete loss landscape analysis."""
        # Load configuration
        config = self.load_config()
        if not config:
            print("Error: Could not load experiment configuration.")
            return 1
        
        model_name = config.get('model_name', '170M')
        print(f"Model: {model_name}")
        
        # Setup data loader and loss computer
        data_loader_config = self.setup_data_loader(config)
        model_config = {'model_name': model_name}
        loss_computer = LossComputer(model_config, data_loader_config, self.gpu_ids)
        
        # Find and process monitors
        monitors = self.find_monitors()
        if not monitors:
            print(f"No monitors found in {self.base_dir}")
            return 1
        
        sorted_iters = sorted(monitors.keys())
        print(f"Found {len(sorted_iters)} monitors: {sorted_iters}")

        # Generate direction
        temp_model = build_llama_model(model_name)
        if self.direction_type == 'checkpoint':
            print(f"load the last checkpoint when direction type is {self.direction_type}")
            # load the last checkpoint
            last_monitor = load_training_monitor(monitors[sorted_iters[-1]])
            temp_model.load_state_dict(last_monitor['model_state_dict'])
        direction = self.get_random_direction(temp_model, self.norm_type)
        
        # Process all iterations
        all_loss_values = []
        all_second_order_values = []
        all_metrics = []
        processed_iters = []
        alpha_values = None
        
        total_start_time = time.time()
        
        for iter_num in sorted_iters:
            try:
                alpha_vals, loss_vals, second_order_vals, metrics = self.process_iteration(
                    monitors[iter_num], model_config, direction, alpha_range,
                    num_points, data_loader_config, loss_computer, iter_num, num_eval_batches
                )
                
                if alpha_values is None:
                    alpha_values = alpha_vals
                
                all_loss_values.append(loss_vals)
                all_second_order_values.append(second_order_vals)
                all_metrics.append(metrics)
                processed_iters.append(iter_num)
                
            except Exception as e:
                print(f"Error processing iteration {iter_num}: {e}")
                continue
        
        total_end_time = time.time()
        
        if processed_iters:
            # Generate evolution plots
            self.plot_evolution(alpha_values, all_loss_values, all_second_order_values, 
                              processed_iters, all_metrics)
            
            # Print summary
            avg_time = (total_end_time - total_start_time) / len(processed_iters)
            print(f"\nAnalysis completed: {len(processed_iters)} iterations")
            print(f"Average time per iteration: {avg_time:.2f} seconds")
            print(f"Results saved in: {self.output_dir}")
            
            # Print metrics summary
            print(f"\nMetrics Summary:")
            print(f"{'Iteration':<10} {'Mean 2nd Order':<15} {'Alpha Distance':<15}")
            print("-" * 40)
            for iter_num, metrics in zip(processed_iters, all_metrics):
                distance_str = f"{metrics['alpha_distance']:.6f}" if metrics['alpha_distance'] is not None else "N/A"
                print(f"{iter_num:<10} {metrics['mean_second_order']:<15.6f} {distance_str:<15}")
        
        return 0


def main():
    # Setup multiprocessing
    mp.set_start_method('spawn', force=True)
    
    parser = argparse.ArgumentParser(description='Simplified Loss Landscape Visualization for TPPT')
    # Required arguments
    parser.add_argument('--base_dir', type=str, required=True,
                        help='Base experiment directory containing monitors')
    # GPU settings
    parser.add_argument('--gpus', type=int, nargs='*', default=None,
                        help='GPU IDs to use (default: all available)')
    # Landscape settings
    parser.add_argument('--alpha_range', type=float, nargs=2, default=[-1.0, 1.0],
                        help='Range of alpha values (default: -1.0 1.0)')
    parser.add_argument('--num_points', type=int, default=21,
                        help='Number of points to evaluate (default: 21)')
    parser.add_argument('--num_eval_batches', type=int, default=10,
                        help='Number of evaluation batches per point (default: 10)')
    parser.add_argument('--norm_type', type=str, default='filter', 
                        choices=['filter', 'layer', 'weight'],
                        help='Direction normalization type (default: filter)')
    parser.add_argument('--direction_type', type=str, default='random',
                        choices=['random', 'checkpoint', 'iter'],
                        help='Type of direction to use (default: random)')
    parser.add_argument('--block_type', type=str, default='all', 
                    choices=['all', 'embed', 'qk', 'vo', 'mlp', 'ln', 'head'],
                    help='Type of block to perturb (default: all)')
    parser.add_argument('--ignore_ln', action='store_true', default=False, dest='ignore_ln',
                        help='Ignore ln weights (default: True)')
    # Other settings
    parser.add_argument('--seed', type=int, default=42,
                        help='Random seed for direction generation (default: 42)')
    args = parser.parse_args()
    
    # Setup GPUs
    if args.gpus is None:
        if torch.cuda.is_available():
            gpu_ids = list(range(torch.cuda.device_count()))
        else:
            print("Error: No CUDA devices available")
            return 1
    else:
        gpu_ids = args.gpus
        
    if not gpu_ids:
        print("Error: No GPUs specified")
        return 1
        
    print(f"Using GPUs: {gpu_ids}")
    
    # Validate GPU availability
    for gpu_id in gpu_ids:
        if gpu_id >= torch.cuda.device_count():
            print(f"Error: GPU {gpu_id} not available")
            return 1
    
    # Set random seed
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    
    # Create analyzer and run analysis
    analyzer = LossLandscapeAnalyzer(
        args.base_dir, gpu_ids, args.norm_type, block_type=args.block_type, 
        direction_type=args.direction_type, ignore_ln=args.ignore_ln
    )
    
    return analyzer.run_analysis(
        alpha_range=args.alpha_range, 
        num_points=args.num_points, 
        num_eval_batches=args.num_eval_batches,
    )


if __name__ == "__main__":
    sys.exit(main())
