import pandas as pd
import numpy as np
import re
import argparse
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
from vllm.outputs import RequestOutput
from typing import List
import torch
from huggingface_hub import login
from fractions import Fraction
import random
import numpy as np
import torch
from transformers import set_seed

def seed_everything(seed: int = 111):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    set_seed(seed)


def preprocess_model_output(text):
    """
    to be used in extract_answer_number
    """
    # remove unnecessary markdown symbols
    text = re.sub(r'[\*\`\#\~]+', '', text)
    
    # clean up text before numbers
    text = re.sub(r'[$€¥£]', '', text)  # remove currency symbols
    
    # clean up line breaks and tabs
    text = re.sub(r'[\n\t\r]+', ' ', text)
    
    # remove commas (thousands separator)
    text = re.sub(r'(\d),(\d)', r'\1\2', text)
    
    # clean up multiple spaces
    text = re.sub(r'\s+', ' ', text)
    
    # analyze context to find negative hint
    has_decrease_hint = any(hint in text.lower() for hint in [
        'decreas', 'loss', 'deficit', 'negativ', 'fell', 'decline', 'drop', 'reduc', 'below'
    ])
    
    # process percentages (15% -> 0.15)
    percent_match = re.search(r'([-+]?\d+\.?\d*)\s*%', text)
    if percent_match:
        percent_value = float(percent_match.group(1)) / 100
        # replace original percentage expression with number
        text = re.sub(r'([-+]?\d+\.?\d*)\s*%', str(percent_value), text)
    
    # handle cases where context has negative hint but no explicit negative sign
    if has_decrease_hint:
        # find number without negative sign in already preprocessed text
        match = re.search(r'(?<!\-)\b\d+\.?\d*\b', text)
        if match and "increase" not in text.lower() and "growth" not in text.lower():
            # negative context but no explicit negative sign in number
            text += " (potential negative number)"
    
    return text.strip()


def extract_answer_number(completion):

    text = completion.strip().lower()  # convert to lowercase
    
    # find hint for sign
    negative_hint = any(hint in text for hint in [
        'decrease', 'loss', 'deficit', 'negative', 'fell', 'decline', 'drop', 'reduction', 'below', 
        'potential negative', 'minus'
    ])
    
    # pattern for final answer (specialized for financial context)
    final_answer_pattern = r"\*\*model's final answer is:\*\*\s*([-+]?[\d\.,]+(?:/[\d\.,]+)?)"
    final_match = re.search(final_answer_pattern, text, re.IGNORECASE)
    if final_match:
        answer_text = final_match.group(1).strip()
        # remove commas
        answer_text = answer_text.replace(',', '')
        try:
            # process fractions
            if '/' in answer_text:
                frac = Fraction(answer_text)
                result = float(frac)
                # there is a negative hint and the number is positive, consider converting to negative
                if negative_hint and result > 0 and not re.match(r'^-', answer_text):
                    # context suggests negative but not certain, so keep original value
                    return result
                return result
            else:
                result = float(answer_text)
                # there is a negative hint and the number is positive, consider converting to negative
                if negative_hint and result > 0 and not re.match(r'^-', answer_text):
                    # context suggests negative but not certain, so keep original value
                    return result
                return result
        except:
            pass  # failed to convert number, try other patterns
    
    # check 'yes'/'no' response (check Model's Final Answer format)
    yes_pattern = r"\*\*model's final answer is:\*\*\s*(yes|y)"
    no_pattern = r"\*\*model's final answer is:\*\*\s*(no|n)"
    
    if re.search(yes_pattern, text, re.IGNORECASE):
        return "yes"
    elif re.search(no_pattern, text, re.IGNORECASE):
        return "no"
    
    # check for general 'yes'/'no' response
    if re.search(r"\b(yes|y)\b", text) or "correct" in text:
        return "yes"
    elif re.search(r"\b(no|n)\b", text) or "incorrect" in text:
        return "no"

    # patterns for identifying numbers
    number_expressions = [
        r"the answer is[:\s]+([-+]?[\d\.,]+)",
        r"final answer[:\s]+([-+]?[\d\.,]+)",
        r"answer[:\s]+([-+]?[\d\.,]+)",
        r"result is[:\s]+([-+]?[\d\.,]+)",
        r"equals[:\s]+([-+]?[\d\.,]+)",
        r"comes to[:\s]+([-+]?[\d\.,]+)",
        r"^\s*([-+]?[\d\.,]+)\s*$",         # case where only number is present
        r"([-+]?[\d\.,]+)\s*%",             # case where number is in percentage form
        r"([-+]?[\d\.,]+)\s*million",       # case where number is in million unit
        r"([-+]?[\d\.,]+)\s*billion",       # case where number is in billion unit
        r"is\s*([-+]?[\d\.,]+)"             # case where number is after "is"
    ]

    # try to extract number (more sophisticated patterns)
    for pattern in number_expressions:
        match = re.search(pattern, text)
        if match:
            num_str = match.group(1) if len(match.groups()) > 0 else match.group(0)
            try:
                # remove commas and process fractions
                num_str = num_str.replace(',', '')
                if '/' in num_str:
                    frac = Fraction(num_str)
                    result = float(frac)
                else:
                    result = float(num_str)
                
                # there is a negative hint and the number is positive, consider converting to negative
                if negative_hint and result > 0 and not re.match(r'^-', num_str):
                    # context suggests negative but not certain, so keep original value
                    return result
                
                return result
            except Exception:
                continue  # try next pattern
    
    # if no pattern is found, try the old pattern
    match = re.search(r'[-+]?[\d\.,]+(?:/[\d\.,]+)?', text)
    
    if match:
        num_str = match.group()
        try:
            # remove commas and process fractions
            num_str = num_str.replace(',', '')
            if '/' in num_str:
                frac = Fraction(num_str)
                result = float(frac)
            else:
                result = float(num_str)
            
            # there is a negative hint and the number is positive, consider converting to negative
            if negative_hint and result > 0 and not re.match(r'^-', num_str):
                # context suggests negative but not certain, so keep original value
                return result
            
            return result
        except Exception:
            return None
    
    return None



