
from transformers import Gemma3ForConditionalGeneration, Gemma3ForCausalLM, AutoModelForCausalLM, AutoProcessor, AutoTokenizer
import torch
import re
import json

def init_agent_wikiarts(base_model: str, adapter_path: str = None):
    '''
    Initialize the agent with vision capabilities.
    '''
    if 'gemma' in base_model:
        model = Gemma3ForConditionalGeneration.from_pretrained(
                base_model, 
                device_map="auto",
                attn_implementation="eager",
                torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
        ).eval()
    else: 
        model = AutoModelForCausalLM.from_pretrained(
            base_model,
            attn_implementation="eager",
            torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
            device_map="auto"
        ).eval()
    processor = AutoProcessor.from_pretrained(base_model)
    return model, processor

def init_agent(base_model: str, adapter_path: str = None):
    if 'gemma' in base_model:
        model = Gemma3ForCausalLM.from_pretrained(
            base_model,
            attn_implementation="eager",
            torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
            device_map="auto"
        )
    else: 
        model = AutoModelForCausalLM.from_pretrained(
            base_model,
            attn_implementation="eager",
            torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
            device_map="auto"
        )
    tokenizer = AutoTokenizer.from_pretrained(base_model)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
            
    return model, tokenizer

def extract_agent_answer(output_text):
    """Extract and clean the agent's answer from a single model output string."""
    agent_answer = output_text.split("<start_of_turn>model")[-1].split("<end_of_turn>")[0].strip()
    cleaned_data = agent_answer.replace('\'', '"')
    cleaned_data = re.sub(r'^\s*\*\*\s*', '', agent_answer)
    cleaned_data = re.sub(r'\s*\*\*\s*$', '', cleaned_data)

    match_json = re.search(r'\{.*?\}', cleaned_data)
    if match_json:
        json_like = match_json.group()
        try:
            json_like_cleaned = re.sub(r',\s*\}', '}', json_like)
            parsed = json.loads(json_like_cleaned)
            if isinstance(parsed, dict) and len(parsed) == 1:
                answer_key = list(parsed.keys())[0]
                answer_text = parsed[answer_key]
                if re.match(r'^[A-Z]$', answer_key):
                    return agent_answer, answer_key, answer_text
                else:
                    print(f"Parsed JSON key '{answer_key}' is not a single uppercase letter.")
        except json.JSONDecodeError as e:
            print(f"Error parsing JSON from extracted string: {e}")
            print(f"Extracted string: {json_like}")

    match_pattern = re.search(r'^\s*([A-Z])\s*\.\s*(.*)', cleaned_data, re.MULTILINE)
    if match_pattern:
        answer_key = match_pattern.group(1).strip()
        answer_text = match_pattern.group(2).strip()

        return agent_answer, answer_key, answer_text

    match_text_only_json = re.search(r'^\s*\{\s*"(.*?)"\s*\}\s*$', cleaned_data)
    if match_text_only_json:
        answer_text = match_text_only_json.group(1).strip()
        print(f"Found text-only JSON pattern: '{answer_text}'")
        return agent_answer, None, answer_text
    
    broken_json_match = re.search(r'\{\s*"([A-Z])"\s*:\s*"([^"]+)$', cleaned_data)
    if broken_json_match:
        try:
            answer_key = broken_json_match.group(1)
            partial_text = broken_json_match.group(2)
            fixed_json = f'{{"{answer_key}":"{partial_text}"}}'
            parsed = json.loads(fixed_json)
            if isinstance(parsed, dict) and len(parsed) == 1:
                answer_text = parsed[answer_key]
                return agent_answer, answer_key, answer_text
        except json.JSONDecodeError as e:
            print(f"Failed to fix broken JSON: {e}")
    
    key_value_match = re.search(r'^([A-Z])\s*:\s*"(.+?)"?$', cleaned_data)
    if key_value_match:
        answer_key = key_value_match.group(1)
        answer_text = key_value_match.group(2).strip('"')
        print(f"Found key-value pattern: key={answer_key}, text={answer_text}")
        return agent_answer, answer_key, answer_text
        
    incomplete_json_match = re.search(r'^([A-Z])\s*"\s*:\s*"(.+)$', cleaned_data)
    if incomplete_json_match:
        answer_key = incomplete_json_match.group(1)
        answer_text = incomplete_json_match.group(2).strip()
        print(f"Found incomplete JSON pattern: key={answer_key}, text={answer_text}")
        return agent_answer, answer_key, answer_text
    
    print(f"Could not find JSON, 'Letter. Text', key-value, or '{{\"Text only\"}}' pattern in agent answer: {cleaned_data}")
    return agent_answer, None, None

