"""Evaluation script for ACE model with comprehensive metrics and timing."""

import argparse
import json
import time
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
from omegaconf import OmegaConf
from torch.utils.data import DataLoader
from tqdm import tqdm

from src.data.gp_sampler import GPBatchLoader, GPSampler, generate_offline_batches
from src.models.ace import AmortizedConditioningEngine, InferenceEngine
from src.models.modules import Embedder, MixtureGaussian, Transformer
from src.utils import DataAttr


class ModelEvaluator:
    """Professional evaluation framework for ACE models."""
    
    def __init__(
        self,
        checkpoint_path: str,
        data_path: Optional[str] = None,
        device: str = "cuda",
        batch_size: int = 32,
        num_eval_batches: int = 100,
        K: int = 4,
        save_dir: str = "./eval_results",
        compile: bool = False,
    ):
        """Initialize evaluator with model and evaluation settings.
        
        Args:
            checkpoint_path: Path to model checkpoint
            data_path: Path to evaluation data (if None, generates online)
            device: Device to run evaluation on
            batch_size: Batch size for evaluation
            num_eval_batches: Number of batches to evaluate
            K: Decoding batch size for sample_sequence
            save_dir: Directory to save evaluation results
            compile: Whether to use torch.compile on inference methods
        """
        self.device = torch.device(device)
        self.batch_size = batch_size
        self.num_eval_batches = num_eval_batches
        self.K = K
        self.save_dir = Path(save_dir)
        self.save_dir.mkdir(parents=True, exist_ok=True)
        
        # Load model and config
        print(f"Loading checkpoint from {checkpoint_path}")
        self.checkpoint = torch.load(checkpoint_path, map_location=self.device)
        self.config = OmegaConf.create(self.checkpoint["config"])
        self.epoch = self.checkpoint.get("epoch", 0)
        self.global_step = self.checkpoint.get("global_step", 0)
        
        # Build and load model
        self.model = self._build_model()
        
        # Handle torch.compile prefix if present
        state_dict = self.checkpoint["model_state_dict"]
        if any(key.startswith("_orig_mod.") for key in state_dict.keys()):
            # Remove _orig_mod. prefix from keys
            state_dict = {
                key.replace("_orig_mod.", ""): value 
                for key, value in state_dict.items()
            }
        
        self.model.load_state_dict(state_dict)
        self.model = self.model.to(self.device)
        self.model.eval()
        
        # Compile model if requested
        self.compile = compile
        if self.compile:
            print("Compiling model with torch.compile (mode=default)...")
            self.model = torch.compile(self.model, mode="default")
        
        # Create inference engine
        self.inference_engine = InferenceEngine.from_trained_model(self.model)
        self.inference_engine = self.inference_engine.to(self.device)
        
        # Setup data
        self.data_path = data_path
        if data_path:
            self.dataloader = self._load_offline_data(data_path)
        else:
            self.sampler = self._create_sampler()
            self.dataloader = None
            
        # Initialize metrics storage
        self.metrics = {
            "timing": {
                "sample_sequence_times": [],
                "prepare_cache_times": [],
                "batch_decode_times": [],
                "per_sample_times": [],
                "total_inference_time": 0.0,
                # Detailed profiling
                "prefill_times": [],  # Self-attention prefill
                "transformer_decode_times": [],  # Cache-based decode
                "embedding_times": [],  # Embedding computation
                "head_times": [],  # Head prediction
                "context_update_times": [],  # Context embedding updates
            },
            "model_stats": {
                "checkpoint_path": checkpoint_path,
                "epoch": self.epoch,
                "global_step": self.global_step,
                "num_parameters": sum(p.numel() for p in self.model.parameters()),
                "device": str(self.device),
                "compiled": self.compile,
            },
            "evaluation": {
                "num_batches": num_eval_batches,
                "batch_size": batch_size,
                "K": K,
                "total_samples": 0,
            },
            "data_stats": {},
            "predictions": [],
        }
    
    def _build_model(self) -> AmortizedConditioningEngine:
        """Build model from config."""
        cfg = self.config.model
        
        embedder = Embedder(
            dim_x=cfg.dim_x,
            dim_y=cfg.dim_y,
            hidden_dim=cfg.embedder.hidden_dim,
            out_dim=cfg.dim_model,  # Use dim_model from config
            depth=cfg.embedder.depth,
        )
        
        backbone = Transformer(
            num_layers=cfg.backbone.num_layers,
            dim_model=cfg.dim_model,
            num_head=cfg.backbone.num_heads,
            dim_feedforward=cfg.backbone.dim_feedforward,
            dropout=cfg.backbone.dropout,
        )
        
        head = MixtureGaussian(
            dim_y=cfg.dim_y,
            dim_model=cfg.dim_model,
            dim_feedforward=cfg.head.dim_feedforward,
            num_components=cfg.head.num_components,
        )
        
        model = AmortizedConditioningEngine(
            embedder=embedder,
            backbone=backbone,
            head=head,
            max_buffer_size=cfg.max_buffer_size,
            num_target_points=cfg.num_target_points,
            targets_block_size_for_buffer_attend=cfg.targets_block_size_for_buffer_attend,
        )
        
        return model
    
    def _create_sampler(self) -> GPSampler:
        """Create GP sampler for online data generation."""
        return GPSampler(
            x_range=[[-2.0], [2.0]],
            kernel_list=["rbf", "matern52"],
            kernel_weights=[0.6, 0.4],
            lengthscale_range=[0.1, 1.0],
            variance_range=[0.5, 2.0],
            noise_range=[0.01, 0.1],
            device=str(self.device),
        )
    
    def _load_offline_data(self, data_path: str) -> DataLoader:
        """Load offline evaluation data."""
        dataset = GPBatchLoader(data_path, device=str(self.device))
        return DataLoader(
            dataset,
            batch_size=None,  # Pre-batched
            shuffle=False,
            num_workers=0,
        )
    
    def generate_offline_eval_data(self, save_path: str, num_batches: int = 100):
        """Generate offline evaluation data."""
        print(f"Generating {num_batches} evaluation batches to {save_path}")
        
        generate_offline_batches(
            save_dir=Path(save_path),
            num_batches=num_batches,
            batch_size=self.batch_size,
            chunk_size=10,
            # Sampler kwargs
            sampler_kwargs={
                "x_range": [[-2.0], [2.0]],
                "kernel_list": ["rbf", "matern52"],
                "kernel_weights": [0.6, 0.4],
                "lengthscale_range": [0.1, 1.0],
                "variance_range": [0.5, 2.0],
                "noise_range": [0.01, 0.1],
                "device": "cpu",  # Generate on CPU to save GPU memory
            },
            # Generation kwargs
            num_buffer=0,  # No buffer for evaluation
            num_target=50,  # More targets for comprehensive evaluation
            context_range=(10, 30),
        )
        
        # Update data path and loader
        self.data_path = save_path
        self.dataloader = self._load_offline_data(save_path)
    
    def evaluate_batch(self, batch: DataAttr) -> Dict[str, float]:
        """Evaluate a single batch with timing."""
        batch_size = batch.xc.shape[0]
        num_targets = batch.xt.shape[1]
        
        # Time cache preparation
        cache_start = time.time()
        self.inference_engine.prepare_inference_caches(batch, self.K)
        cache_time = time.time() - cache_start
        
        # Time full sequence sampling
        seq_start = time.time()
        predictions = self.inference_engine.sample_sequence(batch, K=self.K)
        seq_time = time.time() - seq_start
        
        # Run detailed profiling on a subset
        if len(self.metrics["timing"]["prefill_times"]) < 3:  # Profile first 3 batches
            profile_metrics = self._profile_inference_components(batch)
            for key, value in profile_metrics.items():
                if key in self.metrics["timing"]:
                    self.metrics["timing"][key].append(value)
        
        # Calculate metrics
        per_sample_time = seq_time / (batch_size * num_targets)
        
        # Store batch metrics
        batch_metrics = {
            "batch_size": batch_size,
            "num_targets": num_targets,
            "num_context": batch.xc.shape[1],
            "cache_prep_time": cache_time,
            "sequence_time": seq_time,
            "per_sample_time": per_sample_time,
            "throughput_samples_per_sec": (batch_size * num_targets) / seq_time,
        }
        
        # Calculate prediction error if ground truth available
        if hasattr(batch, "yt") and batch.yt is not None:
            # Handle case where sample_sequence doesn't predict all targets due to T // K
            num_predictions = predictions.yc.shape[1]
            if num_predictions < batch.yt.shape[1]:
                # Only compare the predictions we have
                batch_yt_truncated = batch.yt[:, :num_predictions, :]
                mse = torch.mean((predictions.yc - batch_yt_truncated) ** 2).item()
                mae = torch.mean(torch.abs(predictions.yc - batch_yt_truncated)).item()
                batch_metrics["warning"] = f"Only {num_predictions}/{batch.yt.shape[1]} targets predicted (T//K issue)"
            else:
                mse = torch.mean((predictions.yc - batch.yt) ** 2).item()
                mae = torch.mean(torch.abs(predictions.yc - batch.yt)).item()
            batch_metrics["mse"] = mse
            batch_metrics["mae"] = mae
        
        return batch_metrics, predictions
    
    def _profile_inference_components(self, batch: DataAttr) -> Dict[str, float]:
        """Profile individual components of inference."""
        import torch.nn.functional as F
        
        # We'll manually run through the inference steps with timing
        profile_times = {}
        
        # 1. Time context embedding
        embed_start = time.time()
        self.inference_engine.store_context_embeddings(batch)
        embed_time = time.time() - embed_start
        profile_times["embedding_times"] = embed_time
        
        # 2. Initialize KV cache
        T = batch.xt.shape[1]
        max_seq = self.inference_engine.context_embeddings.shape[1] + T
        self.inference_engine.init_kv_cache(batch.xt.shape[0], max_seq, device=str(batch.xt.device))
        
        # 3. Time prefill (self-attention)
        context_len = self.inference_engine.context_embeddings.shape[1]
        selfattention_mask = self.inference_engine._get_cached_selfattn_mask(
            context_len, str(batch.xt.device)
        )
        
        prefill_start = time.time()
        self.inference_engine.prefill_kv_cache(selfattention_mask)
        prefill_time = time.time() - prefill_start
        profile_times["prefill_times"] = prefill_time
        
        # 4. Profile a single decode step (using cache)
        # Get first target
        query = DataAttr(
            xc=None, yc=None, xb=None, yb=None,
            xt=batch.xt[:, 0:1, :],
            yt=batch.yt[:, 0:1, :] if batch.yt is not None else None
        )
        
        # Time embedding
        embed_start = time.time()
        query_embedding = self.inference_engine.embedder.embed_target(query)
        embed_single_time = time.time() - embed_start
        
        # Time transformer decode (cache-based)
        decode_start = time.time()
        z = self.inference_engine.transformer_decode(query_embedding)
        decode_time = time.time() - decode_start
        profile_times["transformer_decode_times"] = decode_time
        
        # Time head prediction
        head_start = time.time()
        head_output = self.inference_engine.head(z[:, -1, :].unsqueeze(1), num_samples=1)
        head_time = time.time() - head_start
        profile_times["head_times"] = head_time
        
        # Time context update
        yhat = head_output.samples.squeeze(2)
        from src.utils import create_context_buffer_datapoint
        prediction = create_context_buffer_datapoint(query, yhat)
        
        update_start = time.time()
        self.inference_engine.update_context_embeddings(prediction)
        update_time = time.time() - update_start
        profile_times["context_update_times"] = update_time
        
        return profile_times
    
    def run_evaluation(self):
        """Run full evaluation with comprehensive metrics."""
        print(f"\nStarting evaluation on {self.num_eval_batches} batches")
        print(f"Device: {self.device}")
        print(f"Batch size: {self.batch_size}")
        print(f"K (decoding batch size): {self.K}")
        
        all_metrics = []
        total_samples = 0
        eval_start = time.time()
        
        # Progress bar
        pbar = tqdm(range(self.num_eval_batches), desc="Evaluating")
        
        with torch.no_grad():
            for batch_idx in pbar:
                # Get batch (online or offline)
                if self.dataloader:
                    try:
                        batch = next(iter(self.dataloader))
                    except StopIteration:
                        # Reload dataloader if we run out
                        self.dataloader = self._load_offline_data(self.data_path)
                        batch = next(iter(self.dataloader))
                else:
                    # Generate online
                    batch = self.sampler.generate_batch(
                        batch_size=self.batch_size,
                        num_context=20,
                        num_buffer=0,
                        num_target=50,
                    )
                
                # Move to device if needed
                if batch.xc.device != self.device:
                    batch = DataAttr(
                        xc=batch.xc.to(self.device),
                        yc=batch.yc.to(self.device),
                        xb=batch.xb.to(self.device) if batch.xb is not None else None,
                        yb=batch.yb.to(self.device) if batch.yb is not None else None,
                        xt=batch.xt.to(self.device),
                        yt=batch.yt.to(self.device) if batch.yt is not None else None,
                    )
                
                # Evaluate batch
                batch_metrics, predictions = self.evaluate_batch(batch)
                all_metrics.append(batch_metrics)
                
                # Update timing metrics
                self.metrics["timing"]["sample_sequence_times"].append(batch_metrics["sequence_time"])
                self.metrics["timing"]["prepare_cache_times"].append(batch_metrics["cache_prep_time"])
                self.metrics["timing"]["per_sample_times"].append(batch_metrics["per_sample_time"])
                
                total_samples += batch_metrics["batch_size"] * batch_metrics["num_targets"]
                
                # Update progress bar
                avg_time = np.mean(self.metrics["timing"]["sample_sequence_times"])
                pbar.set_postfix({
                    "avg_seq_time": f"{avg_time:.3f}s",
                    "throughput": f"{total_samples / (time.time() - eval_start):.1f} samples/s",
                })
                
                # Optionally save some predictions
                if batch_idx < 5:  # Save first 5 batches of predictions
                    self.metrics["predictions"].append({
                        "batch_idx": batch_idx,
                        "xc": batch.xc.cpu().numpy().tolist(),
                        "yc": batch.yc.cpu().numpy().tolist(),
                        "xt": batch.xt.cpu().numpy().tolist(),
                        "yt": batch.yt.cpu().numpy().tolist() if batch.yt is not None else None,
                        "predictions": predictions.yc.cpu().numpy().tolist(),
                    })
        
        # Calculate final statistics
        eval_time = time.time() - eval_start
        self.metrics["timing"]["total_inference_time"] = eval_time
        self.metrics["evaluation"]["total_samples"] = total_samples
        self.metrics["evaluation"]["samples_per_second"] = total_samples / eval_time
        
        # Aggregate batch metrics
        self._aggregate_metrics(all_metrics)
        
        print(f"\nEvaluation complete in {eval_time:.2f}s")
        print(f"Total samples processed: {total_samples:,}")
        print(f"Average throughput: {total_samples / eval_time:.1f} samples/s")
        
        # Save results
        self.save_results()
    
    def _aggregate_metrics(self, all_metrics: List[Dict]):
        """Aggregate metrics across all batches."""
        # Timing statistics
        timing_keys = ["sequence_time", "cache_prep_time", "per_sample_time", "throughput_samples_per_sec"]
        for key in timing_keys:
            values = [m[key] for m in all_metrics]
            self.metrics["evaluation"][f"{key}_mean"] = float(np.mean(values))
            self.metrics["evaluation"][f"{key}_std"] = float(np.std(values))
            self.metrics["evaluation"][f"{key}_min"] = float(np.min(values))
            self.metrics["evaluation"][f"{key}_max"] = float(np.max(values))
            self.metrics["evaluation"][f"{key}_p50"] = float(np.percentile(values, 50))
            self.metrics["evaluation"][f"{key}_p90"] = float(np.percentile(values, 90))
            self.metrics["evaluation"][f"{key}_p99"] = float(np.percentile(values, 99))
        
        # Error metrics if available
        if "mse" in all_metrics[0]:
            for metric in ["mse", "mae"]:
                values = [m[metric] for m in all_metrics]
                self.metrics["evaluation"][f"{metric}_mean"] = float(np.mean(values))
                self.metrics["evaluation"][f"{metric}_std"] = float(np.std(values))
        
        # Data statistics
        self.metrics["data_stats"] = {
            "num_context_mean": float(np.mean([m["num_context"] for m in all_metrics])),
            "num_targets_mean": float(np.mean([m["num_targets"] for m in all_metrics])),
            "batch_size": self.batch_size,
        }
    
    def save_results(self):
        """Save evaluation results to disk."""
        # Save main metrics as JSON
        metrics_path = self.save_dir / "evaluation_metrics.json"
        with open(metrics_path, "w") as f:
            json.dump(self.metrics, f, indent=2)
        print(f"Saved metrics to {metrics_path}")
        
        # Save summary report
        self._generate_report()
        
        # Save timing data as CSV for easy analysis
        self._save_timing_csv()
    
    def _generate_report(self):
        """Generate human-readable evaluation report."""
        report_path = self.save_dir / "evaluation_report.txt"
        
        with open(report_path, "w") as f:
            f.write("=" * 80 + "\n")
            f.write("ACE MODEL EVALUATION REPORT\n")
            f.write("=" * 80 + "\n\n")
            
            # Model information
            f.write("MODEL INFORMATION\n")
            f.write("-" * 40 + "\n")
            f.write(f"Checkpoint: {self.metrics['model_stats']['checkpoint_path']}\n")
            f.write(f"Epoch: {self.metrics['model_stats']['epoch']}\n")
            f.write(f"Global Step: {self.metrics['model_stats']['global_step']}\n")
            f.write(f"Parameters: {self.metrics['model_stats']['num_parameters']:,}\n")
            f.write(f"Device: {self.metrics['model_stats']['device']}\n\n")
            
            # Evaluation settings
            f.write("EVALUATION SETTINGS\n")
            f.write("-" * 40 + "\n")
            f.write(f"Number of batches: {self.metrics['evaluation']['num_batches']}\n")
            f.write(f"Batch size: {self.metrics['evaluation']['batch_size']}\n")
            f.write(f"K (decode batch size): {self.metrics['evaluation']['K']}\n")
            f.write(f"Total samples: {self.metrics['evaluation']['total_samples']:,}\n\n")
            
            # Performance metrics
            f.write("PERFORMANCE METRICS\n")
            f.write("-" * 40 + "\n")
            f.write(f"Total evaluation time: {self.metrics['timing']['total_inference_time']:.2f}s\n")
            f.write(f"Overall throughput: {self.metrics['evaluation']['samples_per_second']:.1f} samples/s\n\n")
            
            # Timing statistics
            f.write("TIMING STATISTICS (seconds)\n")
            f.write("-" * 40 + "\n")
            timing_metrics = ["sequence_time", "cache_prep_time", "per_sample_time"]
            headers = ["Metric", "Mean", "Std", "Min", "Max", "P50", "P90", "P99"]
            f.write(f"{headers[0]:<20} " + " ".join(f"{h:>10}" for h in headers[1:]) + "\n")
            
            for metric in timing_metrics:
                f.write(f"{metric:<20} ")
                for stat in ["mean", "std", "min", "max", "p50", "p90", "p99"]:
                    key = f"{metric}_{stat}"
                    value = self.metrics["evaluation"].get(key, 0)
                    f.write(f"{value:>10.6f} ")
                f.write("\n")
            
            # Throughput statistics
            f.write(f"\n{'throughput (samples/s)':<20} ")
            for stat in ["mean", "std", "min", "max", "p50", "p90", "p99"]:
                key = f"throughput_samples_per_sec_{stat}"
                value = self.metrics["evaluation"].get(key, 0)
                f.write(f"{value:>10.1f} ")
            f.write("\n")
            
            # Error metrics if available
            if "mse_mean" in self.metrics["evaluation"]:
                f.write("\nPREDICTION ERROR METRICS\n")
                f.write("-" * 40 + "\n")
                f.write(f"MSE: {self.metrics['evaluation']['mse_mean']:.6f} ± {self.metrics['evaluation']['mse_std']:.6f}\n")
                f.write(f"MAE: {self.metrics['evaluation']['mae_mean']:.6f} ± {self.metrics['evaluation']['mae_std']:.6f}\n")
            
            # Detailed profiling results
            if self.metrics["timing"]["prefill_times"]:
                f.write("\nDETAILED PROFILING (seconds)\n")
                f.write("-" * 40 + "\n")
                
                # Calculate averages
                prefill_avg = np.mean(self.metrics["timing"]["prefill_times"])
                decode_avg = np.mean(self.metrics["timing"]["transformer_decode_times"])
                embed_avg = np.mean(self.metrics["timing"]["embedding_times"])
                head_avg = np.mean(self.metrics["timing"]["head_times"])
                update_avg = np.mean(self.metrics["timing"]["context_update_times"])
                
                f.write(f"Context Prefill (self-attention): {prefill_avg:.6f}\n")
                f.write(f"Transformer Decode (with cache):  {decode_avg:.6f}\n")
                f.write(f"Embedding Computation:            {embed_avg:.6f}\n")
                f.write(f"Head Prediction:                  {head_avg:.6f}\n")
                f.write(f"Context Update:                   {update_avg:.6f}\n")
                
                f.write(f"\nPrefill vs Cache Decode Ratio: {prefill_avg/decode_avg:.1f}x\n")
                f.write(f"Self-attention is {prefill_avg/decode_avg:.1f}x slower than cache-based decode\n")
                
                # Component breakdown for one decode step
                total_decode_step = decode_avg + head_avg
                f.write(f"\nSingle Token Generation Breakdown:\n")
                f.write(f"  Transformer: {decode_avg/total_decode_step*100:.1f}%\n")
                f.write(f"  Head:        {head_avg/total_decode_step*100:.1f}%\n")
                
                # Full sequence cost analysis
                avg_targets = self.metrics["data_stats"].get("num_targets_mean", 50)
                f.write(f"\nFull Sequence Generation Cost (avg {avg_targets:.0f} targets):\n")
                f.write(f"  Prefill (1x):     {prefill_avg:.6f}s\n")
                f.write(f"  Decode ({avg_targets:.0f}x):    {decode_avg * avg_targets:.6f}s\n")
                f.write(f"  Total:            {prefill_avg + decode_avg * avg_targets:.6f}s\n")
                f.write(f"  Prefill %:        {prefill_avg/(prefill_avg + decode_avg * avg_targets)*100:.1f}%\n")
                f.write(f"  Decode %:         {decode_avg * avg_targets/(prefill_avg + decode_avg * avg_targets)*100:.1f}%\n")
            
            f.write("\n" + "=" * 80 + "\n")
        
        print(f"Saved evaluation report to {report_path}")
    
    def _save_timing_csv(self):
        """Save detailed timing data as CSV."""
        import csv
        
        csv_path = self.save_dir / "timing_data.csv"
        
        with open(csv_path, "w", newline="") as f:
            writer = csv.writer(f)
            writer.writerow(["batch_idx", "sequence_time", "cache_prep_time", "per_sample_time"])
            
            for i in range(len(self.metrics["timing"]["sample_sequence_times"])):
                writer.writerow([
                    i,
                    self.metrics["timing"]["sample_sequence_times"][i],
                    self.metrics["timing"]["prepare_cache_times"][i],
                    self.metrics["timing"]["per_sample_times"][i],
                ])
        
        print(f"Saved timing data to {csv_path}")


