import os
import cv2
import random
import optuna
import logging
import pandas as pd
from types import MethodType
import numpy as np
import torch
import config_hp
import config_task
import utils
import random
import traceback

import grounding, segmentation, validation, prompt_boosting


# Instantiate grounding model
grounding_model = grounding.GrdModelCogVLM()
print('Grounding model loaded successfully.')

# Instantiate segmentation model
segmentation_model = segmentation.SegModelSAM()
#segmentation_model = segmentation.SegModelMedSAM()
print('Segmentation model loaded successfully.')

# Instantiate validation model
validation_model = validation.ValModelBioMedCLIP()
print('Validation model loaded successfully.')

# Prompt boosting model
boosting_model = prompt_boosting.BstModelDINOKMeans()
print('Prompt Boosting model loaded successfully.')



def vlm_seg(image_path, hp, args, fallback=False, debug_mode='none', save_path=None):
    """
    Main VLM-based segmentation pipeline.

    Args:
        image_path (str): Path to the input image.
        hp (dict): Dictionary of hyperparameters.
        args (dict): Task information.
        fallback (bool): Whether to allow fallback behavior during inference in case of abnormal segmentation results.
        debug_mode (str): Visualization mode ('none', 'demo', or 'batch').
        save_path (str): Path to save visualization output.

    Returns:
        Tuple[dict, np.ndarray]: Validation result and predicted mask.
    """
    # Step 1: Grounding to get bounding box
    bbox = grounding_model.predict(args, image_path, hp)

    # Step 2: Prompt boosting
    points = boosting_model.predict(args, image_path, hp, bbox)

    # Step 3: Segmentation based on bbox and points
    mask = segmentation_model.predict(args, image_path, hp, bbox, points)

    # Step 4: Validation
    valret = validation_model.predict(args, image_path, mask)

    # Fallback mechanism if validation score is too low
    if fallback and valret['score'] < 1e-3:
        image = cv2.imread(image_path)

        # generate a default box
        h, w, _ = image.shape
        top_left = (int(args.center_x_range[0] * w), int(args.center_y_range[0] * h))
        bottom_right = (int(args.center_x_range[1] * w), int(args.center_y_range[1] * h))
        bbox = (*top_left, *bottom_right)
        points = []

        mask = segmentation_model.predict(args, image_path, hp, bbox, points)
        valret = validation_model.predict(args, image_path, mask)

    # Step 5: Visualization if enabled
    if debug_mode == 'demo':
        utils.visualization(image_path, bbox, mask, points, args)
    elif debug_mode == 'batch':
        utils.visualization(image_path, bbox, mask, points, args, save_path=save_path)

    return valret, mask

def evaluate_model(hp, image_paths, args, fallback=False, debug='none', log_name='optuna'):
    val_score_avg = 0
    gt_score_avg = 0
    output_dir = os.path.join(args.output_dir, log_name)
    os.makedirs(output_dir, exist_ok=True)

    for image_path in image_paths:
        file_name = os.path.basename(image_path)
        save_path = os.path.join(output_dir, file_name)
        response, mask = vlm_seg(image_path, hp, args, fallback, debug, save_path)
        val_score = response['score']
        msg = response['reason']

        gt_score = utils.calc_dice_for_file(image_path, mask, args) if args.gt_dir else 0
        gt_score_avg += gt_score

        if debug == 'batch':
            log_name = file_name + '.txt'
            with open(os.path.join(output_dir, log_name), 'w') as f:
                f.write(f"{file_name}\t{gt_score:.4f}\t{val_score:.4f}\t{val_score - gt_score:.4f}\n{msg}\n")

        val_score_avg += val_score

    n = len(image_paths)
    val_score_avg /= n
    gt_score_avg /= n
    with open(os.path.join(output_dir, 'log_all.txt'), 'a') as f:
        msg = f"{hp}\t[gt]{gt_score_avg:.4f}\t[val]{val_score_avg:.4f}"
        f.write(f"{msg}\n")
        print(msg)

    return val_score_avg

def optimize_with_optuna(args, trial_num = 100):
    seen = set()

    # Setup logging
    optuna_logger = optuna.logging.get_logger("optuna")
    optuna_logger.setLevel(logging.INFO)
    log_file = os.path.join(args.output_dir, "optuna_log.txt")
    file_handler = logging.FileHandler(log_file, mode="w")
    file_handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s %(message)s'))
    optuna_logger.addHandler(file_handler)

    def objective(trial):
        param_dict = {}
        for name, (val_type, val_cfg) in config_hp.hp_space.items():
            if val_type == 'categorial':
                val = trial.suggest_categorical(name, val_cfg)
            elif val_type == 'int':
                val =trial.suggest_int(name, val_cfg[0], val_cfg[1])
            elif val_type == 'float':
                val = trial.suggest_float(name, val_cfg[0], val_cfg[1])
            else:
                assert 'wrong val_type %s' % val_type
            param_dict[name] = val

        key = tuple(sorted(param_dict.items()))
        if key in seen:
            raise optuna.TrialPruned()
        seen.add(key)

        return evaluate_model(param_dict, args.val_images, args)

    startup_num = 10 if trial_num > 50 else 1
    study = optuna.create_study(
        direction="maximize",
        sampler=optuna.samplers.TPESampler(n_startup_trials=startup_num, multivariate=True, group=True, seed=42)
    )
    study.optimize(objective, n_trials=trial_num)

    print("Best score:", study.best_value)
    print("Best params:", study.best_params)

    df = study.trials_dataframe()
    df.to_csv(os.path.join(args.output_dir, "optuna_trials.csv"), index=False)
    df.to_json(os.path.join(args.output_dir, "optuna_trials.json"), orient="records", lines=True)
    return study


def set_seed(seed: int = 42):
    # Python built-in
    random.seed(seed)
    # Environment hash
    os.environ['PYTHONHASHSEED'] = str(seed)

    # NumPy
    np.random.seed(seed)

    # PyTorch
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if using multi-GPU

    # Ensure deterministic behavior
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def run(task_name):
    try:
        set_seed(0)
        args = config_task.read_hp_config(task_name)

        os.system(f"rm -rf {args.output_dir}")
        os.makedirs(args.output_dir, exist_ok=True)

        print('[%s] run BO on %d images, evaluate on %d images' % (task_name, len(args.val_images), len(args.test_images)))

        print("Running default evaluation...")
        hp = config_hp.default_hp()
        evaluate_model(hp, args.test_images, args, fallback=False, debug='batch', log_name='base')

        set_seed(0)
        print("Running random evaluation...")
        hp = config_hp.random_hp()
        evaluate_model(hp, args.test_images, args, fallback=False, debug='batch', log_name='random')

        set_seed(0)
        print("Starting Optuna optimization...")
        study = optimize_with_optuna(args, trial_num = 100)

        print("Evaluating best found parameters...")
        best_hp = study.best_params
        evaluate_model(best_hp, args.test_images, args, fallback=False, debug='batch', log_name='optimal')

    except Exception as e:
        traceback.print_exc()
        print('task %s failed' % task_name)


run('REFUGE_small')
