import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from .config import config

# Load the LLaMA 2-7B Chat model and tokenizer
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

def generate(prompt, headline, n_variations):
    """General function to generate variations using LLaMA locally."""
    
    max_new_tokens = config.get("data_gen", "max_new_tokens")
    headline_generation_temperture = config.get("data_gen", "headline_generation_temperture")
    
    # Clear and concise prompt with a delimiter
    messages = [
        {"role": "system", "content": prompt},
        {"role": "user", "content": headline},
    ]

    # Tokenize input
    input_ids = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt",
    ).to(model.device)

    terminators = [
        tokenizer.eos_token_id,
        tokenizer.convert_tokens_to_ids("<|eot_id|>")
    ]

    # Generate the output
    outputs = model.generate(
        input_ids,
        max_new_tokens=max_new_tokens,
        eos_token_id=terminators,
        do_sample=True,
        temperature=headline_generation_temperture,
        top_p=0.9,
        num_return_sequences=n_variations,
        pad_token_id=tokenizer.eos_token_id,
        repetition_penalty=1.2,
    )
        
    # Decode and store all the variations
    variations = []
    for output in outputs:
        response = output[input_ids.shape[-1]:]  # Strip the prompt part from the generated sequence
        response_string = tokenizer.decode(response, skip_special_tokens=True).strip()
        variations.append(response_string)
    
    return variations


def generate_alternatives(headline, n_rephrased):
    """Generates n_rephrased alternatives for a given headline."""
    
    if config.get("dataset", "dataset") == "nifty":
        rephrase_prompt = "Please reword this headline for me, preserving the exact semantic meaning perfectly. Your returned headline should contain the exact information with no meaning added or subtracted, but just rephrased. Please generate the headline, and return only that with no other text. Thanks:"
    elif config.get("dataset", "dataset") == "bigdata22":
        rephrase_prompt = "Please reword this tweet for me, preserving the exact semantic meaning perfectly. Your returned tweet should contain the exact information with no meaning added or subtracted, but just rephrased. Please generate the tweet, and return only that with no other text. Thanks:"
    elif config.get("dataset", "dataset") == "imdb":
        rephrase_prompt = "Please reword this movie review for me, preserving the exact semantic meaning perfectly. Your returned review should contain the exact information with no meaning added or subtracted, but just rephrased. Ensure the generated review is a similar length. Please generate the review, and return only that with no other text. Thanks:"
    
    alternatives = generate(rephrase_prompt, headline, n_rephrased)
    return alternatives

def generate_ablation(headline, n_ablations):
    """Generates n_ablations ablations for a given headline."""
    
    if config.get("dataset", "dataset") == "nifty":
        ablation_prompt = "Please modify this headline slightly, so it is about something related but different. If the headline is good news, ensure it remains good news, and if it is bad news, ensure it remains bad news. Please generate the headline, and return only that with no other text. Thanks:"
    elif config.get("dataset", "dataset") == "bigdata22":
        ablation_prompt = "Please modify this tweet slightly, so it is about something related but different. If the tweet is good news, ensure it remains good news, and if it is bad news, ensure it remains bad news. Please generate the tweet, and return only that with no other text. Thanks:"
    elif config.get("dataset", "dataset") == "imdb":
        ablation_prompt = "Please modify this movie review slightly, so it is about something related but different. If the review is good, ensure it remains good, and if it is bad, ensure it remains bad. Ensure the generated review is a similar length. Please generate the review, and return only that with no other text. Thanks:" 
    
    abalations = generate(ablation_prompt, headline, n_ablations)
    return abalations

def generate_negative(headline, n_negatives):
    """Generates n_negatives negative headlines for a given headline."""
    
    if config.get("dataset", "dataset") == "nifty":
        negative_prompt = "Please reword this headline for me such that the information is the same except that it now is about the opposite meaning. Please generate the headline, and return only that with no other text. Thanks:"
    elif config.get("dataset", "dataset") == "bigdata22":
        negative_prompt = "Please reword this tweet for me such that the information is the same except that it now is about the opposite meaning. Please generate the tweet, and return only that with no other text. Thanks:"
    elif config.get("dataset", "dataset") == "imdb":
        negative_prompt = "Please reword this movie review for me such that the information is the same except that it now is about the opposite meaning. Ensure the generated review is a similar length. Please generate the review, and return only that with no other text. Thanks:"

    negatives = generate(negative_prompt, headline, n_negatives)
    return negatives