def main():
    """Main evaluation script."""
    parser = argparse.ArgumentParser(description="Evaluate trained ACE model")
    parser.add_argument("checkpoint", type=str, help="Path to model checkpoint")
    parser.add_argument("--data-path", type=str, default=None, help="Path to offline evaluation data")
    parser.add_argument("--generate-data", action="store_true", help="Generate offline evaluation data")
    parser.add_argument("--device", type=str, default="cuda", help="Device to run on")
    parser.add_argument("--batch-size", type=int, default=32, help="Batch size for evaluation")
    parser.add_argument("--num-batches", type=int, default=100, help="Number of batches to evaluate")
    parser.add_argument("--K", type=int, default=4, help="Decoding batch size for sample_sequence")
    parser.add_argument("--save-dir", type=str, default="./eval_results", help="Directory to save results")
    parser.add_argument("--compile", action="store_true", help="Use torch.compile on inference methods")
    
    args = parser.parse_args()
    
    # Create evaluator
    evaluator = ModelEvaluator(
        checkpoint_path=args.checkpoint,
        data_path=args.data_path,
        device=args.device,
        batch_size=args.batch_size,
        num_eval_batches=args.num_batches,
        K=args.K,
        save_dir=args.save_dir,
        compile=args.compile,
    )
    
    # Generate offline data if requested
    if args.generate_data:
        data_save_path = Path(args.save_dir) / "eval_data"
        evaluator.generate_offline_eval_data(data_save_path, args.num_batches)
    
    # Run evaluation
    evaluator.run_evaluation()
    
    print("\nEvaluation complete! Results saved to:", args.save_dir)


if __name__ == "__main__":
    main()