def extract_json_from_text(text):
    """
    Extracts the first JSON block enclosed in ```json ... ``` from a text string.
    Handles common "smart quote" issues and encoding issues before parsing.

    Args:
        text (str): The input text containing the JSON block.

    Returns:
        tuple: (raw_text, emotions, explanation)
               Returns (text, None, None) if parsing fails or no JSON block is found.
               Returns (parsed_json, emotions, explanation) on success.
    """
    try:
        try:
            text = text.encode('latin1').decode('utf-8')
        except UnicodeDecodeError:
            try:
                text = text.encode('cp1252').decode('utf-8')
            except UnicodeDecodeError:
                pass 
    except Exception as e:
        print(f"Warning: Encoding fix attempt failed: {e}")
        pass 

    match = re.search(r'```json\s*(\{.*?\})\s*```', text, re.DOTALL | re.IGNORECASE) 
    if match:
        json_string = match.group(1)
        json_string_cleaned = json_string.replace('"', '"').replace('"', '"')
        json_string_cleaned = json_string_cleaned.replace(''', "'").replace(''', "'")
        json_string_cleaned = json_string_cleaned.replace('\\', '\\\\')

        try:
            parsed_json = json.loads(json_string_cleaned)
            emotions = parsed_json.get("emotions") 
            explanation = parsed_json.get("explanation") 
            return parsed_json, emotions, explanation
        except json.JSONDecodeError as e:
            print(f"Error: Found JSON block but failed to parse: {e}")
            print(f"Original JSON string: {json_string}")
            print(f"Cleaned JSON string attempt 1: {json_string_cleaned}")

            json_string_cleaned_alt = json_string_cleaned.replace("'", "\\'") # Escape apostrophes
            try:
                parsed_json = json.loads(json_string_cleaned_alt)
                emotions = parsed_json.get("emotions")
                explanation = parsed_json.get("explanation")
                print("Info: Successfully parsed after escaping apostrophes.")
                return parsed_json, emotions, explanation
            except json.JSONDecodeError as e2:
                print(f"Error: Parsing failed even after escaping apostrophes: {e2}")
                print(f"Cleaned JSON string attempt 2: {json_string_cleaned_alt}")
                return json_string, None, None 

    else:
        try:
            text_cleaned = text.replace('"', '"').replace('"', '"')
            text_cleaned = text_cleaned.replace(''', "'").replace(''', "'")
            text_cleaned = text_cleaned.replace('\\', '\\\\')
            parsed_json = json.loads(text_cleaned.strip())
            emotions = parsed_json.get("emotions")
            explanation = parsed_json.get("explanation")
            print("Info: No ```json block, but parsed the entire text as JSON.")
            return parsed_json, emotions, explanation
        except json.JSONDecodeError:
            if text and not text.isspace():
                 print("Error: No JSON block found and the text itself is not valid JSON.")
            return text, None, None

