import argparse
import json
import hydra
import os
import time
from typing import Tuple, List

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    set_seed,
)

from omegaconf import DictConfig

import os
import sys

from tqdm import tqdm

# Add the project root to the path to ensure consistent imports
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

from alignment.aligners.base_aligner import (
    AlignConfig,
    LeastRestrictiveReachableAlignerSpec,
    NoiseConfig,
    ReControlAlignerSpec,
    SmoothBlendReachableAlignerSpec,
    GradientAscentReachableAlignerSpec,
    SAPAlignerSpec
)
from alignment.aligners import backwards_reachable_aligner, recontrol_aligner, sap_aligner


# ----------------------------
# Safety classifier helper
# ----------------------------
def classify_texts(texts: List[str], device: torch.device, clf_tok: AutoTokenizer, clf: nn.Module) -> Tuple[List[int], List[List[float]]]:
    # Cap at 512 to satisfy roberta-base (514 incl. specials)
    cap = 512
    enc = clf_tok(
        texts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=cap
    ).to(device)
    with torch.no_grad():
        logits = clf(**enc).logits
        probs = F.softmax(logits, dim=-1)          # [B, 2]  ->  index 1 = offensive
        preds = probs.argmax(dim=-1)               # 0 = safe, 1 = unsafe/offensive
    return preds.cpu().tolist(), probs.cpu().tolist()


# ----------------------------
# CLI / runner
# ----------------------------
def load_hydra_config():
    """Load Hydra config without interfering with command line arguments."""
    # Save and restore sys.argv to prevent Hydra from processing our command line args
    import sys
    saved_argv = sys.argv.copy()
    try:
        # Temporarily replace sys.argv with just the script name
        sys.argv = [saved_argv[0]]
        
        # Initialize Hydra
        hydra.initialize(version_base="1.1", config_path="../SafetyPolytope/exp_configs")
        
        # Load the config
        cfg = hydra.compose(config_name="eval_config")
        return cfg
    finally:
        # Restore original command line arguments
        sys.argv = saved_argv

