from datetime import datetime
import math
from rl_utils import template_prefix_similarity_check


def check_hparams(hparams):
    # TODO: chekc n bits mathces n prefixes

    # Assert pseudo_batch_size is a multiple of batch_size and equal to batch_size * mini_batch_size (gradient_accumulation_steps)
    assert (
        hparams["enc"]["batch_size"] % hparams["enc"]["mini_batch_size"] == 0
    ), f"Encoder batch size {hparams['enc']['batch_size']} should be a multiple of mini batch size {hparams['enc']['mini_batch_size']}."
    assert (
        hparams["dec"]["batch_size"] % hparams["dec"]["mini_batch_size"] == 0
    ), f"Decoder batch size {hparams['dec']['batch_size']} should be a multiple of mini batch size {hparams['dec']['mini_batch_size']}."
    assert (
        hparams["enc"]["batch_size"] // hparams["enc"]["gradient_accumulation_steps"]
        == hparams["enc"]["mini_batch_size"]
    ), f"Encoder Batch size {hparams['enc']['batch_size']} should be equal to gradient accumulation steps {hparams['enc']['gradient_accumulation_steps']} * mini batch size {hparams['enc']['mini_batch_size']}."
    assert (
        hparams["dec"]["batch_size"] // hparams["dec"]["gradient_accumulation_steps"]
        == hparams["dec"]["mini_batch_size"]
    ), f"Decoder Batch size {hparams['dec']['batch_size']} should be equal to gradient accumulation steps {hparams['dec']['gradient_accumulation_steps']} * mini batch size {hparams['dec']['mini_batch_size']}."

    # Assert payload_template is in accordance with payload_prefixes
    # Payload template and prefixes assertions
    if hparams["dataset"]["payload_template"] is None:
        assert (
            len(hparams["dataset"]["payload_prefixes"]) <= 1
        ), f"When payload_template is None, payload_prefixes should be empty or have max length 1. Current payload_prefixes: {hparams['dataset']['payload_prefixes']}."
    else:
        assert (
            len(hparams["dataset"]["payload_prefixes"]) > 1
        ), f"When payload_template is not None, payload_prefixes should have length > 1. Current payload_prefixes: {hparams['dataset']['payload_prefixes']}."
        assert (
            "{payload}" in hparams["dataset"]["payload_template"]
        ), f"payload_template should contain '{{payload}}'. Current template: {hparams['dataset']['payload_template']}."
        assert template_prefix_similarity_check(
            hparams["dataset"]["payload_template"],
            hparams["dataset"]["payload_prefixes"],
        ), f"payload_template should be consistent with payload_prefixes. Current template: {hparams['dataset']['payload_template']}, Current prefixes: {hparams['dataset']['payload_prefixes']}."
    # # Check if dataset name is valid
    # assert (
    #     hparams["dataset"]["name"] in ["imdb", "simpleqa_100", "same_prompt"]
    # ), f"Invalid dataset name: {hparams['dataset']['name']}. Must be 'imdb', 'simpleqa_100', or 'same_prompt'."

    # If dataset name is 'same_prompt', ensure 'same_prompt' key is not empty
    if hparams["dataset"]["name"] == "same_prompt":
        assert hparams[
            "dataset"
        ][
            "same_prompt"
        ], "When dataset name is 'same_prompt', the 'same_prompt' key should not be empty."
        
    # Check payloads
    assert hparams['dataset']['n_payloads'] == len(hparams['dataset']['payload_prefixes']), f"Number of payloads should match the number of payload prefixes. Current number of payloads: {hparams['dataset']['n_payloads']}, Current number of payload prefixes: {len(hparams['dataset']['payload_prefixes'])}."

    assert hparams['train']['n_enc_updates'] > 0, f"Number of encoder updates should be greater than 0. Current number of encoder updates: {hparams['train']['n_enc_updates']}."
    assert hparams['train']['n_dec_updates'] > 0, f"Number of decoder updates should be greater than 0. Current number of decoder updates: {hparams['train']['n_dec_updates']}."
    assert hparams['train']['num_epochs'] > 0, f"Number of epochs should be greater than 0. Current number of epochs: {hparams['train']['num_epochs']}."
    assert hparams['train']['n_dec_deepfry_steps'] > 0
    