def get_agent_responses(model, tokenizer, questions_dict, pd_questions, context, batch_size=64):
    """
    Queries the agent model for responses to a given set of questions with context using batch processing.

    Args:
        model: The language model.
        tokenizer: The tokenizer for the model.
        questions_dict: A dictionary where keys are question IDs and values are question texts.
        pd_questions: DataFrame containing question metadata (options, ordinals).
        context: The context string to prepend to each question prompt.
        batch_size: Number of questions to process in parallel (default: 8)
    """
    results = []
    question_ids = list(questions_dict.keys())
    
    # Process questions in batches
    for i in range(0, len(question_ids), batch_size):
        batch_ids = question_ids[i:i + batch_size]
        batch_questions = [questions_dict[qid] for qid in batch_ids]
        batch_outputs = []
    
        messages_batch = [build_messages(q, context) for q in batch_questions]
        prompts_batch = [tokenizer.apply_chat_template(m, tokenize=False, add_generation_prompt=True) for m in messages_batch]
        inputs_batch = tokenizer(prompts_batch, return_tensors="pt", padding=True).to(model.device)
        with torch.inference_mode():
            outputs = model.generate(
                **inputs_batch,
                do_sample=True,
                temperature=1.0,
                pad_token_id=tokenizer.pad_token_id,
                max_new_tokens=1024
            )
            
        output_texts = [tokenizer.decode(output) for output in outputs]
        if isinstance(context, list):
            batch_outputs = [extract_json_from_text(text) for text in output_texts]
        else:
            batch_outputs = [extract_agent_answer(text) for text in output_texts]
        
        # Process results for each question in batch
        for qid, (raw_answer, answer_key, answer_text) in zip(batch_ids, batch_outputs):
            question = questions_dict[qid]
            question_row = pd_questions[pd_questions['key'] == qid]
            options = eval(question_row['references'].values[0])
            option_values = eval(question_row['option_ordinal'].values[0])
            max_val = max(option_values)
            min_val = min(option_values)
            
            try:
                if answer_key and answer_text:
                    try:
                        answer_idx = options.index(answer_text)
                        derived_key = chr(65 + answer_idx)
                        if derived_key != answer_key:
                            answer_key = derived_key
                        answer_ordinal = option_values[answer_idx]
                        results.append((qid, question, answer_key, answer_text, answer_ordinal, raw_answer, max_val, min_val))
                    except ValueError:
                        if answer_key and 'A' <= answer_key <= chr(65 + len(options) - 1):
                            answer_idx = ord(answer_key) - 65
                            answer_text_from_key = options[answer_idx]
                            answer_ordinal = option_values[answer_idx]
                            results.append((qid, question, answer_key, answer_text_from_key, answer_ordinal, raw_answer, max_val, min_val))
                        else:
                            print(f"Error QID {qid}: Invalid key '{answer_key}' or text '{answer_text}' not in options. Storing None.")
                            results.append((qid, question, None, None, None, raw_answer, max_val, min_val))
                else:
                    print(f"Error QID {qid}: Could not extract answer key/text. Raw: '{raw_answer}'")
                    results.append((qid, question, None, None, None, raw_answer, max_val, min_val))

            except Exception as e:
                print(f"Error processing result for question {qid}: {e}")
                results.append((qid, question, None, None, None, raw_answer, max_val, min_val))
                
    return results

def get_agent_responses_wikiarts(model, tokenizer, questions_dict, pd_questions, context, batch_size=4):
    results = []
    question_ids = list(questions_dict.keys())
    
    # Process questions in batches
    for i in range(0, len(question_ids), batch_size):
        batch_ids = question_ids[i:i + batch_size]
        batch_questions = [questions_dict[qid] for qid in batch_ids]
        batch_outputs = []
        
        
        batch_messages_list = [build_messages(q, context) for q in batch_questions]

        try:
            inputs = tokenizer.apply_chat_template(
                batch_messages_list,
                add_generation_prompt=True,
                tokenize=True,
                return_dict=True,
                return_tensors="pt",
                padding=True 
            ).to(model.device) 


            input_len = inputs["input_ids"].shape[-1]

            with torch.inference_mode():
                generation = model.generate(
                    **inputs,
                    max_new_tokens=512,
                    do_sample=True, 
                    temperature=1.0, 
                    pad_token_id=tokenizer.tokenizer.pad_token_id 
                )

            generated_tokens = generation[:, input_len:]

            # Batch decode
            decoded_batch = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)

            # Process results for the batch
            for i, decoded_text in enumerate(decoded_batch):
                try:
                    raw_or_parsed_json, emotion, explanation = extract_json_from_text(decoded_text)
                    results.append((batch_ids[i], batch_questions[i], raw_or_parsed_json, emotion, explanation))
                except Exception as e:
                    print(f"Error processing decoded text: {e}")
                    results.append((batch_ids[i], batch_questions[i], None, None, None))
                
        except Exception as e:
            print(f"Error processing batch: {e}")
    return results

def build_messages(prompt_question, context):
    if isinstance(context, list):
        msg_content = context.copy()
        image_path = prompt_question
        msg_content.append({"type": "text", "text": "Painting:\n"})
        msg_content.append({"type": "image", "image": image_path})
        msg_content.append({"type": "text", "text": "Response: "})
        return [{   
            "role": "user",
            "content": msg_content
        }]

    elif isinstance(context, str):
        message_text = ""
        message_text += f"{context}"
        message_text += f"Question: {prompt_question}"
        message_text += """\n\nAlways select one answer from the provided options. Return your answer in the following format, including both the choice letter and the full text of the selected option: {"<choice_letter>": "<full_answer_text>"}"""
        

        return [{
            "role": "user",
            "content": [{"type": "text", "text": message_text}]
        }]