def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--model_name", type=str, default="Qwen/Qwen2-1.5B")
    p.add_argument("--root_dir", type=str, default="")
    p.add_argument("--layer_idx", type=int, default=20)
    p.add_argument("--max_new_tokens", type=int, default=128)
    p.add_argument("--epsilon_search", type=float, default=0.0)
    p.add_argument("--epsilon_cap", type=float, default=0.0)
    p.add_argument("--num_steer_samples", type=int, default=32)
    p.add_argument("--value_binary_search_max_iter", type=int, default=2)
    p.add_argument("--norm", type=str, choices=["l2", "linf"], default="l2")
    p.add_argument("--seed", type=int, default=0)
    p.add_argument("--device", type=str, default="cuda:0")
    p.add_argument("--no_bf16", action="store_true")
    p.add_argument("--value_hidden_dims", type=eval, default=[16384, 64])
    p.add_argument("--model_ckpt", type=str, default=None)
    p.add_argument("--debug", action="store_true")
    p.add_argument("--num_shards", type=int, default=1)
    p.add_argument("--shard_idx", type=int, default=0)
    p.add_argument("--aligner_type", type=str, choices=["brt_least_restrictive", "brt_blend", "brt_gradient_ascent", "sap", "recontrol"], required=True)
    p.add_argument("--split_name", type=str, default="train")
    p.add_argument("--value_threshold", type=float, default=0.2)
    p.add_argument("--value_coeff_gamma", type=float, default=0.1)
    p.add_argument("--calibration", action="store_true")
    
    # Dataset-specific args
    p.add_argument("--dataset_name", type=str, default="real_toxicity")
    p.add_argument("--sample_dataset", type=int, default=None)
    p.add_argument("--only_unsafe", action="store_true")
    
    # dataset / IO
    p.add_argument("--prompt", type=str, default="")  # used if not dataset mode

    # Recontrol-specific args
    p.add_argument("--num_recontrol_updates", type=int, default=10)
    p.add_argument("--recontrol_step_size", type=float, default=0.01)

    p.add_argument("--brt_type", type=str, default=None)
    p.add_argument("--batch_size", type=int, default=32)
    p.add_argument("--max_unsafe_prompts", type=int, default=None, help="If you want to sample a subset of the unsafe prompts, only use this arg.")
    
    args = p.parse_args()
        
    # Only load hydra config if using SAP aligner
    sap_cfg = None
    if args.aligner_type == "sap":
        sap_cfg = load_hydra_config()

    aligner_spec = None
    
    save_descriptor = ""

    if args.aligner_type == "brt_least_restrictive":
        aligner_spec = LeastRestrictiveReachableAlignerSpec(
            noise_cfg=NoiseConfig(
                radius_mode="uniform_volume",
                epsilon_search=args.epsilon_search,
                epsilon_cap=args.epsilon_cap,
                K=args.num_steer_samples,
                norm=args.norm,
                include_base=True,
            ),
            value_model_ckpt=args.model_ckpt,
            value_hidden_dims=args.value_hidden_dims,
            value_threshold=args.value_threshold,
        )
        save_descriptor = f"brt_least_restrictive_value_threshold_{args.value_threshold}_epsilon_search_{args.epsilon_search}_epsilon_cap_{args.epsilon_cap}"
    elif args.aligner_type == "brt_blend":
        aligner_spec = SmoothBlendReachableAlignerSpec(
            noise_cfg=NoiseConfig(
                radius_mode="uniform_volume",
                epsilon_search=args.epsilon_search,
                epsilon_cap=args.epsilon_cap,
                K=args.num_steer_samples,
                norm=args.norm,
                include_base=True,
            ),
            value_model_ckpt=args.model_ckpt,
            value_hidden_dims=args.value_hidden_dims,
            value_coeff_gamma=args.value_coeff_gamma,
            value_binary_search_max_iter=args.value_binary_search_max_iter,
        )
        save_descriptor = f"brt_blend_value_coeff_gamma_{args.value_coeff_gamma}_epsilon_search_{args.epsilon_search}_epsilon_cap_{args.epsilon_cap}_value_binary_search_max_iter_{args.value_binary_search_max_iter}"
    elif args.aligner_type == "brt_gradient_ascent":
        aligner_spec = GradientAscentReachableAlignerSpec(
            num_updates=args.num_recontrol_updates,
            step_size=args.recontrol_step_size,
            value_model_ckpt=args.model_ckpt,
            value_hidden_dims=args.value_hidden_dims,
            value_threshold=args.value_threshold,
        )
        save_descriptor = f"brt_gradient_ascent_num_updates_{args.num_recontrol_updates}_step_size_{args.recontrol_step_size}"
    elif args.aligner_type == "sap":
        # Update the config with command line arguments
        if hasattr(sap_cfg, 'model_path') or 'model_path' in sap_cfg:
            sap_cfg.model_path = args.model_name
        if hasattr(sap_cfg, 'polytope_weight_path') or 'polytope_weight_path' in sap_cfg:
            sap_cfg.polytope_weight_path = args.model_ckpt
        # Important: let's set the use_safe_rep_model to True for SAP aligner
        sap_cfg.use_safe_rep_model = True
        aligner_spec = SAPAlignerSpec(cfg=sap_cfg)
    elif args.aligner_type == "recontrol":
        aligner_spec = ReControlAlignerSpec(
            num_updates=args.num_recontrol_updates,
            step_size=args.recontrol_step_size,
            value_model_ckpt=args.model_ckpt,
            value_hidden_dims=args.value_hidden_dims,
        )
        save_descriptor = f"recontrol_num_updates_{args.num_recontrol_updates}_step_size_{args.recontrol_step_size}"
    else:
        raise ValueError(f"Unknown aligner type: {args.aligner_type}")
    
    if "brt" in args.aligner_type:
        assert args.brt_type is not None

    cfg = AlignConfig(
        dataset_name=args.dataset_name,
        split_name=args.split_name,
        model_name=args.model_name,
        layer_idx=args.layer_idx,
        max_new_tokens=args.max_new_tokens,
        seed=args.seed,
        device=args.device,
        bf16=False,
        num_shards=args.num_shards,
        shard_idx=args.shard_idx,
        aligner_spec=aligner_spec,
        save_descriptor=save_descriptor,
        sample_dataset=args.sample_dataset,
        only_unsafe=args.only_unsafe,
        calibration=args.calibration,
        brt_type=args.brt_type,
    )
    return cfg, args