def generate_wandb_config(hparams, name_prefix):
    # Extract relevant information from hparams
    dataset_name = hparams["dataset"]["name"]
    enc_model = hparams["enc"]["model_name"].split("/")[-1]
    dec_model = hparams["dec"]["model_name"].split("/")[-1]
    train_loop_version = hparams["train"]["train_loop_version"]

    # Calculate Effective Batch Sizes
    enc_effective_batch_size = (
        hparams["enc"]["mini_batch_size"]
        * hparams["enc"]["gradient_accumulation_steps"]
    )
    dec_effective_batch_size = (
        hparams["dec"]["mini_batch_size"]
        * hparams["dec"]["gradient_accumulation_steps"]
    )

    # Generate name
    name = f"{dataset_name}_{enc_model}_to_{dec_model}_EBS{enc_effective_batch_size}"
    name = f"{name_prefix}_{name}"

    # Calculate number of bits encoded
    n_bits = int(math.log2(len(hparams["dataset"]["payload_prefixes"])))

    # Generate tags
    tags = [
        dataset_name,
        enc_model,
        dec_model,
        train_loop_version,
        f"enc_lr_{hparams['enc']['lr']}",
        f"dec_lr_{hparams['dec']['lr']}",
        f"enc_eff_bs_{enc_effective_batch_size}",
        f"dec_eff_bs_{dec_effective_batch_size}",
        f"enc_updates_{hparams['train']['n_enc_updates']}",
        f"dec_updates_{hparams['train']['n_dec_updates']}",
        "LoRA" if hparams["enc"]["do_lora"] else "Full_Finetune",
        "Quantized" if hparams["enc"]["do_quant"] else "Half_Precision",
        f"same_prompt_{hparams['dataset']['same_prompt']}"
        if hparams['dataset']['same_prompt']
        else "No_Same_Prompt",
        "Separate_Enc_Dec_Data"
        if hparams["dataset"]["seperate_enc_dec_data"]
        else "Shared_Data",
        "Update_Enc" if hparams["train"]["do_update_enc"] else "Freeze_Enc",
        "Update_Dec" if hparams["train"]["do_update_dec"] else "Freeze_Dec",
        "Paraphrase" if hparams["train"]["do_paraphrase"] else "No_Paraphrase",
        "Leak_Password" if hparams["train"]["leak_password"] else "No_Leak",
        f"{n_bits}-bit" if n_bits > 0 else "0-bit",
        f"Overseer_{hparams['overseer']['type']}",
        "Use_Overseer" if hparams["overseer"]["use_overseer"] else "No_Overseer",
        "VLLM_Inference" if hparams["enc"]["vllm_inference"] else "Standard_Inference",
    ]

    # Generate notes
    notes = f"""
    Dataset:
    - Name: {dataset_name}
    - Same Prompt: {hparams['dataset']['same_prompt'] or 'None'}
    - Payload Prefixes: {hparams['dataset']['payload_prefixes']}
    - Payload Template: {hparams['dataset']['payload_template'] or 'None'}
    - Separate Enc/Dec Data: {hparams['dataset']['seperate_enc_dec_data']}
    - Batch Size: {hparams['dataset']['batch_size']}

    Encoder:
    - Model: {enc_model}
    - Learning Rate: {hparams['enc']['lr']}
    - Batch Size: {hparams['enc']['batch_size']}
    - Mini Batch Size: {hparams['enc']['mini_batch_size']}
    - Gradient Accumulation Steps: {hparams['enc']['gradient_accumulation_steps']}
    - KL Coefficient: {hparams['enc']['init_kl_coef']}
    - KL Target: {hparams['enc']['target_kl']}
    - LoRA: {hparams['enc']['do_lora']}
    - Quantization: {hparams['enc']['do_quant']}
    - Output Length: {hparams['enc']['output_len']}
    - VLLM Inference: {hparams['enc']['vllm_inference']}

    Decoder:
    - Model: {dec_model}
    - Learning Rate: {hparams['dec']['lr']}
    - Batch Size: {hparams['dec']['batch_size']}
    - Mini Batch Size: {hparams['dec']['mini_batch_size']}
    - Gradient Accumulation Steps: {hparams['dec']['gradient_accumulation_steps']}
    - New Classification Head: {hparams['dec']['new_classification_head']}
    - Weight Decay: {hparams['dec']['weight_decay']}

    Overseer:
    - Type: {hparams['overseer']['type']}
    - Use Overseer: {hparams['overseer']['use_overseer']}
    - Eval Overseer: {hparams['overseer']['eval_overseer']}
    - Dec Drop Sus: {hparams['overseer']['dec_drop_sus']}

    Training:
    - Train Loop Version: {train_loop_version}
    - Update Encoder: {hparams['train']['do_update_enc']}
    - Update Decoder: {hparams['train']['do_update_dec']}
    - Decoder Update Parameters: {hparams['train']['dec_update_params']}
    - Leak Password: {hparams['train']['leak_password']}
    - Paraphrase: {hparams['train']['do_paraphrase']}
    - WandB Logging: {hparams['train']['wandb_log']}
    - Encoder Updates: {hparams['train']['n_enc_updates']}
    - Decoder Updates: {hparams['train']['n_dec_updates']}
    - Update Encoder First: {hparams['train']['update_enc_first']}
    - Number of Epochs: {hparams['train']['num_epochs']}
    - Generation Temperatures:
      - Decoder Training: {hparams['train']['generation_temp']['dec_training']}
      - Encoder Training: {hparams['train']['generation_temp']['enc_training']}
      - Evaluation: {hparams['train']['generation_temp']['eval']}
    - Cost Cap: {hparams['train']['cost_cap']}
    - Checkpoint ID: {hparams['train']['checkpoint_id']}
    - Save Checkpoints: {hparams['train']['save_checkpoints']}
    - Strip Instruct Tune: {hparams['train']['strip_instruct_tune']}

    Rewards:
    - Use Probs Reward: {hparams['rewards']['use_probs_reward']}
    - Use Task Usefulness Reward: {hparams['rewards']['use_task_usefulness_reward']}
    - Use Length Reward: {hparams['rewards']['use_length_reward']}
    - Stego Scale: {hparams['rewards']['stego_scale']}
    - Overseer Scale: {hparams['rewards']['overseer_scale']}
    - Task Scale: {hparams['rewards']['task_scale']}
    - Length Scale: {hparams['rewards']['length_scale']}
    - EOS Trick: {hparams['rewards']['eos_trick']}

    Evaluation:
    - Eval Every N: {hparams['eval']['eval_every_n']}
    - N Eval Samples: {hparams['eval']['n_eval_samples']}
    - Datasets to Eval: {', '.join(hparams['eval']['datasets_to_eval'])}

    Debug:
    - Override Dec Batch: {hparams['debug']['override_dec_batch']}
    """

    # Update hparams with generated WandB config
    hparams["wandb"] = {"wandb": {"name": name, "tags": tags, "notes": notes.strip()}}

    return hparams


