import torch
import sys
import os

project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if project_root not in sys.path:
    sys.path.append(project_root)
from exp_utils.setup_exp import set_exp
import hydra
from omegaconf import DictConfig, OmegaConf
from exp_utils.utils import get_opt_scheduler, instantiate_filtered
from experiments.ar_ic.train import train_model
from exp_utils.utils import CustomNLLLoss, save_checkpoint, load_checkpoint
from data.ass_recall.ar_dataloader import InContextAssociationDataloader


def print_model_parameters(module):
    if not hasattr(module, "parameters"):
        print("The provided object does not have parameters (it might not be a model).")
        return

    if hasattr(module, "named_parameters"):
        for name, param in module.named_parameters():
            print(f"Layer: {name} | Size: {param.size()} | Requires Grad: {param.requires_grad}")
    else:
        for idx, param in enumerate(module.parameters()):
            print(
                f"Layer: Unnamed_{idx} | Size: {param.size()} | Requires Grad: {param.requires_grad}"
            )


@hydra.main(version_base=None, config_path="conf", config_name="config")
def main(config: DictConfig) -> None:
    # Hydra / OmegaConf housekeeping
    OmegaConf.set_struct(config, False)

    # 1. Set up experiment in W&B (including possible "rewind" if run already exists)
    wandb_writer, last_ckpt_path = set_exp(config)

    # For associative recall task, block_size is always 2 * n_pairs + 2
    config.dataset.block_size = 2 * config.dataset.n_pairs + 2

    device = torch.device(config.run.device)

    # 2. Create the model
    model = instantiate_filtered(config.model, config.dataset).to(device)
    print_model_parameters(model)

    # Optional PyTorch 2.0 compile
    # model = torch.compile(model)

    # 3. Create the optimizer and scheduler
    optimizer, scheduler = get_opt_scheduler(
        config.training.lr,
        config.training.wd,
        config.training.iters,
        config.training.scheduler,
        config.training.opt,
        model,
    )

    # We'll track which iteration we are on
    iteration = 0

    # 4. If there's a checkpoint, load it (restores RNG, model, optimizer, etc.)
    if last_ckpt_path and os.path.exists(last_ckpt_path):
        iteration = load_checkpoint(last_ckpt_path, model, optimizer)
        print(f"Resuming from iteration {iteration}")
    else:
        print("No checkpoint found. Starting from iteration 0 (scratch).")

    # 5. Now that RNG is restored, create any random-based components (data loaders, etc.)
    # We still want to run a total of config.training.iters steps *from now*.
    # If we already did some steps, let's define how many more we want:
    remaining_iters = config.training.iters - iteration
    if remaining_iters <= 0:
        print("All iterations completed in previous runs!")
        return  # or do something else

    # Create the Associative Recall dataloader
    print("\n--- Creating Associative Recall Dataloader ---")

    train_loader = InContextAssociationDataloader(
        vocab_size=config.dataset.vocab_size,
        n_pairs=config.dataset.n_pairs,
        device=device,
        batch_size=config.training.batch_size,
        block_size=config.dataset.block_size,
        iters=remaining_iters,
        train_data_size=config.training.train_data_size,
        test_data_size=config.training.test_data_size,
        train_dist_args=config.dataset.train_dist_args,
        test_dist_args=config.dataset.test_dist_args,
    )

    # Loss function
    if config.training.loss == "ce_softmax":
        loss_fn = torch.nn.CrossEntropyLoss()  # This uses softmax in output
    elif config.training.loss == "nll":
        loss_fn = CustomNLLLoss()  # This doesnt have softmax in output
    else:
        raise ValueError(f"Loss function {config.training.loss} not recognized.")

    loss_histories = train_model(
        model,
        train_loader,
        optimizer,
        loss_fn,
        device,
        wandb_writer,
        iteration,
        remaining_iters,
        eval_interval=config.training.eval_interval,
        scheduler=scheduler,
        out_dir=config.run.out_dir,
        save_interval=config.training.save_interval,
        plot_interval=config.training.plot_interval,
        save_checkpoints=config.training.save_checkpoint,
    )

    if config.training.save_checkpoint:
        save_checkpoint(
            f"{config.run.out_dir}/models/model_{config.training.iters}.pth",
            model,
            optimizer,
            config.training.iters,
        )
        print(f"Training finished. Final model saved at iteration {config.training.iters}.")
    else:
        print(f"Training finished. Checkpoint saving disabled.")

    # Log final metrics to WandB
    if wandb_writer:
        print("Processing loss histories and logging final summary metrics to WandB...")

        train_loss_history = loss_histories.get("train_loss_history", [])
        final_train_loss = train_loss_history[-1] if train_loss_history else None

        summary_metrics = {}
        if final_train_loss is not None:
            summary_metrics["final_train_loss"] = final_train_loss

        # Note: Evaluation metrics (test loss, accuracy) are logged directly
        # by evaluate_model within the training loop at each eval_interval.
        # If a final summary of eval metrics is needed here, evaluate_model
        # would need to be modified to return them, similar to train_loss_history.

        if summary_metrics:
            wandb_writer.summary.update(summary_metrics)

        wandb_writer.finish()


if __name__ == "__main__":
    main()