def main():
    cfg, args = parse_args()
    set_seed(cfg.seed)

    # ----------------------------
    # Aligner setup
    # ----------------------------
    if isinstance(cfg.aligner_spec, LeastRestrictiveReachableAlignerSpec):
        aligner = backwards_reachable_aligner.LeastRestrictiveReachableAligner(cfg)
    elif isinstance(cfg.aligner_spec, SmoothBlendReachableAlignerSpec):
        aligner = backwards_reachable_aligner.SmoothBlendReachableAligner(cfg)
    elif isinstance(cfg.aligner_spec, SAPAlignerSpec):
        aligner = sap_aligner.SAPAligner(cfg)
    elif isinstance(cfg.aligner_spec, ReControlAlignerSpec):
        aligner = recontrol_aligner.ReControlAligner(cfg)
    elif isinstance(cfg.aligner_spec, GradientAscentReachableAlignerSpec):
        aligner = backwards_reachable_aligner.GradientAscentReachableAligner(cfg)
    else:
        raise ValueError(f"Unknown aligner spec type: {type(cfg.aligner_spec)}")

    processed_model_name = cfg.model_name.replace('/', '_')

    # ----------------------------
    # Load data
    # ----------------------------
    prompt_path = f"{args.root_dir}/{cfg.dataset_name}_llm_generated/{processed_model_name}/{cfg.split_name}_prompts.pt"
    prompts = torch.load(prompt_path)

    total_num_prompts = len(prompts)
    start_prompt_idx = total_num_prompts // cfg.num_shards * cfg.shard_idx
    end_prompt_idx = total_num_prompts // cfg.num_shards * (cfg.shard_idx + 1) if (cfg.shard_idx < cfg.num_shards - 1) else total_num_prompts
    prompts = prompts[start_prompt_idx:end_prompt_idx]

    responses_path = f"{args.root_dir}/{cfg.dataset_name}_llm_generated/{processed_model_name}/{cfg.split_name}_responses.pt"
    baseline_responses = torch.load(responses_path)[start_prompt_idx:end_prompt_idx]

    scores_path = f"{args.root_dir}/{cfg.dataset_name}_llm_generated/{processed_model_name}/{cfg.split_name}_final_scores.pt"
    if cfg.dataset_name == "beavertails":
        raw_scores = torch.load(scores_path)
        scores = torch.cat([torch.tensor(score) for score in raw_scores])[start_prompt_idx:end_prompt_idx]
    else:
        scores = torch.load(scores_path)[start_prompt_idx:end_prompt_idx]

    is_safes_path = f"{args.root_dir}/{cfg.dataset_name}_llm_generated/{processed_model_name}/{cfg.split_name}_is_safes.pt"
    is_safes = torch.load(is_safes_path)[start_prompt_idx:end_prompt_idx]

    # ----------------------------
    # Classifier
    # ----------------------------
    clf_model_name = "cardiffnlp/twitter-roberta-base-offensive"
    classifier_tokenizer = AutoTokenizer.from_pretrained(clf_model_name)
    classifier = AutoModelForSequenceClassification.from_pretrained(clf_model_name).eval().to(cfg.device)

    # ----------------------------
    # Save dir
    # ----------------------------
    sample_suffix = f"_sampled_{cfg.sample_dataset}" if cfg.sample_dataset is not None else ""
    calibration_str = "" if not cfg.calibration else "/calibration"
    method_str = args.aligner_type 
    if cfg.brt_type is not None:
        method_str = cfg.brt_type + "/" + method_str
    unsafe_only = "" if not args.only_unsafe else "_only_unsafe"
    save_dir = f"{args.root_dir}/alignment_results/seed_{args.seed}/{processed_model_name}/{cfg.dataset_name}/{cfg.split_name}{unsafe_only}{sample_suffix}{calibration_str}/{method_str}/{cfg.save_descriptor}"
    print(f"Saving to {save_dir}", flush=True)
    os.makedirs(save_dir, exist_ok=True)

    # ----------------------------
    # Batching loop
    # ----------------------------
    batch_size = args.batch_size
    prompt_outputs = []
    total_previous_score, total_new_score = 0.0, 0.0
    num_unsafe, correct_unsafe = 0, 0

    if cfg.sample_dataset is not None:
        indices = torch.randperm(len(prompts))[:cfg.sample_dataset]
        prompts = [prompts[i] for i in indices]
        baseline_responses = [baseline_responses[i] for i in indices]
        scores = [scores[i] for i in indices]
        is_safes = [is_safes[i] for i in indices]

    total_time_spent_in_llm_generation = 0.0
    num_unsafe_so_far = 0

    for b_start in tqdm(range(0, len(prompts), batch_size)):
        b_end = min(b_start + batch_size, len(prompts))
        batch_prompts = prompts[b_start:b_end]
        batch_baselines = baseline_responses[b_start:b_end]
        batch_scores = scores[b_start:b_end]
        batch_is_safes = is_safes[b_start:b_end]

        # Skip all-safe if only_unsafe
        batch_data = [(p, br, sc, safe) for p, br, sc, safe in zip(batch_prompts, batch_baselines, batch_scores, batch_is_safes)
                      if not (cfg.only_unsafe and bool(safe))]
        if not batch_data:
            continue
        
        if cfg.only_unsafe:
            num_unsafe_so_far += len(batch_data)
            if args.max_unsafe_prompts is not None and num_unsafe_so_far > args.max_unsafe_prompts:
                batch_data = batch_data[:args.max_unsafe_prompts - (num_unsafe_so_far - len(batch_data))]
                if not batch_data:
                    break

        batch_prompts, batch_baselines, batch_scores, batch_is_safes = zip(*batch_data)

        # Generate aligned responses for the batch
        t_1 = time.perf_counter()
        _, batch_responses, batch_value_preds = aligner.generate_batch_and_response(batch_prompts)
        t_2 = time.perf_counter()
        total_time_spent_in_llm_generation += (t_2 - t_1)

        # Classify responses in batch
        preds, probs = classify_texts(list(batch_responses), torch.device(cfg.device), classifier_tokenizer, classifier)

        # Update stats + save outputs
        for prompt, response, baseline_response, score, is_safe, pred, prob, vpred in zip(
            batch_prompts, batch_responses, batch_baselines, batch_scores, batch_is_safes, preds, probs, batch_value_preds
        ):
            if cfg.only_unsafe and bool(is_safe):
                continue

            print("=== Prompt (unsafe) ===", flush=True)
            print(prompt, flush=True)
            print("Prompt:", prompt, flush=True)
            print("Response:", response, flush=True)
            print(f"Baseline response: {baseline_response}", flush=True)
            print(f"Baseline score: {score}", flush=True)
            print(f"New response score: {prob} (0=safe, 1=offensive)", flush=True)

            total_previous_score += score.item() if isinstance(score, torch.Tensor) else score
            total_new_score += prob[0]  # prob[0] = safe
            num_unsafe += 1
            correct_unsafe += (prob[0] > 0.5)

            print(f"Total previous avg score: {total_previous_score / num_unsafe:.4f}", flush=True)
            print(f"Total new avg score: {total_new_score / num_unsafe:.4f}", flush=True)
            print(f"Correctly newly identified unsafe so far: {correct_unsafe} / {num_unsafe}", flush=True)

            print("Classifier preds/probs:", pred, prob, flush=True)  # 0=safe, 1=unsafe/offensive

            prompt_outputs.append({
                "prompt": prompt,
                "response": response,
                "baseline_response": baseline_response,
                "preds_offensive": pred,
                "probs_safe": prob[0],
                "value_preds": vpred,
                "original_safe": score.item() if isinstance(score, torch.Tensor) else score,
                "new_safe": prob[0],
            })

    # ----------------------------
    # Save results
    # ----------------------------
    save_path = f"{save_dir}/qualitative_prompt_output_shard_{cfg.shard_idx}_of_{cfg.num_shards}_shards.json"
    with open(save_path, "w") as f:
        json.dump(prompt_outputs, f, indent=2)
    print(f"Done! Saved results to {save_path}")

    time_info = {
        "total_time_spent_in_llm_generation": total_time_spent_in_llm_generation,
        "num_prompts_processed": len(prompt_outputs),
        "avg_time_per_prompt": total_time_spent_in_llm_generation / len(prompt_outputs) if prompt_outputs else 0.0,
    }
    time_info_path = f"{save_dir}/time_info_shard_{cfg.shard_idx}_of_{cfg.num_shards}_shards.json"
    with open(time_info_path, "w") as f:
        json.dump(time_info, f, indent=2)
    print(f"Final time info: {time_info}", flush=True)
    print(f"Saved time info to {time_info_path}", flush=True)

if __name__ == "__main__":
    main()