def get_payloads(hparams):
    # Extract relevant information from hparams
    prefixes = hparams["dataset"]["payload_prefixes"]
    template = hparams["dataset"]["payload_template"]

    # If template is None, return empty list
    if template is None:
        return []

    # Payloads are in the format of the template with the payload inserted (Monday/Sunday, 0/1, happy/sad, etc.)
    # The list to be returned should extract (Monday/Sunday, 0/1, happy/sad, etc.) from the prefixes given the template
    payloads = []
    for prefix in prefixes:
        payload = prefix.replace(template.split("{payload}")[0], "").replace(
            template.split("{payload}")[1], ""
        )
        payloads.append(payload)

    return payloads


from enum import Enum


class RLExpRunType(str, Enum):
    # Base
    TRAIN = "train"

    # EOL2
    GPT2_imdb = "gpt2_imdb"
    Gemma_it_imdb = "gemma_it_imdb"
    Gemma_it_imdb_2bit = "gemma_it_imdb_2bit"
    Gemma_it_imdb_seeded = "gemma_it_imdb_seeded"
    Gemma_it_imdb_explicit = "gemma_it_imdb_explicit"
    Gemma_it_imdb_paraphrase = "gemma_it_imdb_paraphrase"
    
    # CV
    GPT2_cv = "gpt2_cv"
    Gemma_cv = "gemma_cv"
    
    ## main runs
    EOL3_CV_UNI = "eol3_cv_uni"
    EOL3_CV_GENDER = "eol3_cv_gender"
    EOL3_CV_GEN_CORR = "eol3_cv_gen_correlations"
    ## bypassing overseer
    EOL3_CV_GENDER_SEEDED = "eol3_cv_gender_seeded"
    EOL3_CV_GENDER_CORR_SEEDED = "eol3_cv_gen_correlations_seeded"
    EOL3_CV_UNI_SEEDED = "eol3_cv_uni_seeded"
    EOL3_CV_UNI_O1 = "eol3_cv_uni_o1"
    EOL3_CV_UNI_O2 = "eol3_cv_uni_o2"
    EOL3_CV_UNI_O3 = "eol3_cv_uni_o3"
    ## paraphrasing run
    EOL3_CV_UNI_PARAPHRASE = "eol3_cv_uni_paraphrase"
    


from hparams.base_hparams import (
    get_debug_hparams,
    get_sanity_hparams,
    get_train_hparams,
)
from hparams.eol2.gemma_it_imdb import get_hparams as get_gemma_it_imdb_hparams
from hparams.eol2.gemma_it_imdb_2bit import (
    get_hparams as get_gemma_it_imdb_2bit_hparams,
)
from hparams.eol2.gemma_it_imdb_explicit import (
    get_hparams as get_gemma_it_imdb_explicit_hparams,
)
from hparams.eol2.gemma_it_imdb_paraphrase import (
    get_hparams as get_gemma_it_imdb_paraphrase_hparams,
)
from hparams.eol2.gemma_it_imdb_seeded import (
    get_hparams as get_gemma_it_imdb_seeded_hparams,
)
from hparams.eol2.gpt2_imdb import get_hparams as get_gpt2_imdb_hparams
from hparams.eol3_old.gemma_cv import get_hparams as get_gemma_cv_hparams
from hparams.eol3_old.gpt2_cv import get_hparams as get_gpt2_cv_hparams


