import argparse
import inspect
import os
import yaml

import torch
import wandb

from models import get_model
from trainer import *
from utils import *

def main(
    config: Config
):
    wandb.login()
    wandb.init(
        project=config.project,
        config=config.__dict__,
        tags=config.tags,
        entity="",
    )

    # Set random seeds for reproducibility
    set_seed(config.seed)

    # Set up device (GPU/CPU)
    device = f"cuda:{config.device}" if torch.cuda.is_available() else "cpu"
    print(f"Using {device}.")

    
    #########
    # Data #
    #########
    train_dataloader, valid_dataloader = build_dataloader(config.data.training_size,config.data.batch_size)

    #########
    # Model #
    #########

    # Initialize model based on configuration
    model_class = get_model(config.model.name)
    model = model_class(
        **config.model.params.to_dict()
    ).to(device)
    
    # Enable model parameter tracking in wandb
    wandb.watch(model)
    
    # Count and log parameters and trainable parameters
    params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    wandb.log(
        {
            "model/parameters": params,
            "model/trainable_parameters": trainable_params
        }
    )

    #########
    # Optim #
    #########
    
    # Initialize loss function (Mean Squared Error) for training
    loss_fn = torch.nn.MSELoss()
    loss_fn_valid = torch.nn.MSELoss()
    
    # Initialize AdamW optimizer with configured learning rate
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config.optimizer.params.learning_rate,
        weight_decay=config.optimizer.params.weight_decay
    )

    # Set up learning rate scheduler based on configuration
    if config.scheduler.name == "constant":
        # ConstantLR maintains a fixed learning rate throughout training
        # factor=1.0 means no change to the initial learning rate
        scheduler = torch.optim.lr_scheduler.ConstantLR(
            optimizer,
            factor=1.0
        )
    elif config.scheduler.name == "cosine":
        # CosineAnnealingLR gradually reduces learning rate following a cosine curve
        # Calculate total number of iterations (batches * epochs)
        iters = len(train_dataloader) * config.epochs
        # T_max: total number of iterations for one cosine cycle
        # eta_min: minimum learning rate at the bottom of the cosine curve
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=iters,
            eta_min=config.scheduler.params.learning_rate_min
        )
    
    
    # Train Loop...
    print("Training...")
    best_valid_loss = float('inf')
    best_model_state = None
    
    for epoch in range(config.epochs):
        print(f"Epoch: {epoch + 1}/{config.epochs}")

        train(
            dataloader=train_dataloader,
            model=model,
            loss_fn=loss_fn,
            optimizer=optimizer,
            scheduler=scheduler,
            device=device
        )

        valid_loss = valid(
            dataloader=valid_dataloader,
            model=model,
            loss_fn=loss_fn_valid,
            device=device
        )
        
        # Save model if validation loss improves
        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            best_model_state = model.state_dict().copy()
            print(f"New best model with validation loss: {best_valid_loss:.6f}")
            # wandb.log({"valid/best_loss": best_valid_loss})
    
    # Restore best model
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
        print(f"Restored best model with validation loss: {best_valid_loss:.6f}")
    
    print("Training complete.")

    #######################
    # Evaluate Trajectory #
    ########################

    model_path = ''
    os.makedirs(model_path, exist_ok=True)

    torch.save(model.state_dict(), f"{model_path}/model.pt")

    print(f"Model saved as {model_path}/model.pt.")

    # Clean up wandb
    wandb.finish()

if __name__ == "__main__":
    # Initialize argument parser for command line interface
    parser = argparse.ArgumentParser()
    # Required: Path to the YAML configuration file
    parser.add_argument("--config", type=str, help="Path to config file.")

    parser.add_argument("--set", metavar="KEY=VAL", action="append",
                        help="Override any config entry, e.g. --set model.params.activation=relu")
    
    args = parser.parse_args()
    
    # Load configuration from YAML file
    config = Config.from_yaml(args.config)
    
    for item in args.set or []:
        key, raw = item.split("=", 1)
        config.set(key, yaml.safe_load(raw))

    # Start the training process with the configured settings
    main(config)