#!/usr/bin/env python3
"""
Training script for SAE with sharded HDF5 datasets.
Preserves all SAE training logic while only changing data loading.
"""
import argparse
import math
import numpy as np
import os
import random
import sys
import glob
from datetime import datetime
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))

from typing import Union
import yaml
import h5py
import lightning.pytorch as pl
from lightning.pytorch import seed_everything
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger

from Group_SAE.SAETran_model_v2 import (
    BufferUpdateCallback,
    LitSAEWithChannelNew,
)
from Simtransformer.simtransformer.utils import EasyDict, clean_config, clever_load, StepCheckpointCallback
from core.sharded_data_module import (
    ShardedBufferedBatchHDF5DataModule,
    DirectShardedHDF5DataModule
)

# Try to import memory block module
try:
    from core.memory_block_data_module import MemoryBlockDataModule
    MEMORY_BLOCK_AVAILABLE = True
except ImportError:
    MEMORY_BLOCK_AVAILABLE = False
    print("Memory block data module not available")


def float_or_str(val: str) -> Union[float, str]:
    try:
        return float(val)
    except ValueError:
        return val


def float_or_none(val: str) -> Union[float, None]:
    try:
        return float(val)
    except ValueError:
        return None


def parse_args(debug: bool) -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="SAE Training with Sharded Datasets")

    # Data arguments
    parser.add_argument("--shard_pattern", type=str, required=True,
                        help="Pattern to match shard files (e.g., '/path/to/data_L12_shard_*.h5')")
    parser.add_argument("--use_buffered", action="store_true",
                        help="Use buffered loading (legacy option, consider using memory_block for better performance)")
    parser.add_argument("--use_memory_block", action="store_true", default=True,
                        help="Use memory block loading for pre-shuffled data (default, fastest)")
    parser.add_argument("--no_memory_block", action="store_true",
                        help="Disable memory block loading and use direct loading instead")
    parser.add_argument("--block_size_gb", type=float, default=4.0,
                        help="Memory block size in GB for memory block loading")
    parser.add_argument("--batch_size", type=int, default=128)
    parser.add_argument("--num_workers", type=int, default=4)
    parser.add_argument("--buffer_in_GB", type=float, default=2.0,
                        help="Buffer size in GB (only used with --use_buffered)")
    parser.add_argument("--split_ratio", type=float, default=0.999)

    # High norm filtering arguments
    parser.add_argument("--remove_high_norm", type=float, default=None,
                        help="Remove samples with norm > median * this factor (e.g., 10.0 for Qwen models)")
    parser.add_argument("--filter_cache_size", type=int, default=10000,
                        help="Number of samples to use for computing median norm when filtering")
    parser.add_argument("--normalize_by_mean", action="store_true", default=False,
                        help="Normalize data by dividing by the mean norm (computed from sampled data)")

    # Model arguments
    parser.add_argument("--num_neurons", type=int, default=65536)
    parser.add_argument("--activation", type=str, default="relu")
    parser.add_argument("--normalize_W_enc", action="store_true")
    parser.add_argument("--use_tunable_threshold_activation", action="store_true")
    parser.add_argument("--topk", type=int, default=None)

    # Training arguments
    parser.add_argument("--max_epochs", type=int, default=3)
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--L1_decay", type=float, default=0.0)
    parser.add_argument("--tune_b_enc", action="store_true", help="Whether allow gradient on b_enc")
    parser.add_argument("--tune_b_dec", action="store_true")
    parser.add_argument("--adjust_b_enc", action="store_true")

    # Normalization arguments
    parser.add_argument("--normalize_batch", action="store_true")
    parser.add_argument("--normalize_batch_with_tanh_threshold", type=float_or_none, default=None)
    parser.add_argument("--divide_by", type=float, default=1.0)

    # Adjustment arguments
    parser.add_argument("--freq_threshold_high", type=float_or_str, default="0.1")
    parser.add_argument("--factor_down", type=float, default=0.08)
    parser.add_argument("--factor_up", type=float, default=0.0001)
    parser.add_argument("--init_FTH", type=float, default=0.2)
    parser.add_argument("--end_FTH", type=float, default=1e-3)
    parser.add_argument("--num_groups", type=int, default=10)
    parser.add_argument("--clamp_b_enc_max", type=float_or_none, default=None)
    parser.add_argument("--clamp_b_enc_min", type=float_or_none, default=None)

    # Experiment arguments
    parser.add_argument("--wandb_project", type=str, default="SAE-Sharded")
    parser.add_argument("--wandb_entity", type=str, default="SAE_atomic",
                        help="Wandb entity (organization/team name)")
    parser.add_argument("--exp_name", type=str, default="SAE-Sharded")
    parser.add_argument("--seed", type=int, default=None)
    parser.add_argument("--ckpt_folder", type=str, default="checkpoints_sharded")
    parser.add_argument("--resume_ckpt", type=str, default=None,
                        help="Path to checkpoint to resume from")

    args = parser.parse_args()

    if debug:
        args.batch_size = 2
        args.num_workers = 0
        args.max_epochs = 1
        args.buffer_in_GB = 0.5
        args.split_ratio = 0.999
        args.wandb_project="<PROJECT>"
        args.seed = 42

    return args


