import os
import re
import prompts
from datetime import datetime
from transformers import AutoTokenizer
import pandas as pd
import json
import requests
import argparse
from tqdm import tqdm
import random
import numpy as np
import torch

SEED = 2025
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

def run(prompt_selection, temp, model_name):
    system_prompt = '''Generate three persuasive counterarguments using the given structure.

- Aim to write in a way that could realistically persuade the original author, while keeping the tone respectful and well-reasoned.
- Do not use first-person language (e.g., "I", "we", "as a").
- Output must be plain text only (no markdown).
- Each counterargument must be at least 10 sentences and under 500 tokens.
- Follow the format exactly with no extra text.'''

    if prompt_selection == "basic":
        base_prompt = prompts.diverse_basic_prompt
    elif prompt_selection == "op_and_persona":
        base_prompt = prompts.diverse_op_and_persona_prompt
        persona_df = pd.read_pickle(f'./data/cluster_diff_persona.pickle')
    elif prompt_selection == "op":
        base_prompt = prompts.diverse_op_prompt
        persona_df = pd.read_pickle(f'./data/cluster_diff_persona.pickle')
    else:
        raise ValueError("Invalid selection. Please choose from: example, example_experience, persona_example, persona_example_experience.")

    timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    match = re.match(r"[a-zA-Z]+", model_name)
    model_prefix = match.group() if match else model_name

    output_dir = "./results/"
    filename = f'base_llm_{model_prefix}_three_{prompt_selection}_temperature_{temp}_{timestamp}.json'
    os.makedirs(os.path.dirname(output_dir), exist_ok=True)

    file_path = os.path.join(output_dir, filename)

    data_df = pd.read_pickle('../data/processed_multiple_test_data.pickle')

    data = [' '.join(conclusion) + ' ' + ' '.join(premises) for conclusion, premises in zip(data_df['conclusion'].tolist(), data_df['premises'].tolist())]
    post_id_list = data_df['post_id'].tolist()
    
    results = []
    for i, (sample, pid) in tqdm(enumerate(zip(data, post_id_list)), total=len(data)):
        if prompt_selection == "op_and_persona" or prompt_selection == "op":
            author_persona = persona_df.loc[persona_df['post_id'] == pid, 'author_persona'].values[0]
            if prompt_selection == "op_and_persona":
                personas = persona_df.loc[persona_df['post_id'] == pid, 'personas'].values[0]
                prompt = base_prompt.format(input=sample, author_persona=author_persona, persona_1=personas[0], persona_2=personas[1], persona_3=personas[2])
            else:
                prompt = base_prompt.format(input=sample, author_persona=author_persona)
        else:
            prompt = base_prompt.format(input=sample)

        wrapped_prompt = f"<think>\n{prompt}\n</think>" if "deepseek" in model_name.lower() else prompt

        response = requests.post(
            'http://localhost:11434/api/generate',
            json={
                "model": model_name,
                "prompt": wrapped_prompt,
                "system": system_prompt,
                "options": {
                    "top_p": 0.95,
                    "num_predict": 3000,
                    "temperature": temp,
                    "seed": SEED
                },
                "stream": False
            }
        )

        if response.status_code == 200 and "response" in response.json():
            output = response.json()["response"]
        else:
            print(f"[Error] {response.text}")
            continue

        results.append({
            "post_id": pid,
            "input": sample,
            "output": output
        })
        
        with open(file_path, 'w') as f:
            json.dump(results, f, indent=4)

def print_run(prompt_selection, temp, model_name):
    print(f"Running with prompt: {prompt_selection}, temperature: {temp}, model: {model_name}")
    print(f"Using GPUs: {os.environ.get('CUDA_VISIBLE_DEVICES')}")

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Run with selected prompt and temperature.')
    parser.add_argument('--prompt_selection', type=str, required=True, help='Name of the prompt to use')
    parser.add_argument('--temp', type=float, default=0.8, help='Temperature for generation (default: 0.8)')
    parser.add_argument('--gpu_num', type=str, default="0", help='Comma-separated GPU device IDs to use (default: "0")')
    parser.add_argument('--model_name', type=str, default="llama3.1:8b-instruct-q8_0", help='Model name for Ollama (e.g., llama3.1:8b or deepseek-r1:8b)')

    args = parser.parse_args()

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_num

    print_run(args.prompt_selection, args.temp, args.model_name)
    run(args.prompt_selection, args.temp, args.model_name)