import torch
import sys
import os
import torch.nn.functional as F  # Import MSE Loss
import wandb
import numpy as np

# Add project root to sys.path
project_root = os.path.abspath(
    os.path.join(os.path.dirname(__file__), ".")
)  # Adjusted for root script
if project_root not in sys.path:
    sys.path.append(project_root)

from exp_utils.setup_exp import set_exp
import hydra
from omegaconf import DictConfig, OmegaConf
from exp_utils.utils import (
    get_opt_scheduler,
    instantiate_filtered,
    save_checkpoint,
    save_tensor,
    load_checkpoint,
)

# Adjusted imports for the new task
from experiments.single_location_linear_regression.train import train_model
from data.single_location_linear_regression.sl_lr_dataloader import (
    SingleLocationLinearRegressionDataloader,
    SingleLocationLinearRegressionTask,
)


def print_model_parameters(module):
    # (Function remains the same as in ar_ic_learning.py)
    if not hasattr(module, "parameters"):
        print("The provided object does not have parameters (it might not be a model).")
        return

    if hasattr(module, "named_parameters"):
        for name, param in module.named_parameters():
            print(f"Layer: {name} | Size: {param.size()} | Requires Grad: {param.requires_grad}")
    else:
        for idx, param in enumerate(module.parameters()):
            print(
                f"Layer: Unnamed_{idx} | Size: {param.size()} | Requires Grad: {param.requires_grad}"
            )


