import pandas as pd
import openai
import json
import os
import argparse
import requests

from time import sleep
from tqdm import tqdm


def make_api_call_with_retry(prompt: str, model_call_fn, client, max_retries=2, delay=1.0):
    """
    Make an API call with retry logic and JSON parsing.
    """
    last_exception = None
    
    for attempt in range(max_retries):
        try:
            # Make the API call
            response = model_call_fn(prompt, client)
            
            # Try to parse the JSON response
            parsed_response = json.loads(response)
            return parsed_response
            
        except json.JSONDecodeError as e:
            last_exception = e
            print(f"JSON parsing failed on attempt {attempt + 1}/{max_retries}: {e}")
        except requests.exceptions.HTTPError as e:
            last_exception = e
            print("Rate limit, retry in 4 mins")
            sleep(240)
        except Exception as e:
            last_exception = e
            print(f"API call failed on attempt {attempt + 1}/{max_retries}: {e}")
            print(type(e))
        
        # Wait before retrying (except on the last attempt)
        if attempt < max_retries - 1:
            sleep(delay)
    
    # All attempts failed
    print(f"All {max_retries} attempts failed. Last error: {last_exception}")
    return None

def create_default_response():
    return {
        "Reasoning": "Failed to parse LLM response",
        "Rating": None,
        "Formatted_instruction": "Failed to get response"
    }

def gemini_answer(prompt: str, client) -> str:
    observation = AI.ask(prompt, client=client)
    output_text = observation.response

    ans = output_text.split("<answer>")[-1].split("</answer>")[0].strip()
    return ans

def anthropic_answer(prompt, client):
    observation = AI.ask(prompt=prompt, client=client)
    output_text = observation.response

    ans = output_text.split("<answer>")[-1].split("</answer>")[0].strip()
    return ans
        
def main(args):
    # If using the initial categorization file, add the new columns

    df = pd.read_csv(args.input_csv, index_col=0)

    if args.is_initial_csv:
        # Initialize new columns for the filtering results
        df["filter_reasoning"] = ""
        df["rating"] = None
        df["formatted_instruction"] = ""

    with open("prompts/filtering_prompt.txt", "r") as f:
        prompt_template = f.read()


    # select model calling fn and client
    model_call_fn = anthropic_answer 

    client = AnthropicClient(model_name="claude-sonnet-4-20250514")
    # client = GoogleAIClient(model_name="gemini-2.5-pro", api_key=os.environ["IGEMINI_API_KEY"])
    # client = GoogleAIClient(model_name="gemini-2.5-flash", api_key=os.environ["IGEMINI_API_KEY"])

    processed_count = 0
    for i, row in tqdm(df.iterrows(), total=len(df), desc="Rating"):
        # Skip AI model call if rating already exists
        if pd.notna(df.at[i, "rating"]) and df.at[i, "rating"] != "":
            print(f"Skipping row {i}: rating already exists")
            continue
            
        prompt = prompt_template.format(category=row.llm, instruction=row.instruction)
        
        # Use the retry mechanism for API call and JSON parsing
        parsed_ans = make_api_call_with_retry(prompt, model_call_fn, client)
        
        # If all retries failed, use default response
        if parsed_ans is None:
            parsed_ans = create_default_response()

        # Store the parsed results in separate columns
        if isinstance(parsed_ans, dict):
            df.at[i, "filter_reasoning"] = parsed_ans.get("Reasoning", "")
            df.at[i, "rating"] = parsed_ans.get("Rating", None)
            # df.at[i, "formatted_instruction"] = parsed_ans.get("Formatted_instruction", "")
        else:
            # Fallback if parsed_ans is not a dict (like empty string)
            df.at[i, "filter_reasoning"] = "Failed to get response"
            df.at[i, "rating"] = None
            # df.at[i, "formatted_instruction"] = "Failed to get response"

        processed_count += 1
        
        # Save dataframe every 75 processed rows
        if processed_count % 75 == 0:
            df.to_csv(args.output_csv)
            print(f"Intermediate save: {processed_count} rows processed, saved to {args.output_csv}")

        # if you want to early stop to inspect outputs
        # if i >= 2:
        #     break
     
    # Save the results to a new CSV file
    df.to_csv(args.output_csv)
    print(f"Results saved to {args.output_csv}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_csv", type=str, help="Input csv file name with instructions and llm categories", required=True)
    parser.add_argument("--is_initial_csv", action="store_true", help="Set true if the input csv is the initial categorization file (does not have filter_reasoning and rating columns)")
    parser.add_argument("--output_csv", type=str, help="Output csv file names for the ratings", required=True)
    args = parser.parse_args()
    main(args)