import os
import json
import torch
import datetime
import itertools
from omegaconf import OmegaConf
from torch.utils.data import DataLoader
from tqdm import tqdm

from models.vlm_adapt import VLMAdapt
from utils.data_utils import TPDataset
from utils.evaluate_utils import evaluate

from configs.dataset_config import dataset_config

from utils.utils import load_config, setup_device, init_wandb, set_seed


def prepare_device(cfg):
    """
    Set up the device for training based on the configuration.
    """
    gpu = cfg.evaluation.gpu
    device = setup_device(gpu)
    return device


def prepare_tt_combos(cfg):
    """
    Prepare the combinations of TTT learning rates, epochs, and window sizes.
    """
    ttt_lrs = list(cfg.evaluation.ttt_lr)
    ttt_epochs_list = list(cfg.evaluation.ttt_epochs)
    ttt_window_size = list(cfg.evaluation.window_size)
    ttt_combos = list(itertools.product(ttt_lrs, ttt_epochs_list, ttt_window_size))
    return ttt_combos


def load_model(model_name, projection_dim, cfg, device):
    """
    Load the VLMAdapt model from checkpoint.
    """
    model = VLMAdapt(
        clip_model_name=cfg.model.clip_model,
        pretrained_clip=cfg.model.pretrained_clip,
        projection_dim=projection_dim,
    ).to(device)
    print(f"Current working directory: {os.getcwd()}")
    if os.path.exists(model_name):
        print("Checkpoint found!")
    else:
        print("Checkpoint not found!")
    print(model_name)
    checkpoint_path = model_name
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}") 
    model.load_state_dict(torch.load(checkpoint_path, map_location=device), strict=False)

    model.eval()
    return model


def evaluate_model_on_split(model, split, traj_len, baseline_desc, loader, ttt_combos, results, device, cfg):
    """
    Evaluate the model on a specific dataset split using the TTT combos.
    """
    for lr, steps, window in ttt_combos:
        print(f"Evaluating on {split} | TTT lr={lr}, epochs={steps}, window={window}")
        metrics = evaluate(
            loader,
            model,
            ttt_lr_eval=float(lr),     
            ttt_epochs_eval=int(steps),
            device=device,
            baseline_clip=baseline_desc,
            window_size=int(window),
            reset=bool(cfg.evaluation.reset),
            shuffling_online=bool(cfg.evaluation.shuffling_online)
        )
        key = f"{split}__tttlr{lr}__tttsteps{steps}__window{window}_model{model_name}_{baseline_desc}_perturb{cfg.evaluation.shuffling_online}_reset{cfg.evaluation.reset}"
        results[key] = metrics
        print(f"--- Metrics for {key} ---")
        for metric_name, value in metrics.items():
            print(f"{metric_name}: {value:.6f}")
        print("-" * 40)


def create_loader(cfg, split, traj_len):
    """
    Create the DataLoader for a specific dataset split.
    """
    dataset = TPDataset(
        root_dir=cfg.dataset_eval.root_dir,
        split=split,
        traj_len=traj_len,
    )
    loader = DataLoader(
        dataset,
        batch_size=1,  # step-by-step evaluation
        shuffle=False,
        num_workers=cfg.dataset_eval.num_workers,
        pin_memory=True,
    )
    return loader


def save_results(results, desc, reset):
    """
    Save the evaluation results to a timestamped JSON file.
    """
    os.makedirs("results", exist_ok=True)
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    save_path = os.path.join("results", f"f_evaluation_results_{timestamp}_{desc}_reset{reset}.json")

    with open(save_path, "w") as f:
        json.dump(results, f, indent=2)

    print(f"Saved evaluation results to {save_path}")


def run_evaluate(cfg):
    """
    Main function for evaluating the configurations.
    """
    # Set up device
    device = prepare_device(cfg)
    
    # Prepare TTT hyperparameter combinations
    ttt_combos = prepare_tt_combos(cfg)
    
    # Dataset splits and parameters
    splits = cfg.dataset_eval.splits
    baseline_descriptions = [dataset_config.get(split, {}).get("baseline", "a robot") for split in splits]
    traj_lengths = [dataset_config.get(split, {}).get("traj_len", 49) for split in splits]


    
    
    print('traj_lens:', traj_lengths)
    print('splits:', splits)
    print('baseline_descriptions:', baseline_descriptions)
    print('ttt_combos:', ttt_combos)
    
    # Prepare for evaluation
    results = {}
    desc = cfg.evaluation.desc
    total_runs = len(splits) * len(ttt_combos)

    model_dirs = cfg.evaluation.model_dirs
    proj_dims = cfg.evaluation.projection_dims
    shuffling_online = cfg.evaluation.shuffling_online

    print(f"Model directories: {model_dirs}")
    print(f"Projection dimensions: {proj_dims}")
    print(f"Shuffling online: {shuffling_online}, desc: {desc}, reset: {cfg.evaluation.reset}")
    print(f"Using GPU: {cfg.evaluation.gpu}")

    # Start evaluation with tqdm progress bar
    with tqdm(total=total_runs, desc="Evaluating Configs") as pbar:
        for i, model_name in enumerate(model_dirs):
            # Load the model
            
            model = load_model(model_name, proj_dims[i],cfg, device)

            for split, traj_len, baseline_desc in zip(splits, traj_lengths, baseline_descriptions):
                loader = create_loader(cfg, split, traj_len)
                evaluate_model_on_split(model, split, traj_len, baseline_desc, loader, ttt_combos, results, device, cfg)
                pbar.update(1)

    # Save results to JSON
    reset = cfg.evaluation.reset
    save_results(results, desc, reset)


# Load config and start evaluation
cfg = load_config()
run_evaluate(cfg)
