
import wandb
from tqdm import tqdm
import time
from utils.data_utils import load_dataset 
from utils.train_utils import train_one_epoch, save_checkpoint, save_results
from utils.evaluate_utils import evaluate
from utils.utils import load_config, setup_device, init_wandb, set_seed
from utils.train_utils import setup_model_and_optimizer, train_one_epoch
from itertools import product






# Load config
cfg = load_config()
# Set device
device_id = cfg.training.gpu
device = setup_device(device_id)
# Set random seed for reproducibility
set_seed(cfg.training.seed)
# Hyperparameters 
lambda_list = cfg.training.lambda_self
proj_list = cfg.model.projection_dim
# Parameters
num_runs = cfg.training.num_runs
random_w_size = cfg.training.random_w_size
num_windows = cfg.training.num_windows
batch_sampling = cfg.training.batch_sampling

# Load dataset
train_loader, val_loader, traj_length = load_dataset(cfg)

# --- Train loop ---
start_time = time.time()

all_results = []
# Loop over hyperparameter combinations
for lambda_self, projection_dim in product(lambda_list, proj_list):
    for run_id in range(num_runs):  # Run multiple times per setting
        # Set up a unique run name for each combination
        print(f"\n=== Run {run_id+1} Training with lambda_self={lambda_self}, projection_dim={projection_dim}, random_win={random_w_size}, n_win ={num_windows} ===")
        # Update config dynamically
        cfg.training.lambda_self = lambda_self
        cfg.model.projection_dim = projection_dim
        
         # Set and log a unique seed
        seed = run_id
        set_seed(seed)
        cfg.training.seed = seed
        desc = cfg.training.desc
        # Set a dynamic wandb run name
        
        run_name = f"lambda{lambda_self}_proj{projection_dim}_wsize{random_w_size}_nwind_{num_windows}_{desc}"
        init_wandb(run_name, cfg)
        # Init wandb for this combination
       
        # Initialize model and optimizer again for each run
        model, optimizer, scheduler = setup_model_and_optimizer(cfg, projection_dim, device, len(train_loader))
        
        training_results = {
            "run_id": run_id + 1,
            "lambda_self": lambda_self,
            "projection_dim": projection_dim,
            "seed": seed,
            "epochs": [],
            "val_metrics": []
        }
 

        for epoch in tqdm(range(cfg.training.num_epochs), desc="Training Progress", position=0):
            print(f"\n[Epoch {epoch+1}/{cfg.training.num_epochs}] Starting...")

            epoch_start = time.time()
            
            # Train for one epoch
            avg_loss_pred, avg_loss_self = train_one_epoch(train_loader, model, optimizer, scheduler, cfg, lambda_self, device,  batch_sampling= batch_sampling)
            val_metrics = evaluate(val_loader, model, ttt_lr_eval  = cfg.training.ttt_lr_eval, ttt_epochs_eval = cfg.training.ttt_epochs_eval, device = device, baseline_clip= 'a robot')
            current_lr = optimizer.param_groups[0]["lr"]
            # Log per batch to WandB for monitoring
            epoch_time = time.time() - epoch_start
            total_elapsed = time.time() - start_time
            estimated_total = (total_elapsed / (epoch + 1)) * cfg.training.num_epochs
            eta = estimated_total - total_elapsed
            # Log per epoch to WandB
            wandb.log({
                "train/loss_pred": avg_loss_pred,
                "train/loss_self": avg_loss_self,
                "train/lr": current_lr,
                "train/epoch_time_sec": epoch_time,
                "train/elapsed_time_sec": total_elapsed,
                "train/eta_sec": eta,
                **val_metrics,
                "epoch": epoch,
            })
            
            training_results["epochs"].append(epoch)
            training_results["val_metrics"].append({
                "epoch": epoch,
                "loss_pred": avg_loss_pred,
                "loss_self": avg_loss_self,
                "lr": current_lr,
                "val_metrics": val_metrics
            })
            # Print training and validation metrics
            val_metrics_str = ", ".join([f"{k}: {v:.4f}" for k, v in val_metrics.items()])
            print(f"[Epoch {epoch+1}/{cfg.training.num_epochs}] " 
                f"Train Loss Pred: {avg_loss_pred:.4f}, Self: {avg_loss_self:.4f} | "
                f"Val: {val_metrics_str} | "
                f"Time: {epoch_time:.2f}s | Elapsed: {total_elapsed:.1f}s | ETA: {eta/60:.1f} min")
            all_results.append(training_results)
            # Save model checkpoint
            if (epoch + 1) % cfg.training.checkpoint.steps == 0:
                save_checkpoint(model, epoch, run_name, save_dir= cfg.training.checkpoint.dir)
        # Save final model checkpoint  
        save_results(epoch, run_name, run_id, training_results, save_dir=cfg.training.checkpoint.dir)