def generate_extract_prompt_with_question(
    extract_template: str,
    answer_text: str,
    tokenizer
) -> str:
    """
    generate prompt to be passed to the 2nd model.
    - question_text: original question
    - answer_text: 1st model's long answer (Chain of Thought, etc.)
    - tokenizer: apply_chat_template
    """
    
    # (1) construct prompt_text including question and answer
    filled_prompt = extract_template.format(
        answer_text=answer_text.strip(),
    )

    # (2) apply Chat template (e.g., tokenizer.apply_chat_template)
    messages = [
        {"role": "user", "content": filled_prompt}
    ]
    prompt_for_model = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )

    # (3) add "Model's Final Answer is: " to the end
    prompt_for_model += "\n- **Model's Final Answer is:** "

    return prompt_for_model




def parse_args():    
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str)  # model path
    parser.add_argument("--data_file", type=str, default='./data/gsm8k_test')  # data_path
    parser.add_argument("--tensor_parallel_size", type=int, default=4)  # tensor_parallel_size
    return parser.parse_args()


def main():
    args = parse_args()
    
    # load CSV file with explicit data type specification
    try:
        df = pd.read_csv(args.data_file)
        
        # try to convert GT values (separate numbers and strings)
        gt_values = []
        for val in df['GT']:
            if isinstance(val, str):
                # if string, try to convert to number
                val = val.strip().lower()
                if val in ['yes', 'no']:
                    gt_values.append(val)  # yes/no is kept as is
                else:
                    try:
                        # remove commas and try to convert
                        val = val.replace(',', '')
                        gt_values.append(float(val))
                    except:
                        gt_values.append(val)  # keep original if conversion fails
            else:
                gt_values.append(val)  # keep original if already a number
        
        # update GT values with converted values
        df['GT'] = gt_values
        
        print(f"Loaded CSV with {len(df)} rows")
    except Exception as e:
        print(f"Error loading CSV file: {e}")
        return
    
    model_name = 'meta-llama/Llama-3.3-70B-Instruct'
    model = LLM(
        model_name,
        tensor_parallel_size=args.tensor_parallel_size,
        max_model_len=16392,
        trust_remote_code=True,
        gpu_memory_utilization=0.95,
        dtype='auto',
        enforce_eager=True)

    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        padding_side="left")

    # use template for financial data
    template_path='./template/extract_template_fin.txt'
    
    with open(template_path, 'r',encoding='utf-8') as file:
        extract_template = file.read()  
        
    lst= []
    for i in range(len(df)):
        prompt=generate_extract_prompt_with_question(extract_template,df['answer'][i],tokenizer)
        lst.append(prompt)
        
    df['prompts']=lst
    
    sampling_params = SamplingParams(temperature=0.6, top_p=0.95,min_p=0, top_k=20, max_tokens=10)
    
    outputs: List[RequestOutput] = model.generate(lst, sampling_params)
    
    generated_texts = [output.outputs[0].text for output in outputs]
    
    df['gen_answer'] = generated_texts
    
    # for debugging
    parsing_results = []
    
    count=0
    for i in range(len(df)):
        # apply preprocessing
        gen_text = preprocess_model_output(generated_texts[i])
        gen_val = extract_answer_number(gen_text)
        gt_val = df['GT'][i]
        
        # for debugging
        debug_info = {
            'index': i,
            'query': df['query'][i],
            'model_answer': gen_text,
            'extracted_value': gen_val,
            'ground_truth': gt_val,
            'is_correct': False  # default value
        }
        
        # check if both values are not None
        if gen_val is not None:
            # both are strings (yes/no)
            if isinstance(gen_val, str) and isinstance(gt_val, str):
                is_correct = gen_val.lower() == gt_val.lower()
            # both are numbers
            elif isinstance(gen_val, (int, float)) or isinstance(gt_val, (int, float)):
                try:
                    gen_float = float(gen_val)
                    gt_float = float(gt_val)
                    
                    # check difference between percentage and decimal representation (100x difference)
                    percent_factor = False
                    if abs(gen_float) > abs(gt_float):
                        ratio = abs(gen_float / gt_float) if gt_float != 0 else float('inf')
                        if 98 <= ratio <= 102:  # check if it's about 100x difference
                            # possible percentage/decimal mismatch
                            percent_factor = True
                            gen_float_adjusted = gen_float / 100  # convert percentage to decimal
                            debug_info['percent_conversion'] = True
                            debug_info['original_gen'] = gen_float
                            debug_info['adjusted_gen'] = gen_float_adjusted
                            gen_float = gen_float_adjusted  # update with adjusted value
                    elif abs(gt_float) > abs(gen_float):
                        ratio = abs(gt_float / gen_float) if gen_float != 0 else float('inf')
                        if 98 <= ratio <= 102:  # check if it's about 100x difference
                            # possible percentage/decimal mismatch
                            percent_factor = True
                            gt_float_adjusted = gt_float / 100  # convert percentage to decimal
                            debug_info['percent_conversion'] = True
                            debug_info['original_gt'] = gt_float
                            debug_info['adjusted_gt'] = gt_float_adjusted
                            gt_float = gt_float_adjusted  # update with adjusted value
                    
                    # check special cases with different signs
                    if (gen_float > 0 and gt_float < 0) or (gen_float < 0 and gt_float > 0):
                        # only different signs and similar absolute values
                        abs_gen = abs(gen_float)
                        abs_gt = abs(gt_float)
                        
                        # additional info for sign issue
                        debug_info['sign_issue'] = True
                        debug_info['abs_gen'] = abs_gen
                        debug_info['abs_gt'] = abs_gt
                        
                        # 1. both are large values (1 or more)
                        if abs_gen >= 1 and abs_gt >= 1:
                            abs_diff = abs(abs_gen - abs_gt)
                            if abs_diff <= 1.0:
                                is_correct = True
                                debug_info['abs_diff'] = abs_diff
                                debug_info['sign_but_correct'] = True
                            else:
                                is_correct = False
                        # 2. small values: relative error within 1%
                        else:
                            # relative error calculation (divide by larger value)
                            rel_error = abs(abs_gen - abs_gt) #/ max(abs_gen, abs_gt) if max(abs_gen, abs_gt) > 1e-6 else abs(abs_gen - abs_gt)
                            if rel_error <= 0.01:  # 1% error allowed
                                is_correct = True
                                debug_info['rel_error'] = rel_error
                                debug_info['sign_but_correct'] = True
                            else:
                                is_correct = False
                    else:
                        # general comparison (same sign)
                        # 1. both values are 1 or more: use absolute error
                        if abs(gen_float) >= 1 and abs(gt_float) >= 1:
                            abs_diff = abs(gen_float - gt_float)
                            is_correct = abs_diff <= 1.0  # absolute error within 1
                            debug_info['abs_diff'] = abs_diff
                        # 2. small values: use relative error
                        else:
                            # relative error calculation (divide by larger value)
                            rel_error = abs(gen_float - gt_float) #/ max(abs(gen_float), abs(gt_float)) if max(abs(gen_float), abs(gt_float)) > 1e-6 else abs(gen_float - gt_float)
                            is_correct = rel_error <= 0.01  # 1% error allowed
                            debug_info['rel_error'] = rel_error
                except (ValueError, TypeError):
                    is_correct = False
                    debug_info['error'] = 'Type conversion error'
            else:
                is_correct = False
                debug_info['error'] = 'Type mismatch'
            
            debug_info['is_correct'] = is_correct
            
            if is_correct:
                count += 1
        else:
            debug_info['error'] = 'Failed to extract value'
        
        parsing_results.append(debug_info)
    
    print('='*80)
    print(f'ACC of {args.model}:')
    print('ACC:',count/len(df))
    print('='*80)
    
    # save debug results
    debug_df = pd.DataFrame(parsing_results)
    debug_df.to_csv(f'./debug_parsing_{args.model}.csv', index=False)
    print(f'Debug info saved to: ./debug_parsing.csv')
    
    # check incorrect samples
    incorrect_samples = debug_df[debug_df['is_correct'] == False]
    correct_samples = debug_df[debug_df['is_correct'] == True]
    print(f'Total incorrect samples: {len(incorrect_samples)}')
    
    # save results to original data
    df.to_csv(args.data_file,index=False)
    
    

if __name__ == "__main__":
    seed_everything(666)
    main()