def infer_data_info_from_shards(shard_pattern: str) -> dict:
    """Infer dataset info from shard files."""
    shard_files = sorted(glob.glob(shard_pattern))
    if not shard_files:
        raise ValueError(f"No shard files found matching pattern: {shard_pattern}")

    # Read first shard to get dimensions
    with h5py.File(shard_files[0], 'r') as f:
        first_sample = f["non_padding_cache"][0]
        dimensions = first_sample.shape[-1]
        dtype = first_sample.dtype

    # Count total samples across all shards
    total_samples = 0
    for shard_file in shard_files:
        with h5py.File(shard_file, 'r') as f:
            total_samples += f["non_padding_cache"].shape[0]

    data_info = {
        'dimensions': dimensions,
        'length': total_samples,
        'data_type': str(dtype),
        'num_shards': len(shard_files)
    }

    print(f"Dataset info inferred from {len(shard_files)} shards:")
    print(f"  Dimensions: {dimensions}")
    print(f"  Total samples: {total_samples}")
    print(f"  Data type: {dtype}")

    return data_info


def compute_buffer_params(num_samples: int, hidden_size: int, gb: float):
    bytes_per_sample = hidden_size * 4  # float32
    size_in_samples = int(gb * 1024 ** 3 // bytes_per_sample)
    buffers_per_epoch = math.ceil(num_samples / size_in_samples)
    return size_in_samples, buffers_per_epoch


def build_configs(args: argparse.Namespace, data_info: dict):
    """Build configuration dictionaries for model and training."""

    # Determine which loading method to use
    use_memory_block = args.use_memory_block and not args.no_memory_block

    # Data config
    if use_memory_block:
        # Memory block loading config (for pre-shuffled data)
        buf_epochs = 1  # No buffering epochs for memory block
        data_cfg = EasyDict({
            'shard_pattern': args.shard_pattern,
            'batch_size': args.batch_size,
            'block_size_gb': args.block_size_gb,
            'num_workers': args.num_workers,
            'split_ratio': args.split_ratio,
            'seed': args.seed,
            'data_type': data_info['data_type'],
            'num_samples': int(data_info['length'] * args.split_ratio),
            'use_memory_block': True,
            'use_buffered': False,
            'remove_high_norm': args.remove_high_norm,
            'filter_cache_size': args.filter_cache_size if args.remove_high_norm else None,
            'normalize_by_mean': args.normalize_by_mean,
        })
    elif args.use_buffered:
        # Buffered loading config
        buf_samples, buf_epochs = compute_buffer_params(
            int(data_info['length'] * args.split_ratio),
            data_info['dimensions'],
            args.buffer_in_GB
        )
        data_cfg = EasyDict({
            'shard_pattern': args.shard_pattern,
            'buffer_size_samples': buf_samples,
            'batch_size': args.batch_size,
            'num_workers': args.num_workers,
            'split_ratio': args.split_ratio,
            'seed': args.seed,
            'data_type': data_info['data_type'],
            'num_samples': int(data_info['length'] * args.split_ratio),
            'use_buffered': True,
            'use_memory_block': False,
            'remove_high_norm': args.remove_high_norm,
            'filter_cache_size': args.filter_cache_size if args.remove_high_norm else None,
            'normalize_by_mean': args.normalize_by_mean,
        })
    else:
        # Direct loading config
        buf_epochs = 1
        data_cfg = EasyDict({
            'shard_pattern': args.shard_pattern,
            'batch_size': args.batch_size,
            'num_workers': args.num_workers,
            'split_ratio': args.split_ratio,
            'seed': args.seed,
            'data_type': data_info['data_type'],
            'num_samples': int(data_info['length'] * args.split_ratio),
            'use_buffered': False,
            'use_memory_block': False,
            'remove_high_norm': args.remove_high_norm,
            'filter_cache_size': args.filter_cache_size if args.remove_high_norm else None,
            'normalize_by_mean': args.normalize_by_mean,
        })

    # Model config
    model_cfg = EasyDict({
        'hidden_size': data_info['dimensions'],
        'num_neurons': args.num_neurons,
        'activation': args.activation,
        'channel_size_ls': [],
        'use_neuron_weight': True,
    })

    # Training config (preserving original logic)
    train_cfg = EasyDict(clever_load(
        os.path.join(
            sys.path[0],
            '..', '..',
            'Simtransformer/simtransformer/configurations',
            'train_config_default.yaml')
    ))

    total_steps = int(data_info['length'] // args.batch_size + 1) * args.max_epochs
    train_cfg.update({
        'freq_threshold_high': args.freq_threshold_high,
        'max_epochs': args.max_epochs,
        'batch_size': data_cfg.batch_size,
        'optimizer': 'AdamW',
        'AdamW_optimizer_config': {
            'lr': args.lr,
            'weight_decay': 0.01,
            'betas': [0.9, 0.999],
        },
        'GroupedAdamW_optimizer_config': {
            'lr': args.lr,
            'weight_decay': 0.01,
            'betas': [0.9, 0.999],
        },
        'wandb_config': {
            'wandb_project': args.wandb_project,
            'wandb_entity': args.wandb_entity,
        },
        'num_neuron_vis': 50,
        'seed': args.seed,
        'lr_scheduler': 'cosine',
        'cosine_scheduler_config': {
            'lr_decay_steps': total_steps - 1000 if total_steps > 1000 else total_steps,
            'min_lr': args.lr,
            'warmup_steps': 1000 if total_steps > 1000 else 0,
        },
        'normalize_W_enc': args.normalize_W_enc,
        'SAE_output_scale': 1,
        'b_dec_zero': not args.tune_b_dec,
        'tune_b_dec': args.tune_b_dec,
        'save_phase_end_ckpt': False,
        'save_phase_start_ckpt': False,
    })

    # Add optional normalization settings
    if args.normalize_batch_with_tanh_threshold is not None:
        train_cfg.update({
            'normalize_batch_with_tanh_threshold': args.normalize_batch_with_tanh_threshold,
        })

    if hasattr(args, 'divide_by'):
        train_cfg.update({'divide_by': args.divide_by})

    if args.normalize_batch:
        train_cfg.update({'normalize_batch': args.normalize_batch})

    # Phase 0 config
    train_cfg.update({
        'phase_0_config': {
            'start_step': 0,
            'end_step': 2 * total_steps,
            'tune_W_enc': True,
            'tune_b_enc': args.tune_b_enc,
            'use_alignment_loss': False,
            "use_tunable_threshold_activation": args.use_tunable_threshold_activation,
            'adjust_b_enc_config': {
                'adjust_b_enc': args.adjust_b_enc,
                'group_size': 1,
                'group_partitions': [1.0],
                'interval': 50,
                'factor_up': args.factor_up,
                'factor_down': args.factor_down,
                'group_1': {
                    'freq_threshold_low': 1 / (100 * data_cfg.batch_size),
                    'freq_threshold_high': args.freq_threshold_high,
                },
            },
            'L1_decay': args.L1_decay,
        },
    })

    # Clamping settings
    if args.clamp_b_enc_max is not None:
        train_cfg.phase_0_config.update({'clamp_b_enc_max': args.clamp_b_enc_max})
    if args.clamp_b_enc_min is not None:
        train_cfg.phase_0_config.update({'clamp_b_enc_min': args.clamp_b_enc_min})

    # TopK setting
    if args.topk is not None:
        train_cfg.update({'topk': args.topk})

    # Mixed mode adjustment group
    if not isinstance(args.freq_threshold_high, float) and args.freq_threshold_high == "mixed":
        num_groups = args.num_groups
        train_cfg.phase_0_config.adjust_b_enc_config.update({
            'adjust_b_enc': args.adjust_b_enc,
            'group_partitions': [1.0 / num_groups for _ in range(num_groups)],
            'interval': 50,
            'factor_up': args.factor_up,
            'factor_down': args.factor_down,
        })

        # Set freq_threshold_high for each group
        init_FTH = args.init_FTH
        end_FTH = args.end_FTH
        freq_threshold_high = np.logspace(np.log10(init_FTH), np.log10(end_FTH), num_groups)
        for i in range(num_groups):
            train_cfg.phase_0_config.adjust_b_enc_config.update({
                f'group_{i + 1}': {
                    'freq_threshold_low': 1 / (data_info['length'] * 0.1),
                    'freq_threshold_high': float(freq_threshold_high[i]),
                },
            })

    # Combine all configs
    cfg = EasyDict({'data_config': data_cfg, **model_cfg, **train_cfg})
    return cfg, buf_epochs


def get_experiment_name(cfg: EasyDict, args: argparse.Namespace = None) -> str:
    return args.exp_name


def get_run_name(seed: int) -> str:
    return f"seed-{seed}_{datetime.now():%Y%m%d_%H%M%S}"


def prepare_trainer(cfg: EasyDict, exp_name: str, run_name: str, buffers: int, total_steps: int, args) -> pl.Trainer:
    """Prepare Lightning trainer with callbacks and logger."""

    logger = WandbLogger(
        project=cfg.wandb_config.wandb_project,
        entity=cfg.wandb_config.wandb_entity,
        group=exp_name,
        name=run_name,
        config=cfg,
    )

    ckpt_dir = os.path.join(
        os.path.dirname(os.path.abspath(__file__)),
        '..',
        args.ckpt_folder,
        f"{exp_name}-{run_name}"
    )
    os.makedirs(ckpt_dir, exist_ok=True)
    yaml.dump(clean_config(cfg), open(os.path.join(ckpt_dir, 'config.yaml'), 'w'))
    cfg.checkpoint_dir = ckpt_dir

    # Calculate checkpoint steps
    ckpt_steps = np.arange(0, total_steps, step=total_steps // 5)
    ckpt_steps = ckpt_steps[1:][:-1]  # Remove first and last
    ckpt_steps = [int(x) for x in ckpt_steps]

    callbacks = [
        LearningRateMonitor(logging_interval='step'),
        ModelCheckpoint(
            dirpath=ckpt_dir,
            filename='{step}-{val_loss:.4f}',
            monitor='val_loss',
            mode='min',
            save_top_k=1,
            every_n_train_steps=500,
            save_last=True,
        ),
        StepCheckpointCallback(
            dirpath=ckpt_dir,
            ckpt_steps=ckpt_steps,
        )
    ]

    # Add buffer update callback if using buffered loading
    if cfg.data_config.get('use_buffered', False):
        callbacks.append(BufferUpdateCallback())

    # For IterableDataset, val_check_interval must be 1.0 or an int
    # Using int means check validation every N training batches
    val_check_interval = 1000  # Check validation every 1000 batches

    trainer = pl.Trainer(
        max_steps=total_steps,
        logger=logger,
        callbacks=callbacks,
        val_check_interval=val_check_interval,
        precision=getattr(cfg, 'precision', '32'),
        accelerator='gpu',
        log_every_n_steps=30,
    )

    return trainer


def main(debug: bool = False):
    args = parse_args(debug=debug)

    # Infer data info from shards
    data_info = infer_data_info_from_shards(args.shard_pattern)

    # Build configs
    cfg, buf_epochs = build_configs(args, data_info)

    # Set seed
    seed = args.seed or random.randint(0, 2**32 - 1)
    args.seed = seed
    seed_everything(seed)
    print(f"Using seed {seed}")

    # Create experiment and run names
    exp_name = get_experiment_name(cfg, args)
    run_name = get_run_name(seed)

    # Create data module based on loading choice
    if args.remove_high_norm is not None:
        print(f"High norm filtering enabled: remove_high_norm={args.remove_high_norm}")
        print(f"Using {args.filter_cache_size:,} samples for median computation")

    if args.normalize_by_mean:
        print(f"Mean normalization enabled: data will be divided by the mean norm")
        print(f"Using {args.filter_cache_size:,} samples for computing mean norm")

    # Handle the no_memory_block flag to override default
    use_memory_block = args.use_memory_block and not args.no_memory_block

    if use_memory_block and MEMORY_BLOCK_AVAILABLE:
        print(f"Using memory block loading for pre-shuffled data (RECOMMENDED)")
        print(f"Block size: {args.block_size_gb}GB")
        data_module = MemoryBlockDataModule(
            shard_pattern=args.shard_pattern,
            batch_size=args.batch_size,
            block_size_gb=args.block_size_gb,
            num_workers=args.num_workers,
            split_ratio=args.split_ratio,
            use_iterable=True,  # Use streaming for efficiency
            prefetch_factor=2,
            persistent_workers=(args.num_workers > 0),
            remove_high_norm=args.remove_high_norm,
            filter_cache_size=args.filter_cache_size,
            normalize_by_mean=args.normalize_by_mean
        )
        max_epochs_adjusted = args.max_epochs
    elif args.use_buffered:
        print("Using buffered sharded data loading (DEPRECATED - consider using memory block)")
        data_module = ShardedBufferedBatchHDF5DataModule(**cfg.data_config)
        max_epochs_adjusted = args.max_epochs * buf_epochs
    else:
        print("Using direct sharded data loading (for data requiring runtime shuffling)")
        data_module = DirectShardedHDF5DataModule(**cfg.data_config)
        max_epochs_adjusted = args.max_epochs

    # Setup data module to compute norm statistics
    data_module.setup()

    # Get norm statistics and add to config if using memory block
    if use_memory_block and MEMORY_BLOCK_AVAILABLE:
        norm_stats = data_module.get_norm_statistics()
        if norm_stats:
            cfg.data_config.update(norm_stats)
            print(f"\nComputed norm statistics saved to config:")
            for key, value in norm_stats.items():
                print(f"  {key}: {value:.4f}")

    # Create model (preserving original SAE model)
    model = LitSAEWithChannelNew(cfg, feat_set=None, split_neuron_by_group=True)

    # Calculate total steps
    total_steps = int(data_info['length'] * cfg.data_config.split_ratio // args.batch_size + 1) * args.max_epochs

    # Prepare trainer (this will save the updated config with norm statistics)
    trainer = prepare_trainer(cfg, exp_name, run_name, buf_epochs, total_steps, args)

    # Train
    print(f"Starting trainer.fit() at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"Data module type: {type(data_module).__name__}")
    print(f"Number of workers: {args.num_workers}")
    print(f"Batch size: {args.batch_size}")

    if args.resume_ckpt:
        print(f"Resuming from checkpoint: {args.resume_ckpt}")
        trainer.fit(model, data_module, ckpt_path=args.resume_ckpt)
    else:
        print(f"Starting fresh training...")
        trainer.fit(model, data_module)

    print(f"Training completed at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")


def load_model(ckpt_path: str, split_neuron_by_group: bool = True, load_data: bool = True):
    """Load a trained model from checkpoint."""
    parent_dir = os.path.dirname(ckpt_path)
    config_path = os.path.join(parent_dir, 'config.yaml')
    config = EasyDict(clever_load(config_path))

    if load_data:
        if config.data_config.get('use_memory_block', False) and MEMORY_BLOCK_AVAILABLE:
            # Load with memory block module (best for pre-shuffled data)
            data_module = MemoryBlockDataModule(**config.data_config)
        elif config.data_config.get('use_buffered', False):
            # Legacy buffered loading
            data_module = ShardedBufferedBatchHDF5DataModule(**config.data_config)
        else:
            # Direct loading (supports runtime shuffling)
            data_module = DirectShardedHDF5DataModule(**config.data_config)
    else:
        data_module = None

    model = LitSAEWithChannelNew.load_from_checkpoint(
        ckpt_path,
        config=config,
        feat_set=None,
        split_neuron_by_group=split_neuron_by_group
    )

    return config, model, data_module


if __name__ == '__main__':
    main(debug=False)