import os
import json
import argparse
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint, Callback
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.strategies import DeepSpeedStrategy
from pytorch_lightning.callbacks import TQDMProgressBar
import glob


from model import PythiaModel


from mmap_dataset_lightning import setup_pythia_data


class CleanupCallback(Callback):
    def __init__(self, model_name):
        self.base_dir = f"~pythia_replicate/trained_models/pythia_output_{model_name}"

    def on_validation_end(self, trainer, pl_module):
        base_dir = self.base_dir
        
        if not os.path.isdir(base_dir):
            return
        for sub in os.listdir(base_dir):
            
            if (not sub.endswith(".ckpt")) or sub == "last.ckpt":
                continue

            ckpt_folder = os.path.join(base_dir, sub)
            
            pattern = os.path.join(ckpt_folder, "**", "*_optim_states.pt")
            for fp in glob.glob(pattern, recursive=True):
                try:
                    os.remove(fp)
                    print(f"[cleanup] removed {fp}")
                except FileNotFoundError:
                    pass


class DetailedProgressBar(TQDMProgressBar):
    def get_metrics(self, trainer, model):
        
        items = super().get_metrics(trainer, model)
        if trainer.optimizers:
            items["lr"] = trainer.optimizers[0].param_groups[0]["lr"]
        return items


