"""
Main experiment script. Usage: python3 acr-wmdp.py
Modify the configuration file in config/config.yaml to change the experiment settings.
"""

import os
import re
import json
import hydra
import torch
import logging
import numpy as np
from omegaconf import DictConfig
from datasets import load_dataset
import prompt_optimization as prompt_opt
from transformers import AutoModelForCausalLM, AutoTokenizer

def load_example_dataset():
    # Dataset format: list of dictionaries with keys 'hint' and 'target'
    dataset = [
        {'hint': 'The cat leapt over the fence,', 'target': 'chasing after a butterfly.'},
        {'hint': 'The quick brown fox', 'target': 'jumps over the lazy dog.'},
        {'hint': 'The sun was shining', 'target': 'brightly in the sky.'}
    ]
    return dataset

def load_wmdp_dataset(config_name='wmdp-bio', style='choose'):
    # Choose config_name among 'wmdp-bio', 'wmdp-chem', 'wmdp-cyber'
    if not config_name in ['wmdp-bio', 'wmdp-chem', 'wmdp-cyber']:
        raise ValueError(f'Unknown config_name: {config_name}')
    if not style in ['choose', 'option', 'repeat', 'generate']:
        raise ValueError(f'Unknown style: {style}')
    option_map = {
        0 : 'A',
        1 : 'B',
        2 : 'C',
        3 : 'D'
    }
    dataset = load_dataset("cais/wmdp", config_name)
    processed = []
    for split, split_dataset in dataset.items():
        for example in split_dataset:
            if style == 'choose' or style == 'option':
                processed_example = {
                    'hint': f"{example['question']} "
                            f"A. {example['choices'][0]} "
                            f"B. {example['choices'][1]} "
                            f"C. {example['choices'][2]} "
                            f"D. {example['choices'][3]} ",
                    'target': f"Answer: {option_map[example['answer']]}"
                }
            elif style == 'repeat':
                processed_example = {
                    'hint': f"{example['question']} "
                            f"A. {example['choices'][0]} "
                            f"B. {example['choices'][1]} "
                            f"C. {example['choices'][2]} "
                            f"D. {example['choices'][3]}",
                    'target': f"Answer: {option_map[example['answer']]}. {example['choices'][example['answer']]}"
                }
            else:
                processed_example = {
                    'hint': example['question'],
                    'target': example['choices'][example['answer']]
                }
            processed.append(processed_example)
    return processed

def set_random_seed(seed):
    # Set randomness
    if seed is None:
        return
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def set_up_model(model_name='base'):
    if not model_name in ['base', 'rmu', 'npo-bio', 'npo-cyber']:
        raise ValueError(f'Unknown model_name: {model_name}')
    if model_name == 'base':
        model_url = "HuggingFaceH4/zephyr-7b-beta"
    elif model_name == 'rmu':
        model_url = "cais/Zephyr_RMU"
    elif model_name == 'npo-bio':
        model_url = "J4Q8/zephyr-npo-bio"
    elif model_name == 'npo-cyber':
        model_url = "J4Q8/zephyr-npo-cyber"
    device = "cuda" if torch.cuda.is_available() else "cpu"
    if device == "cuda":
        model_args = dict(trust_remote_code=True, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, device_map="auto")
    else:
        model_args = dict(trust_remote_code=False, low_cpu_mem_usage=True)
    model = AutoModelForCausalLM.from_pretrained(model_url, **model_args)
    tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
    return device, model, tokenizer

@hydra.main(version_base=None, config_path="./config", config_name="config")
def main(cfg: DictConfig) -> None:
    logger = logging.getLogger('main')
    logger.info(f'Program started ...')

    logger.info(f'Setting random seed to {cfg.seed}')
    set_random_seed(cfg.seed)

    logger.info(f'Loading dataset {cfg.dataset_name} ...')
    dataset = load_wmdp_dataset(cfg.dataset_name, cfg.style)
    dataset = dataset[cfg.test_start:cfg.test_start+cfg.test_cnt]

    logger.info(f'Loading model {cfg.model_name} ...')
    device, model, tokenizer = set_up_model(cfg.model_name)

    logger.info(f'Evaluating the model ...')
    system_prompt = ""
    chat_template = (cfg.chat_template[0], cfg.chat_template[1])
    optimization_args = {
        'topk': cfg.topk,
        'num_steps': cfg.num_steps,
        'batch_size': cfg.batch_size,
        'mini_batch_size': cfg.mini_batch_size,
        'discrete_optimizer': cfg.discrete_optimizer,
    }
    results = []
    for i, element in enumerate(dataset):
        logger.info(f'Processing example {i+1} / {len(dataset)} ...')
        input_str  = element['hint']
        target_str = element['target']
        logging.getLogger().setLevel(logging.WARNING)
        try:
            if cfg.style == 'choose':
                result = prompt_opt.minimize_prompt_choice_only(model, tokenizer, input_str, target_str, system_prompt,
                                                    chat_template, device, optimization_args, max_tokens=cfg.max_tokens)
            else:
                result = prompt_opt.minimize_prompt(model, tokenizer, input_str, target_str, system_prompt,
                                                    chat_template, device, optimization_args, max_tokens=cfg.max_tokens)
        except Exception as e:
            result = {'success': False, 'error': str(e)}
        logging.getLogger().setLevel(logging.DEBUG)
        logger.info(f'Example {i+1} / {len(dataset)} finished.')
        logger.info(f'    Success: {result["success"]}')
        if result["success"]:
            logger.info(f'    Input: {input_str}')
            logger.info(f'    Target: {target_str}')
            logger.info(f'    Solution: {tokenizer.decode(result["input_ids"])}')
            logger.info(f'    Free token length: {result["num_free_tokens"]}')
        results.append(result)
    
    logger.info(f'Saving results ...')
    with open(f'results/{cfg.model_name}_{cfg.dataset_name}_{cfg.style}_'
              f'{cfg.test_start}_{cfg.test_start+cfg.test_cnt}.json', 'w') as f:
        results = [{'success': result['success'], 'error': result['error']} if not result['success'] and 'error' in result else
                   {'success': result['success']} if not result['success'] else
                   {'success': result['success'], 'free_tokens': result['num_free_tokens']}
                   for result in results]
        json.dump(results, f, indent=2)

if __name__ == "__main__":
    main()
