import os
import re

import lightning as L
import torch
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger

import wandb
from model import ItrAttention
from util import load_config

config_name = os.environ.get("CONFIG_FILENAME", "arc-agi1")
L.seed_everything(42, workers=True)  # For reproducibility


@load_config(config_path="./config", config_name=config_name)
def main(config):

    num_heads = config.num_heads
    embed_dim = config.embed_dim
    betas = config.betas
    weight_decay = config.weight_decay
    lr = config.lr
    batch_size = config.batch_size
    num_workers = config.num_workers
    accumulate_grad_batches = config.accumulate_grad_batches
    gradient_clip_val = config.gradient_clip_val
    precision = config.precision
    num_nodes = config.num_nodes
    max_epochs = config.max_epochs
    num_iter = config.num_iter
    beta = config.beta
    test_only = config.test_only
    train_dataset_name = config.train_dataset_name
    test_dataset_name = config.test_dataset_name
    check_val_every_n_epoch = config.check_val_every_n_epoch
    resume = config.resume
    use_compile = config.use_compile
    use_best = config.use_best
    test_batch_size = config.test_batch_size
    update_after_step = (
        config.update_after_step if hasattr(config, "update_after_step") else 100
    )
    update_every = config.update_every if hasattr(config, "update_every") else 10
    num_remain_grad = (
        config.num_remain_grad if hasattr(config, "num_remain_grad") else 8
    )
    multi_votes_exp_exponent = (
        config.multi_votes_exp_exponent
        if hasattr(config, "multi_votes_exp_exponent")
        else 6
    )
    min_votes_exponent = (
        config.min_votes_exponent if hasattr(config, "min_votes_exponent") else 0
    )
    max_votes_exponent = (
        config.max_votes_exponent if hasattr(config, "max_votes_exponent") else 10
    )
    log_every_n_steps = (
        config.log_every_n_steps if hasattr(config, "log_every_n_steps") else 50
    )
    ffn_dim_multiplier = (
        config.ffn_dim_multiplier if hasattr(config, "ffn_dim_multiplier") else 4
    )
    use_cross_attn = (
        config.use_cross_attn if hasattr(config, "use_cross_attn") else True
    )
    num_rep_attn = config.num_rep_attn if hasattr(config, "num_rep_attn") else 4
    num_layers = config.num_layers if hasattr(config, "num_layers") else 1
    use_mpc = config.use_mpc if hasattr(config, "use_mpc") else False
    mpc_every = config.mpc_every if hasattr(config, "mpc_every") else 4
    use_transformer = (
        config.use_transformer if hasattr(config, "use_transformer") else False
    )
    no_truncation = config.no_truncation if hasattr(config, "no_truncation") else False
    confidence_type = (
        config.confidence_type if hasattr(config, "confidence_type") else "max_prob"
    )

    if train_dataset_name in ["sudoku", "sudoku-hard", "sudoku-extreme"]:
        L_sqrt = 9
        num_vocab = 10
        num_classes = 9
    elif train_dataset_name == "maze":
        L_sqrt = 30
        num_vocab = 4
        num_classes = 5
    elif train_dataset_name in ["arc", "arc2"]:
        L_sqrt = 30
        num_vocab = 12
        num_classes = 12

    print(L_sqrt, num_vocab, num_classes)

    name = "_".join(
        [
            f"tag_{config.tag}",
            f"head_{num_heads}",
            f"dim_{embed_dim}",
            f"bs_{batch_size}",
            f"weight_decay_{weight_decay}",
            f"lr_{lr}",
            f"use_compile_{use_compile}",
            f"train_{train_dataset_name}",
            f"test_{test_dataset_name}",
            f"num_layers_{num_layers}",
            f"num_rep_attn_{num_rep_attn}",
            f"use_mpc_{use_mpc}",
            f"use_transformer_{use_transformer}",
            f"confidence_type_{confidence_type}",
        ]
    )

    print(config)

    model = ItrAttention(
        num_heads=num_heads,
        embed_dim=embed_dim,
        betas=betas,
        weight_decay=weight_decay,
        lr=lr,
        batch_size=batch_size // accumulate_grad_batches,
        num_workers=num_workers,
        num_iter=num_iter,
        beta=beta,
        use_bias=False,
        train_dataset_name=train_dataset_name,
        test_dataset_name=test_dataset_name,
        L_sqrt=L_sqrt,
        num_vocab=num_vocab,
        num_classes=num_classes,
        use_compile=use_compile,
        num_remain_grad=num_remain_grad,
        update_after_step=update_after_step,
        update_every=update_every,
        ffn_dim_multiplier=ffn_dim_multiplier,
        use_cross_attn=use_cross_attn,
        num_rep_attn=num_rep_attn,
        num_layers=num_layers,
        use_mpc=use_mpc,
        mpc_every=mpc_every,
        use_transformer=use_transformer,
        no_truncation=no_truncation,
        confidence_type=confidence_type,
    )

    num_devices = torch.cuda.device_count() if torch.cuda.is_available() else 1
    print(f"num_devices: {num_devices}")

    checkpoint_callback = ModelCheckpoint(
        monitor="val_loss",
        dirpath=f"checkpoints/{name}/",
        filename="best",
        save_top_k=1,
        save_last=True,
    )
    if torch.cuda.is_available():
        if use_compile:
            strategy = "ddp"
        else:
            strategy = "ddp_find_unused_parameters_true"
    else:
        strategy = "auto"

    if not test_only:
        trainer = L.Trainer(
            max_epochs=max_epochs,
            accelerator="cuda" if torch.cuda.is_available() else "mps",
            devices=num_devices,
            num_nodes=num_nodes,
            strategy=strategy,
            logger=WandbLogger(
                project="sudoku_deq_ebt",
                name=name,
                log_model=False,
                save_code=True,
                settings=wandb.Settings(code_dir="./src"),
            ),
            callbacks=[checkpoint_callback],
            enable_progress_bar=True,
            check_val_every_n_epoch=check_val_every_n_epoch,
            precision=precision,
            accumulate_grad_batches=accumulate_grad_batches,
            gradient_clip_val=gradient_clip_val,
            log_every_n_steps=log_every_n_steps,
        )
        if resume:
            trainer.fit(model, ckpt_path="last")
        else:
            trainer.fit(model)

        print(f"{checkpoint_callback.best_model_path=}")
        print(f"{checkpoint_callback.last_model_path=}")

        ckpt_path = (
            checkpoint_callback.best_model_path
            if use_best
            else checkpoint_callback.last_model_path
        )

        for i in range(7):
            for j in range(min_votes_exponent, max_votes_exponent + 1):
                if i != multi_votes_exp_exponent and j > 0:
                    continue
                num_iter = int(2**i)
                num_votes = int(2**j)
                model = ItrAttention.load_from_checkpoint(
                    checkpoint_path=ckpt_path,
                    batch_size=int(test_batch_size * 2 ** (max_votes_exponent - j - 1)),
                    num_iter_test=num_iter,
                    num_votes_test=num_votes,
                    max_votes_exponent=max_votes_exponent,
                    num_rep_attn=num_rep_attn,
                )
                trainer = L.Trainer(
                    accelerator="cuda" if torch.cuda.is_available() else "cpu",
                    devices=num_devices,
                    num_nodes=num_nodes,
                    strategy=strategy,
                    logger=WandbLogger(
                        project="sudoku_deq_ebt",
                        name=name,
                        log_model=False,
                        save_code=True,
                        settings=wandb.Settings(code_dir="./src"),
                    ),
                    enable_progress_bar=True,
                    precision=precision,
                    inference_mode=False,
                )
                trainer.test(model)
    else:
        ckpt_prefix = "best" if use_best else "last"
        ckpt_dir = os.path.join("checkpoints", name)
        files = os.listdir(ckpt_dir)
        pattern = re.compile(rf"{ckpt_prefix}(?:-v(\d+))?\.ckpt$")
        candidates = [f for f in files if pattern.match(f)]

        def version_key(name: str) -> int:
            m = pattern.match(name)
            return int(m.group(1) or 0)

        latest = max(candidates, key=version_key)

        print(f"checkpoint path: {latest}")

        for i in range(7):
            for j in range(min_votes_exponent, max_votes_exponent + 1):
                if i != multi_votes_exp_exponent and j > 0:
                    continue
                num_iter = int(2**i)
                num_votes = int(2**j)
                model = ItrAttention.load_from_checkpoint(
                    checkpoint_path=os.path.join(ckpt_dir, latest),
                    batch_size=int(test_batch_size * 2 ** (max_votes_exponent - j - 1)),
                    max_votes_exponent=max_votes_exponent,
                    num_rep_attn=num_rep_attn,
                    num_iter_test=num_iter,
                    num_votes_test=num_votes,
                )
                trainer = L.Trainer(
                    accelerator="cuda" if torch.cuda.is_available() else "cpu",
                    devices=num_devices,
                    num_nodes=num_nodes,
                    strategy=strategy,
                    logger=WandbLogger(
                        project="sudoku_deq_ebt",
                        name=name,
                        log_model=False,
                        save_code=True,
                        settings=wandb.Settings(code_dir="./src"),
                    ),
                    enable_progress_bar=True,
                    precision=precision,
                )
                trainer.test(model)


if __name__ == "__main__":
    main()