def train_pythia(args):
    """Main training function for Pythia model using MMapIndexedDataset."""

    
    print(f"Loading config from {args.config}")
    with open(args.config, "r") as f:
        config = json.load(f)

    torch.set_float32_matmul_precision("medium")

    
    deepspeed_config = {
        "fp16": {
            **config["fp16"],  
            "initial_scale_power": 12,  
            "loss_scale_window": 1000,  
        },
        "zero_optimization": {
            "stage": (
                config["zero_optimization"]["stage"]
                if not config["reinitialize_optim_for_heads"]
                else 0
            ),
            "allgather_partitions": config["zero_optimization"].get(
                "allgather_partitions", True
            ),
            "allgather_bucket_size": config["zero_optimization"].get(
                "allgather_bucket_size", 200000000
            ),
            "overlap_comm": config["zero_optimization"].get("overlap_comm", True),
            "reduce_scatter": config["zero_optimization"].get("reduce_scatter", True),
            "reduce_bucket_size": config["zero_optimization"].get(
                "reduce_bucket_size", 200000000
            ),
            "contiguous_gradients": config["zero_optimization"].get(
                "contiguous_gradients", True
            ),
        },
        "gradient_clipping": config["gradient_clipping"],
        "train_micro_batch_size_per_gpu": config["train_micro_batch_size_per_gpu"],
        "steps_per_print": config["steps_per_print"],
        "wall_clock_breakdown": config["wall_clock_breakdown"],
    }

    
    if config["checkpoint-activations"]:
        deepspeed_config["activation_checkpointing"] = {
            "partition_activations": config["partition-activations"],
            "cpu_checkpointing": False,
            "contiguous_memory_optimization": False,
            "synchronize_checkpoint_boundary": config["synchronize-each-layer"],
        }
    
    pl.seed_everything(args.seed)

    
    print("Creating Pythia model...")
    model = PythiaModel.from_config(config=config)

    if args.save_step0_init:
        step0_dir = os.path.join(
            args.output_dir, f"pythia_output_{args.run_name}", "step0_init"
        )
        os.makedirs(step0_dir, exist_ok=True)
        model.model.save_pretrained(step0_dir)
        model.tokenizer.save_pretrained(step0_dir)  
        print(f"Saved random init HuggingFace weights to {step0_dir}")

    
    print("Setting up MMap data module...")
    
    if "train-data-paths" not in config or not config["train-data-paths"]:
        raise ValueError("Config must include 'train-data-paths'")

    
    if "use_single_file" not in config:
        config["use_single_file"] = True
        print("Using single file mode (using portion of training data for validation)")

    
    equiv_lookup = None
    if config["use_equivalence_bigram_masking"]:
        
        equiv_path = config["equivalence_lookup_path"]
        if os.path.exists(equiv_path):
            equiv_lookup = torch.load(equiv_path, map_location="cpu")
            print(f"Loaded equivalence lookup from {equiv_path}")
        else:
            print(
                "Warning: Equivalence lookup file not found. It will be created during model init."
            )

    
    print("Setting up MMap data module...")
    data_module = setup_pythia_data(config, equiv_lookup=equiv_lookup)
    

    
    checkpoint_callback = ModelCheckpoint(
        dirpath=os.path.join(args.output_dir, f"pythia_output_{args.run_name}"),
        filename="{step}",  
        save_top_k=-1,
        every_n_train_steps=config["checkpoint-factor"],
        save_last=True,
    )

    
    lr_monitor = LearningRateMonitor(logging_interval="step")

    
    logger = TensorBoardLogger(
        save_dir=os.path.join(args.output_dir, "tensorboard_logs"),
        name="pythia",
        version=args.run_name,
    )

    
    if torch.cuda.is_available() and torch.cuda.device_count() > 1:
        if config["zero_optimization"]["stage"] > 0:
            strategy = DeepSpeedStrategy(
                config=deepspeed_config,
            )
            print(strategy.initial_scale_power)
            print(f"Using DeepSpeed strategy with {torch.cuda.device_count()} GPUs")
        else:
            print(f"Using DDP strategy with {torch.cuda.device_count()} GPUs")
            strategy = "ddp"
    else:
        strategy = "auto"

    
    micro_batch = config["train_micro_batch_size_per_gpu"]
    gas = config["gas"]
    num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1
    num_nodes = args.num_nodes
    total_batch_size = micro_batch * gas * num_gpus * num_nodes
    print(
        f"Effective batch size: {micro_batch} (micro) × {gas} (accumulation) × {num_gpus} (GPUs) × {num_nodes} (nodes) = {total_batch_size}"
    )

    cbs = [checkpoint_callback, lr_monitor, DetailedProgressBar()]
    if args.save_weights_only:
        cbs.append(CleanupCallback(args.run_name))
        print("Adding cleanup callback")

    
    trainer = pl.Trainer(
        max_steps=config["train-iters"],
        gradient_clip_val=config["gradient_clipping"],
        accumulate_grad_batches=gas,
        limit_val_batches=20,
        val_check_interval=gas * 50,
        precision="16-mixed" if config["fp16"]["fp16"] else "32",
        accelerator="gpu" if torch.cuda.is_available() else "cpu",
        devices="auto",
        num_nodes=args.num_nodes,
        strategy=strategy,
        callbacks=cbs,
        logger=logger,
        log_every_n_steps=config["log-interval"],
        num_sanity_val_steps=0,
    )

    
    if args.resume_from_checkpoint:
        print(f"Resuming from checkpoint: {args.resume_from_checkpoint}")
        trainer.fit(
            model, datamodule=data_module, ckpt_path=args.resume_from_checkpoint
        )
    else:
        print("Starting training from scratch")
        trainer.fit(model, datamodule=data_module)

    
    model_save_path = os.path.join(args.output_dir, "final_model")
    print(f"Saving final model to {model_save_path}")
    model.model.save_pretrained(model_save_path)

    print("Training completed successfully!")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Train a Pythia model with MMapIndexedDataset"
    )
    parser.add_argument(
        "--config",
        type=str,
        default="~pythia_replicate/pythia-160m.json",
        help="Path to config file",
    )
    parser.add_argument(
        "--save_weights_only",
        action="store_true",
        help="Save weights only",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="./trained_models",
        help="Directory to save checkpoints and logs",
    )
    parser.add_argument(
        "--run_name",
        type=str,
        required=True,
        help="Name for this training run",
    )
    parser.add_argument(
        "--resume_from_checkpoint", type=str, help="Path to checkpoint to resume from"
    )
    parser.add_argument("--seed", type=int, default=1234, help="Random seed")
    parser.add_argument(
        "--num_nodes",
        type=int,
        default=1,
        help="Number of nodes for distributed training",
    )
    parser.add_argument(
        "--save_step0_init",
        action="store_true",
        help="Save step0 init",
    )

    args = parser.parse_args()

    
    os.makedirs(args.output_dir, exist_ok=True)

    train_pythia(args)
