# utils/train_utils.py

import os
import torch
import random
import numpy as np
import wandb
from datetime import datetime
from torch.optim.lr_scheduler import LambdaLR
import math
from utils.data_utils import temporal_perturb_inplace
import json
from models.vlm_adapt import VLMAdapt

def train_one_epoch(train_loader, model, optimizer, scheduler, cfg, lambda_self, device, batch_sampling="diversity"):
    """
    Trains the model for one epoch based on the specified batch training strategy.
    
    Args:
        train_loader (DataLoader): DataLoader for the training set.
        model (nn.Module): The model to train.
        optimizer (Optimizer): The optimizer for the model.
        scheduler (Scheduler): The learning rate scheduler.
        cfg (Config): Configuration object containing training parameters.
        lambda_self (float): Regularization strength for self-supervision.
        device (Device): Device (CPU or GPU) to run the model on.
        batch_sampling (str): Type of training strategy to use. Options: 'ft', 'rnn', 'diversity', 'random', 'mini_batch'.
    
    Returns:
        Tuple: Average loss for prediction and self-supervision over the epoch.
    """
    # Set model to training mode
    model.train()
    
    # Initialize accumulators for prediction and self-supervision losses
    total_loss_pred = 0
    total_loss_self = 0

    # Loop over batches in the train_loader
    for frames, goal_text, progress_label, valid_mask in train_loader:
        frames = frames.to(device)  # Move frames to the device
        progress_label = progress_label.to(device)  # Move progress labels to the device
        valid_mask = valid_mask.to(device)  # Move valid mask to the device

        # Apply temporal perturbation to frames, labels, and mask (data augmentation)
        frames_ptr, progress_label_ptr = temporal_perturb_inplace(
            frames,
            progress_label,
            valid_mask,
            cfg.training.temporal_perturb_rate,
            device,
        )

        # Select the appropriate training function based on the batch_sampling argument
        if batch_sampling == "ft":
            # Training with frozen CLIP and learned linear head
            loss_pred, loss_self = model.train_batch_ft(
                frames_ptr,
                goal_text,
                progress_label_ptr,
                valid_mask,
                optimizer,
                lambda_self,
                random_w_size=cfg.training.random_w_size,
                num_windows=cfg.training.num_windows,
                grad_clip=cfg.training.grad_clip
            )
        elif batch_sampling == "rnn":
            # Training with frozen CLIP + temporal modeling using RNN
            loss_pred, loss_self = model.train_batch_rnn(
                frames_ptr,
                goal_text,
                progress_label_ptr,
                valid_mask,
                optimizer,
                lambda_self,
                random_w_size=cfg.training.random_w_size,
                num_windows=cfg.training.num_windows,
                grad_clip=cfg.training.grad_clip
            )
        elif batch_sampling == "dissimilarity":
            # Training with TTT (Test-Time Training) and Dissimilarity-based Mini-batch Sampling 
            loss_pred, loss_self = model.train_batch_dissimilarity(
                frames_ptr,
                goal_text,
                progress_label_ptr,
                valid_mask,
                optimizer,
                lambda_self,
                random_w_size=cfg.training.random_w_size,
                num_windows=cfg.training.num_windows,
                grad_clip=cfg.training.grad_clip
            )
        elif batch_sampling == "random":
            # Training with TTT on multiple random windows per sequence
            loss_pred, loss_self = model.train_batch_random(
                frames_ptr,
                goal_text,
                progress_label_ptr,
                valid_mask,
                optimizer,
                lambda_self,
                random_w_size=cfg.training.random_w_size,
                num_windows=cfg.training.num_windows,
                grad_clip=cfg.training.grad_clip
            )
        elif batch_sampling == "mini_batch":
            # Training with mini-batch TTT (online, chunk-wise adaptation)
            loss_pred, loss_self = model.train_mini_batch(
                frames_ptr,
                goal_text,
                progress_label_ptr,
                valid_mask,
                optimizer,
                lambda_self,
                random_w_size=cfg.training.random_w_size,
                num_windows=cfg.training.num_windows,
                grad_clip=cfg.training.grad_clip
            )
        else:
            # If the batch_sampling is invalid, raise an error
            raise ValueError(f"Invalid batch_sampling: {batch_sampling}. Must be one of: 'ft', 'rnn', 'diversity', 'random', 'mini_batch'.")

        # Log per batch to WandB for monitoring
        current_lr = optimizer.param_groups[0]["lr"]  # Get the current learning rate
        wandb.log({
            "train/loss_pred_batch": loss_pred,  # Log prediction loss
            "train/loss_self_batch": loss_self,  # Log self-supervised loss
            "train/lr_batch": current_lr,  # Log learning rate
        })

        # Step the scheduler if it's provided
        if scheduler is not None:
            scheduler.step()

        # Accumulate losses over the entire epoch
        total_loss_pred += loss_pred
        total_loss_self += loss_self

    # Calculate the average losses over all batches
    avg_loss_pred = total_loss_pred / len(train_loader)
    avg_loss_self = total_loss_self / len(train_loader)

    # Return average prediction loss and self-supervision loss
    return avg_loss_pred, avg_loss_self



def save_results(epoch, run_name, run_id, results, save_dir="checkpoints"):
    os.makedirs(save_dir, exist_ok=True)

    # --- Save results (append to shared JSON file) ---
    timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
    
    results_path = os.path.join(save_dir, f"training_results_{run_name}_{timestamp}.json")

    if os.path.exists(results_path):
        with open(results_path, "r") as f:
            all_results = json.load(f)
    else:
        all_results = {}

    # Create nested structure: run_id → epoch_X
    if run_id not in all_results:
        all_results[run_id] = {}

    all_results[run_id][f"epoch_{epoch+1}"] = results

    with open(results_path, "w") as f:
        json.dump(all_results, f, indent=2)

    print(f"Saved results for run_id={run_id}, epoch={epoch+1}")



def save_checkpoint(model, epoch, run_name, save_dir="checkpoints"):
    os.makedirs(save_dir, exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")

    # --- Save model ---
    model_filename = f"epoch{epoch+1}_{run_name}_{timestamp}.pth"
    model_path = os.path.join(save_dir, model_filename)
    torch.save(model.state_dict(), model_path)
    print(f"Model checkpoint saved at {model_path}")



def build_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps):
    def lr_lambda(current_step):
        # warmup lr: 0 -> base_lr 
        if current_step < warmup_steps:
            return float(current_step) / float(max(1, warmup_steps))
        # lr:  base_lr -> 0
        progress = (current_step - warmup_steps) / float(max(1, total_steps - warmup_steps))
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
    # lr = base_lr * lr_lambda(current_step)
    return LambdaLR(optimizer, lr_lambda)


def setup_model_and_optimizer(cfg, projection_dim, device, size):
    """Initialize model and optimizer for each run."""
    model = VLMAdapt(
        clip_model_name=cfg.model.clip_model,
        pretrained_clip=cfg.model.pretrained_clip,
        projection_dim=projection_dim,
    ).to(device)

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=cfg.training.base_lr,
        weight_decay=cfg.training.weight_decay,
    )
    
    scheduler = build_cosine_schedule_with_warmup(
        optimizer,
        warmup_steps=int(cfg.training.cosine_warmup_pct * size * cfg.training.num_epochs),
        total_steps=size * cfg.training.num_epochs,
    )

    return model, optimizer, scheduler