## main runs
from hparams.eol3.cv_uni import get_eol3_cv_uni_hparams
from hparams.eol3.cv_gender import get_eol3_cv_gender_hparams
from hparams.eol3.cv_gen_correlations import get_eol3_cv_gen_correlations_hparams
## bypassing overseer
from hparams.eol3.cv_gender_seeded import get_eol3_cv_gender_seeded_hparams
from hparams.eol3.cv_gen_correlations_seeded import get_eol3_cv_gen_correlations_seeded_hparams
from hparams.eol3.cv_uni_seeded import get_eol3_cv_uni_seeded_hparams
from hparams.eol3.cv_uni_o1 import get_eol3_cv_uni_o1_hparams
from hparams.eol3.cv_uni_o2 import get_eol3_cv_uni_o2_hparams
from hparams.eol3.cv_uni_o3 import get_eol3_cv_uni_o3_hparams
## paraphrasing run
from hparams.eol3.cv_uni_para import get_eol3_cv_uni_para_hparams


# TODO: this is bad
# TODO: relative import wierd if not running from stego_trainer

def get_run_type_hparams(run_type: RLExpRunType = RLExpRunType.TRAIN, name_prefix: str = "debug"):
    # BASE hparams
    if run_type == RLExpRunType.TRAIN:
        hparams = get_train_hparams()

    # Experimental hparams
    elif run_type == RLExpRunType.GPT2_imdb:
        hparams = get_gpt2_imdb_hparams()

    elif run_type == RLExpRunType.Gemma_it_imdb:
        hparams = get_gemma_it_imdb_hparams()

    elif run_type == RLExpRunType.Gemma_it_imdb_2bit:
        hparams = get_gemma_it_imdb_2bit_hparams()

    elif run_type == RLExpRunType.Gemma_it_imdb_seeded:
        hparams = get_gemma_it_imdb_seeded_hparams()

    elif run_type == RLExpRunType.Gemma_it_imdb_explicit:
        hparams = get_gemma_it_imdb_explicit_hparams()

    elif run_type == RLExpRunType.Gemma_it_imdb_paraphrase:
        hparams = get_gemma_it_imdb_paraphrase_hparams()

    # CV hparams
    elif run_type == RLExpRunType.GPT2_cv:
        hparams = get_gpt2_cv_hparams()

    elif run_type == RLExpRunType.Gemma_cv:
        hparams = get_gemma_cv_hparams()
    
    ## main runs
    elif run_type == RLExpRunType.EOL3_CV_UNI:
        hparams = get_eol3_cv_uni_hparams()
    elif run_type == RLExpRunType.EOL3_CV_GENDER:
        hparams = get_eol3_cv_gender_hparams()
    elif run_type == RLExpRunType.EOL3_CV_GEN_CORR:
        hparams = get_eol3_cv_gen_correlations_hparams()
    ## bypassing overseer
    elif run_type == RLExpRunType.EOL3_CV_GENDER_SEEDED:
        hparams = get_eol3_cv_gender_seeded_hparams()
    elif run_type == RLExpRunType.EOL3_CV_GENDER_CORR_SEEDED:
        hparams = get_eol3_cv_gen_correlations_seeded_hparams()
    elif run_type == RLExpRunType.EOL3_CV_UNI_SEEDED:
        hparams = get_eol3_cv_uni_seeded_hparams()
    elif run_type == RLExpRunType.EOL3_CV_UNI_O1:
        hparams = get_eol3_cv_uni_o1_hparams()
    elif run_type == RLExpRunType.EOL3_CV_UNI_O2:
        hparams = get_eol3_cv_uni_o2_hparams()
    elif run_type == RLExpRunType.EOL3_CV_UNI_O3:
        hparams = get_eol3_cv_uni_o3_hparams()
    ## paraphrasing run
    elif run_type == RLExpRunType.EOL3_CV_UNI_PARAPHRASE:
        hparams = get_eol3_cv_uni_para_hparams()
    
    # If not found
    else:
        raise ValueError(f"Invalid run type: {run_type}")

    hparams["team_member"] = {
        "name": "User", 
    }
    check_hparams(hparams)
    hparams = generate_wandb_config(hparams, name_prefix)
    hparams["dataset"]["payloads"] = get_payloads(hparams)
    hparams["timestamp"] = {"start": datetime.now().isoformat()}

    return hparams