@hydra.main(version_base=None, config_path="conf", config_name="config")
def main(config: DictConfig) -> None:
    # Hydra / OmegaConf housekeeping
    OmegaConf.set_struct(config, False)

    # 1. Set up experiment in W&B
    wandb_writer, last_ckpt_path = set_exp(config)

    device = torch.device(config.run.device)

    # Determine input dimension based on the flag
    input_dim = config.dataset.dimension
    if config.dataset.get("show_relevant_token", False):  # Use .get for safety
        input_dim += 1
        print(f"show_relevant_token is True. Input dimension adjusted to {input_dim}")

    # 2. Create the model
    model_config_overrides = {
        "vocab_size": input_dim,  # Size of the input
        "output_size": config.dataset.dimension,  # Size of the output (ignore the eventual +1)
        "embedding_type": "linear",  # Linear embedding for regression (if used)
        "pos_enc": "sin_cos",  # we consider long sequences so we need to use cos_sin
        # "enable_norm": False,  # no norm for this task (hard to regress small / big values otherwise)
        "block_size": config.dataset.sequence_length + 1,  # +1 for the last prediction
    }
    if config.model.get("skip_embedding", False) or config.model.get(
        "force_embedding_dim_to_input_dim", False
    ):
        assert (
            input_dim % config.model.n_heads == 0
        ), f"Input dimension {input_dim} must be divisible by number of heads {config.model.n_heads}"  # to make sure that we can split the embedding dimension in n_heads
        model_config_overrides["embedding_dim"] = input_dim
        print(f"Forcing embedding_dim to be equal to input_dim: {input_dim}")

    model = instantiate_filtered(
        config.model, config.dataset, overwrite_cfgs=model_config_overrides
    ).to(device)
    print("Model Configuration:", config.model)
    print("Dataset Configuration:", config.dataset)
    print_model_parameters(model)

    # Optional PyTorch 2.0 compile
    # model = torch.compile(model)

    # 3. Create the optimizer and scheduler
    optimizer, scheduler = get_opt_scheduler(
        config.training.lr,
        config.training.wd,
        config.training.iters,
        config.training.scheduler,
        config.training.opt,
        model,
    )

    # Track iteration number
    iteration = 0

    # 4. Load checkpoint if exists
    if last_ckpt_path and os.path.exists(last_ckpt_path):
        iteration = load_checkpoint(
            last_ckpt_path, model, optimizer, scheduler
        )  # Pass scheduler too
        print(f"Resuming from iteration {iteration}")
    else:
        print("No checkpoint found. Starting from iteration 0 (scratch).")

    # 5. Calculate remaining iterations
    remaining_iters = config.training.iters - iteration
    if remaining_iters <= 0:
        print("All iterations completed in previous runs!")
        if wandb_writer:
            wandb_writer.finish()
        return

    # 6. Create the Task Generator
    task_generator = SingleLocationLinearRegressionTask(
        dimension=config.dataset.dimension,
        sequence_length=config.dataset.sequence_length,
        p_repeat=config.dataset.p_repeat,
        burstiness=config.dataset.burstiness,
        device=device,
        show_relevant_token=config.dataset.get("show_relevant_token", False),
        random_relevant_token_positions=config.dataset.get(
            "random_relevant_token_positions", False
        ),
        size_pool_repetition=config.dataset.get("size_pool_repetition", 1),
    )

    # 7. Create the primary dataloader using the shared task
    train_loader = SingleLocationLinearRegressionDataloader(
        task=task_generator,  # Pass the task object
        device=device,
        batch_size=config.training.batch_size,
        iters=remaining_iters,  # Dataloader should yield for remaining iters
        train_data_size=config.dataset.train_data_size,
        test_data_size=config.dataset.test_data_size,
        # Removed redundant parameters now in task_generator
    )

    # 8. Create the secondary eval dataloader if needed, using the *same* task
    eval_loader_generalization = None
    if config.dataset.burstiness_eval is not None or config.dataset.p_repeat_eval is not None:
        # Create a temporary task *only* if generation parameters differ, but reuse W

        burstiness_eval = (
            config.dataset.burstiness_eval
            if config.dataset.burstiness_eval is not None
            else config.dataset.burstiness
        )
        p_repeat_eval = (
            config.dataset.p_repeat_eval
            if config.dataset.p_repeat_eval is not None
            else task_generator.p_repeat  # Use training task's p_repeat if not specified
        )

        # Create a separate task instance for OOD evaluation parameters
        ood_task_generator = SingleLocationLinearRegressionTask(
            dimension=config.dataset.dimension,
            sequence_length=config.dataset.sequence_length,
            p_repeat=p_repeat_eval,
            burstiness=burstiness_eval,
            device=device,
            show_relevant_token=config.dataset.get("show_relevant_token", False),
            random_relevant_token_positions=config.dataset.get(
                "random_relevant_token_positions", False
            ),
        )
        # Crucially, overwrite the W matrix to be the same as the training task's W
        ood_task_generator.W = task_generator.W

        eval_loader_generalization = SingleLocationLinearRegressionDataloader(
            task=ood_task_generator,  # Pass the OOD task object
            device=device,
            batch_size=config.training.batch_size,
            iters=1,  # Only need one iteration for evaluation data
            train_data_size=config.training.batch_size,  # Minimal train data size
            test_data_size=config.dataset.test_data_size,  # Same test size as primary
        )

    # 9. Start training
    print(f"Starting training for {remaining_iters} iterations...")
    # Capture the loss histories returned by train_model
    loss_histories = train_model(
        model=model,
        train_loader=train_loader,
        eval_loader_generalization=eval_loader_generalization,
        optimizer=optimizer,
        device=device,
        wandb_writer=wandb_writer,
        iteration=iteration,
        remaining_iters=remaining_iters,
        eval_interval=config.training.eval_interval,
        scheduler=scheduler,
        out_dir=config.run.out_dir,
        save_interval=config.training.save_interval,
        plot_interval=config.training.plot_interval,
        config=config,
        save_checkpoint=config.training.save_checkpoint,
    )

    # 10. Final save
    final_iter = iteration + remaining_iters
    if config.training.save_checkpoint:
        save_checkpoint(
            f"{config.run.out_dir}/models/model_{final_iter}.pth",
            model,
            optimizer,
            final_iter,
        )
        # Also save the W matrix from the task generator
        save_tensor(
            f"{config.run.out_dir}/models/task.pth",
            {"W": task_generator.W, "burst_indices": task_generator.burst_indices},
        )
        print(f"Training finished. Final model saved at iteration {final_iter}.")
    else:
        print(f"Training finished. Checkpoint saving disabled.")

    # 11. Process loss histories and log final metrics to WandB
    if wandb_writer:
        print("Processing loss histories and logging final summary metrics to WandB...")

        train_loss_history = loss_histories["train_loss_history"]
        eval_loss_history = loss_histories["eval_loss_history"]

        final_train_loss = train_loss_history[-1] if train_loss_history else None
        final_eval_loss_ID = (
            eval_loss_history["ID"][-1] if len(eval_loss_history.get("ID")) > 0 else None
        )
        final_eval_loss_OOD = (
            eval_loss_history["OOD"][-1] if len(eval_loss_history.get("OOD")) > 0 else None
        )

        plateau_length_iter = float("inf")  # Default if plateau not reached
        run_exploded = False
        previous_eval_loss = None

        threshold = 0.1  # Arbitrary, 0 prediction gives loss of 1

        for i in range(1, len(eval_loss_history["ID"])):
            current_loss = eval_loss_history["ID"][i]

            # Check for plateau
            if plateau_length_iter == float("inf") and current_loss <= threshold:
                prev_loss = eval_loss_history["ID"][i - 1]
                # Linear interpolation
                if (
                    prev_loss > threshold and prev_loss != current_loss
                ):  # Ensure loss crossed threshold and avoid division by zero
                    plateau_length_iter = int(
                        config.training.eval_interval
                        * (i - 1 + (prev_loss - threshold) / (prev_loss - current_loss))
                    )

                # Check for explosion
                if previous_eval_loss is not None:
                    if current_loss > 1.2 * previous_eval_loss:
                        run_exploded = True

                previous_eval_loss = current_loss

        # Log summary metrics
        wandb_writer.summary.update(
            {
                "final_train_loss": final_train_loss,
                "final_eval_loss_ID": final_eval_loss_ID,
                "final_eval_loss_OOD": final_eval_loss_OOD,
                "plateau_length": plateau_length_iter,
                "run_exploded": run_exploded,
            }
        )

        wandb_writer.finish()


if __name__ == "__main__":
    main()
