import argparse
import inspect
import os
import yaml

import torch
import wandb

from models import get_model
from trainer import *
from utils import *

def main(
    config: Config
):
    wandb.login()
    wandb.init(
        project=config.project,
        config=config.to_dict(),
        tags=config.tags
    )

    # Set random seeds for reproducibility
    set_seed(config.seed)

    # Set up device (GPU/CPU)
    device = f"cuda:{config.device}" if torch.cuda.is_available() else "cpu"
    print(f"Using {device}.")

    # Set up data paths
    data_path = os.path.dirname(os.path.abspath(__file__)) + "/data/parquet/"
    
    #########
    # Train #
    #########

    # Load and process training data
    train_path = data_path + config.data.train_path
    train_data = get_data(train_path)
    train_data_input, train_data_output = process_data(
        data=train_data,
        data_dt=config.data.data_dt,
        pred_dt=config.data.pred_dt,
        residual=config.data.residual
    )
    
    # Initialize and fit data scaler if enabled
    scaler = None
    if config.data.scale:
        scaler = MaxScaler()
        scaler.fit(train_data_output)

        # Apply scaling to training data
        train_data_output = scaler.transform(train_data_output)
        
        # Log scaling factor to wandb
        wandb.log({
            "scaler/scale": scaler.scale
        })
        
    # Move training data to device and create dataloader
    train_data_input = train_data_input.to(device)
    train_data_output = train_data_output.to(device)
    train_dataloader = torch.utils.data.DataLoader(
        Dataset(
            input=train_data_input,
            output=train_data_output
        ),
        batch_size=config.data.batch_size,
        shuffle=True
    )
    
    ##########
    # Valid #
    #########

    # Load and process validation data
    valid_path = data_path + config.data.valid_path
    valid_data = get_data(valid_path)
    valid_data_input, valid_data_output = process_data(
        data=valid_data,
        data_dt=config.data.data_dt,
        pred_dt=config.data.pred_dt,
        residual=config.data.residual
    )
    
    # Scale validation data if scaling is enabled
    if config.data.scale:
        valid_data_output = scaler.transform(valid_data_output)
    
    # Move validation data to device and create dataloader
    valid_data_input = valid_data_input.to(device)
    valid_data_output = valid_data_output.to(device)
    valid_dataloader = torch.utils.data.DataLoader(
        Dataset(
            input=valid_data_input,
            output=valid_data_output
        ),
        batch_size=config.data.batch_size,
        shuffle=False
    )

    # Prepare trajectory for evaluation
    valid_traj = process_traj(
        traj=valid_data,
        data_dt=config.data.data_dt,
        pred_dt=config.data.pred_dt
    )
    valid_traj = valid_traj.to(device)

    #########
    # Model #
    #########

    # Initialize model based on configuration
    model_class = get_model(config.model.name)
    model = model_class(
        **config.model.params.to_dict()
    ).to(device)
    
    # Enable model parameter tracking in wandb
    wandb.watch(model)
    
    # Count and log parameters and trainable parameters
    params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    wandb.log(
        {
            "model/parameters": params,
            "model/trainable_parameters": trainable_params
        }
    )

    #########
    # Optim #
    #########

    # Initialize loss function (Mean Squared Error) for training
    loss_fn = torch.nn.MSELoss()
    
    # Initialize AdamW optimizer with configured learning rate
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config.optimizer.params.learning_rate,
        weight_decay=config.optimizer.params.weight_decay
    )

    # Set up learning rate scheduler based on configuration
    if config.scheduler.name == "constant":
        # ConstantLR maintains a fixed learning rate throughout training
        # factor=1.0 means no change to the initial learning rate
        scheduler = torch.optim.lr_scheduler.ConstantLR(
            optimizer,
            factor=1.0
        )
    elif config.scheduler.name == "cosine":
        # CosineAnnealingLR gradually reduces learning rate following a cosine curve
        # Calculate total number of iterations (batches * epochs)
        iters = len(train_dataloader) * config.epochs
        # T_max: total number of iterations for one cosine cycle
        # eta_min: minimum learning rate at the bottom of the cosine curve
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=iters,
            eta_min=config.scheduler.params.learning_rate_min
        )
    
    # Train Loop...
    print("Training...")
    best_valid_loss = float('inf')
    best_model_state = None
    
    for epoch in range(config.epochs):
        print(f"Epoch: {epoch + 1}/{config.epochs}")

        train(
            dataloader=train_dataloader,
            model=model,
            loss_fn=loss_fn,
            optimizer=optimizer,
            scheduler=scheduler
        )

        valid_loss = valid(
            dataloader=valid_dataloader,
            model=model,
            loss_fn=loss_fn
        )
        
        # Save model if validation loss improves
        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            best_model_state = model.state_dict().copy()
            print(f"New best model with validation loss: {best_valid_loss:.6f}")
            # wandb.log({"valid/best_loss": best_valid_loss})
    
    # Restore best model
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
        print(f"Restored best model with validation loss: {best_valid_loss:.6f}")
    
    print("Training complete.")

    #######################
    # Evaluate Trajectory #
    #######################

    pred_traj = generate_traj(
        model=model,
        ic=valid_traj[:, 0, :],
        n_steps=valid_traj.shape[1],
        residual=config.data.residual,
        scaler=scaler
    )

    step_loss, traj_loss, step_corr, steps_above_threshold_avg, steps_above_threshold_min, plots = evaluate_traj(
        pred_traj=pred_traj,
        targ_traj=valid_traj,
        visualize=True
    )

    # Log metrics to wandb
    wandb.log(
        {
            "valid/step_loss": wandb.plot.line(
                wandb.Table(
                    data=[
                        [i * config.data.pred_dt, step_loss[i]]
                        for i in range(len(step_loss))
                    ],
                    columns=["Time", "Loss"]
                ),
                "Time",
                "Loss",
                title="Loss over Time"
            ), 
            "valid/traj_loss": traj_loss,
            "valid/step_corr": wandb.plot.line(
                wandb.Table(
                    data=[
                        [i * config.data.pred_dt, step_corr[i]]
                        for i in range(len(step_corr))
                    ],
                    columns=["Time", "Correlation"]
                ),
                "Time",
                "Correlation",
                title="Correlation over Time"
            ),
            "valid/corr_above_threshold_avgx": wandb.plot.line(
                wandb.Table(
                    data=[
                        [x, y] for x, y in zip(
                            [step.item() * config.data.pred_dt for step in steps_above_threshold_avg],
                            torch.linspace(0.5, 1, 51)
                        )
                    ],
                    columns=["Time", "Corr"]
                ),
                "Time",
                "Corr",
                title="Time above Threshold (Average)"
            ),
            "valid/corr_above_threshold_minx": wandb.plot.line(
                wandb.Table(
                    data=[
                        [x, y] for x, y in zip(
                            [step.item() * config.data.pred_dt for step in steps_above_threshold_min],
                            torch.linspace(0.5, 1, 51)
                        )
                    ],
                    columns=["Time", "Corr"]
                ),
                "Time",
                "Corr",
                title="Time above Threshold (Minimum)"
            ),
            "valid/traj_plots": plots
        }
    )
    
    # RecurrentModel ## TODO: Currently metrics do not return time. They return steps.
    if isinstance(model, get_model("RecurrentModel")):
        step_losses = []
        traj_losses = [] 
        step_corrs = [] 
        steps_above_threshold_avgs = []
        steps_above_threshold_mins = []

        for k in range(1, int(model.recurrent_k_distribution["k_bar"]) * 2 + 1):
            print(f"Evaluating k={k}...")
            pred_traj = generate_traj(
                model=model,
                ic=valid_traj[:, 0, :],
                n_steps=valid_traj.shape[1],
                residual=config.data.residual,
                scaler=scaler,
                kwargs={"k": k}
            )

            step_loss, traj_loss, step_corr, steps_above_threshold_avg, steps_above_threshold_min, _ = evaluate_traj(
                pred_traj=pred_traj,
                targ_traj=valid_traj,
                visualize=False
            )
            
            step_losses.append(step_loss)
            traj_losses.append(traj_loss)
            step_corrs.append(step_corr)
            steps_above_threshold_avgs.append(steps_above_threshold_avg)
            steps_above_threshold_mins.append(steps_above_threshold_min)

        # Transpose lists and use wandb multiline to log k steps transposed.
        # Convert tensors to lists for proper transposition
        step_loss_vs_k = [list(i) for i in zip(*[tensor.tolist() for tensor in step_losses])]
        best_traj_loss = min(traj_losses)
        step_corr_vs_k = [list(i) for i in zip(*[tensor.tolist() for tensor in step_corrs])]
        steps_above_threshold_avg_vs_k = [list(i) for i in zip(*[tensor.tolist() for tensor in steps_above_threshold_avgs])]
        steps_above_threshold_min_vs_k = [list(i) for i in zip(*[tensor.tolist() for tensor in steps_above_threshold_mins])]

        # Log metrics to wandb
        wandb.log(
            {
                "valid/step_loss_vs_k": wandb.plot.line_series(
                    xs=range(1, int(model.recurrent_k_distribution["k_bar"]) * 2 + 1),
                    ys=step_loss_vs_k,
                    keys=[f"Step {t+1}" for t in range(len(step_loss_vs_k))],
                    title="Step Loss vs K",
                    xname="k"
                ),
                "valid/traj_loss_vs_k": wandb.plot.line(
                    wandb.Table(
                        data=[
                            [k, traj_losses[k-1]]
                            for k in range(1, int(model.recurrent_k_distribution["k_bar"]) * 2 + 1)
                        ],
                        columns=["k", "Traj Loss"]
                    ),
                    "k",
                    "Traj Loss",
                    title="Traj Loss vs K"
                ),
                "valid/best_traj_loss": best_traj_loss,
                "valid/step_corr_vs_k": wandb.plot.line_series(
                    xs=range(1, int(model.recurrent_k_distribution["k_bar"]) * 2 + 1),
                    ys=step_corr_vs_k,
                    keys=[f"Step {t+1}" for t in range(len(step_corr_vs_k))],
                    title="Step Corr vs K",
                    xname="k"
                ),
                "valid/steps_above_threshold_avg_vs_k": wandb.plot.line_series(
                    xs=range(1, int(model.recurrent_k_distribution["k_bar"]) * 2 + 1),
                    ys=steps_above_threshold_avg_vs_k,
                    keys=[f"Threshold {i+1}" for i in range(len(steps_above_threshold_avg_vs_k))],
                    title="Steps Above Threshold (Average) vs K",
                    xname="k"
                ),
                "valid/steps_above_threshold_min_vs_k": wandb.plot.line_series(
                    xs=range(1, int(model.recurrent_k_distribution["k_bar"]) * 2 + 1),
                    ys=steps_above_threshold_min_vs_k,
                    keys=[f"Threshold {i+1}" for i in range(len(steps_above_threshold_min_vs_k))],
                    title="Steps Above Threshold (Minimum) vs K",
                    xname="k"
                ),
            }
        )

    # Save model and scaler
    model_path = os.path.dirname(os.path.abspath(__file__)) + f"/models/.pt/{wandb.run.name}"
    os.makedirs(model_path, exist_ok=True)

    torch.save(model.state_dict(), f"{model_path}/model.pt")
    if config.data.scale:
        scaler.save(f"{model_path}/scaler.json")

    print(f"Model saved as {model_path}/model.pt.")

    # Clean up wandb
    wandb.finish()

if __name__ == "__main__":
    # Initialize argument parser for command line interface
    parser = argparse.ArgumentParser()
    # Required: Path to the YAML configuration file
    parser.add_argument("--config", type=str, help="Path to config file.")

    parser.add_argument("--set", metavar="KEY=VAL", action="append",
                        help="Override any config entry, e.g. --set model.params.activation=relu")
    
    args = parser.parse_args()
    
    # Load configuration from YAML file
    config = Config.from_yaml(args.config)
    
    for item in args.set or []:
        key, raw = item.split("=", 1)
        config.set(key, yaml.safe_load(raw))

    # Start the training process with the configured settings
    main(config)