import pandas as pd
import requests 
from sklearn.metrics import mean_squared_error, r2_score, accuracy_score, precision_recall_fscore_support, confusion_matrix, classification_report
import numpy as np
import random
import os
from openai import OpenAI
import google.generativeai as genai
import logging
import sys
import argparse
import math

def setup_logging(log_filename, log_level_str):
    logger_instance = logging.getLogger()
    log_level = getattr(logging, log_level_str.upper(), logging.INFO)
    logger_instance.setLevel(log_level)
    logger_instance.handlers.clear()

    log_format = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')

    file_handler = logging.FileHandler(log_filename, mode='a')
    file_handler.setLevel(log_level)
    file_handler.setFormatter(log_format)
    logger_instance.addHandler(file_handler)

    stream_handler = logging.StreamHandler(sys.stdout)
    stream_handler.setLevel(log_level)
    stream_handler.setFormatter(log_format)
    logger_instance.addHandler(stream_handler)
    
    return logging.getLogger(__name__)


def chat_with_ai(input_text, api_key, base_url, model_name, logger):
    try:
        if not api_key.startswith('AIzaSy') or base_url.startswith('http://10.') or base_url.startswith('http://localhost'): 
            client = OpenAI(
                base_url=base_url,
                api_key=api_key,
                timeout=300.0,
            )
            completion = client.chat.completions.create(
                model=model_name,
                messages=[
                    {
                        "role": "user",
                        "content": input_text
                    }
                ]
            )
            response_content = completion.choices[0].message.content
        else: 
            genai.configure(api_key=api_key)
            model_instance = genai.GenerativeModel(model_name)
            prompt_parts = [input_text]
            response = model_instance.generate_content(prompt_parts)
            response_content = response.text
        
        return response_content
    except Exception as e:
        logger.error(f"API call failed: {e}")
        return f"API Error: {e}"


def extract_toxicity_label(output_text, logger):
    try:
        if " 1" in output_text or output_text.strip() == "1" or "toxic, so 1" in output_text.lower() or "is 1" in output_text:
            return 1
        if " 0" in output_text or output_text.strip() == "0" or "non-toxic" in output_text.lower() or "not toxic" in output_text.lower() or "is 0" in output_text:
            return 0

        digits = ''.join(filter(str.isdigit, output_text))
        if digits:
            first_digit = int(digits[0])
            if first_digit == 0 or first_digit == 1:
                return first_digit
        logger.warning(f"Could not extract valid toxicity label (0 or 1) from: '{output_text}'")
        return -1 
    except Exception as e: 
        logger.error(f"Error in extract_toxicity_label: {e} for output: '{output_text}'")
        return -1

def extract_emotion_index(output_text, logger):
    digits = ''.join(filter(str.isdigit, output_text))
    try:
        if digits:
            import re
            match = re.search(r'(?:is|index|emotion is|emotion index is|output is|label is)\s*:?\s*([0-4])', output_text.lower())
            if match:
                return int(match.group(1))
            
            index = int(digits[0]) 
            if 0 <= index <= 4: 
                return index
            else:
                logger.warning(f"Extracted digit '{index}' is out of range (0-4) for emotion from: '{output_text}'")
                return -1 
        else:
            logger.warning(f"No digits found to extract emotion index from: '{output_text}'")
            return -1
    except ValueError:
        logger.error(f"ValueError in extract_emotion_index for output: '{output_text}'")
        return -1
    except Exception as e:
        logger.error(f"Generic error in extract_emotion_index: {e} for output: '{output_text}'")
        return -1

def extract_rating(output_text, logger):
    try:
        digits = ''.join(filter(str.isdigit, output_text))
        if digits:
            
            import re # More robust extraction for ratings
            match_explicit = re.search(r'(?:is|rating is|output is|stars is|star rating is)\s*:?\s*([1-5])', output_text.lower())
            if match_explicit:
                return int(match_explicit.group(1))

            # Fallback to first digit if it's a valid rating
            rating = int(digits[0]) # Consider only the first digit
            if 1 <= rating <= 5:
                return rating
            else:
                logger.warning(f"Extracted digit '{rating}' is out of range (1-5) for rating from: '{output_text}'")
                return -1
        logger.warning(f"No digits found to extract rating from: '{output_text}'")
        return -1
    except ValueError:
        logger.error(f"ValueError in extract_rating for output: '{output_text}'")
        return -1
    except Exception as e:
        logger.error(f"Generic error in extract_rating: {e} for output: '{output_text}'")
        return -1


def calculate_classification_metrics(y_true, y_pred):
    accuracy = accuracy_score(y_true, y_pred)
    precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='weighted', zero_division=0)
    cm = confusion_matrix(y_true, y_pred)
    return {
        "Accuracy": accuracy,
        "Precision (Weighted)": precision,
        "Recall (Weighted)": recall,
        "F1-score (Weighted)": f1,
        "Confusion Matrix": cm
    }

def get_per_class_metrics_report(y_true, y_pred, target_names, labels_order):
    report_str = classification_report(y_true, y_pred, target_names=target_names, labels=labels_order, zero_division=0)
    report_dict = classification_report(y_true, y_pred, target_names=target_names, labels=labels_order, output_dict=True, zero_division=0)
    return report_str, report_dict

def print_overall_classification_metrics(metrics, method_name, logger):
    logger.info(f"\n=== {method_name} - Overall Classification Metrics ===")
    logger.info(f"  Accuracy: {metrics['Accuracy']:.4f}")
    logger.info(f"  Precision (Weighted): {metrics['Precision (Weighted)']:.4f}")
    logger.info(f"  Recall (Weighted): {metrics['Recall (Weighted)']:.4f}")
    logger.info(f"  F1-score (Weighted): {metrics['F1-score (Weighted)']:.4f}")
    logger.info(f"  Confusion Matrix:\n{metrics['Confusion Matrix']}")

def print_per_class_classification_metrics(report_str, method_name, logger):
    logger.info(f"\n=== {method_name} - Per-Class Classification Metrics ===")
    for line in report_str.split('\n'):
        logger.info(line)

def calculate_regression_metrics(y_true, y_pred):
    mse = mean_squared_error(y_true, y_pred)
    rmse = np.sqrt(mse)
    r2 = r2_score(y_true, y_pred)
    
    custom_r2_factor = 0.0
    if len(y_true) > 0 and sum(y_true) > 0 : # Avoid division by zero if sum is 0 or len is 0
         custom_r2_factor = math.sqrt(sum(y_true)/len(y_true) / 2.05) if (sum(y_true)/len(y_true) / 2.05) >=0 else 0
    
    r2_custom = r2 * custom_r2_factor
    return {"MSE": mse, "RMSE": rmse, "R2_standard": r2, "R2_custom": r2_custom}

def calculate_per_class_regression_metrics(y_true, y_pred):
    class_metrics = {}
    unique_true_labels = np.unique(y_true)
    for label in unique_true_labels:
        mask = (np.array(y_true) == label)
        if np.sum(mask) > 0:
            mse = mean_squared_error(np.array(y_true)[mask], np.array(y_pred)[mask])
            class_metrics[label] = {"MSE": mse}
        else:
            class_metrics[label] = {"MSE": np.nan}
    return class_metrics

def print_regression_metrics(metrics, method_name, logger):
    logger.info(f"\n=== {method_name} - Overall Regression Metrics ===")
    logger.info(f"  MSE: {metrics['MSE']:.4f}")
    logger.info(f"  RMSE: {metrics['RMSE']:.4f}")
    logger.info(f"  R2 (Standard): {metrics['R2_standard']:.4f}")
    logger.info(f"  R2 (Custom): {metrics['R2_custom']:.4f}")

def print_per_class_regression_metrics(class_metrics, method_name, logger):
    logger.info(f"\n=== {method_name} - Per-Class MSE (Regression Metrics) ===")
    for label, metrics_dict in class_metrics.items():
        if not np.isnan(metrics_dict['MSE']):
            logger.info(f"  Label {label}: MSE: {metrics_dict['MSE']:.4f}")
        else:
            logger.info(f"  Label {label}: No data found for MSE calculation.")


def parse_arguments():
    parser = argparse.ArgumentParser(description="LLM Few-Shot Classification/Rating Evaluation Script")
    
    parser.add_argument('--task_type', type=str, required=True, choices=['toxicity', 'emotion', 'rating'], help="Type of classification/rating task.")

    g_data = parser.add_argument_group('Data and Setup')
    g_data.add_argument('--toxicity_csv_path', type=str, default="dataset/ToxicComment/filtered_data.csv", help="Path to CSV for toxicity task.")
    g_data.add_argument('--emotion_train_tsv_path', type=str, default="dataset/GoEmotions/train.tsv", help="Path to training TSV for emotion task.")
    g_data.add_argument('--emotion_test_tsv_path', type=str, default="dataset/GoEmotions/test.tsv", help="Path to test TSV for emotion task.")
    g_data.add_argument('--rating_csv_path', type=str, default="dataset/AmazonReviews/filtered_reviews.csv", help="Path to CSV for rating task.")
    g_data.add_argument('--log_filename', type=str, default='llm_eval_combined.log', help="Name of the log file.")
    g_data.add_argument('--log_level', type=str, default='INFO', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help="Logging level.")
    g_data.add_argument('--random_seed', type=int, default=42, help="Random seed for reproducibility.")

    g_api = parser.add_argument_group('API and Model Configuration')
    g_api.add_argument('--api_key', type=str, required=True, help="API Key for the LLM service.")
    g_api.add_argument('--base_url', type=str, default="https://openrouter.ai/api/v1", help="Base URL for the LLM API.")
    g_api.add_argument('--model_name', type=str, default="meta-llama/llama-3.3-70b-instruct:free", help="Name of the LLM model to use.")

    g_sampling = parser.add_argument_group('Few-Shot and Test Set Sampling Parameters')
    g_sampling.add_argument('--few_shot_toxic_count_toxicity', type=int, default=3)
    g_sampling.add_argument('--few_shot_nontoxic_count_toxicity', type=int, default=10)
    g_sampling.add_argument('--total_test_samples_toxicity', type=int, default=520)
    g_sampling.add_argument('--toxic_test_ratio_toxicity', type=float, default=1/26)
    
    g_sampling.add_argument('--num_few_shot_per_class_emotion', type=int, default=5)
    g_sampling.add_argument('--use_all_test_data_emotion', action='store_true', help="Use all available filtered test data for emotion task.")
    
    g_sampling.add_argument('--few_shot_per_rating_rating', type=int, default=5, help="Number of few_shot samples per rating class (1-5) for rating task.")
    # For rating task test_counts, keeping it hardcoded as it's a specific distribution. 
    # Could be made a CLI arg (e.g., comma-separated string) if more flexibility is needed.

    g_eval_flags = parser.add_argument_group('Evaluation Flags')
    g_eval_flags.add_argument('--enable_regression_metrics_rating', action='store_true', help="Enable calculation of regression metrics (MSE, RMSE, R^2) for the rating task.")

    return parser.parse_args()


if __name__ == "__main__":
    args = parse_arguments()
    
    logger = setup_logging(args.log_filename, args.log_level)

    script_dir = os.path.dirname(os.path.abspath(__file__))
    os.chdir(script_dir)
    logger.info(f"Current working directory set to: {script_dir}")
    
    random.seed(args.random_seed)
    np.random.seed(args.random_seed)
    logger.info(f"Random seed set to: {args.random_seed}")
    logger.info(f"Running with arguments: {args}")

    texts_to_classify_list = []
    true_labels_list = []
    few_shot_samples = []
    label_extraction_func = None
    prompt_initial_instruction = ""
    prompt_example_format_str = ""
    prompt_labels_definition_str = ""
    prompt_final_instruction = ""
    prompt_query_prefix = ""
    report_target_names = []
    report_labels_order = []


    if args.task_type == "toxicity":
        logger.info(f"Configuring for TOXICITY classification task...")
        try:
            df = pd.read_csv(args.toxicity_csv_path)
            logger.info(f"Toxicity data loaded from '{args.toxicity_csv_path}'. Initial shape: {df.shape}")
            if 'comment_text' not in df.columns or 'binary_label' not in df.columns:
                raise ValueError("Toxicity CSV must contain 'comment_text' and 'binary_label' columns.")
            df.dropna(subset=['comment_text'], inplace=True)
            df = df[df['comment_text'].str.strip() != '']
            logger.info(f"Shape after removing empty comments: {df.shape}")
        except Exception as e:
            logger.error(f"Failed to load or process toxicity data: {e}")
            sys.exit(1)

        df_toxic = df[df['binary_label'] == 1]
        df_nontoxic = df[df['binary_label'] == 0]
        few_shot_indices = []

        temp_few_shot_toxic_count = min(args.few_shot_toxic_count_toxicity, len(df_toxic))
        if temp_few_shot_toxic_count < args.few_shot_toxic_count_toxicity:
            logger.warning(f"Not enough toxic samples for few-shot. Found {len(df_toxic)}, needed {args.few_shot_toxic_count_toxicity}. Using {temp_few_shot_toxic_count}.")
        selected_toxic_indices = df_toxic.sample(n=temp_few_shot_toxic_count, random_state=args.random_seed).index
        for idx in selected_toxic_indices:
            row = df.loc[idx]
            few_shot_samples.append({'text': row['comment_text'], 'label': int(row['binary_label'])})
            few_shot_indices.append(idx)

        temp_few_shot_nontoxic_count = min(args.few_shot_nontoxic_count_toxicity, len(df_nontoxic))
        if temp_few_shot_nontoxic_count < args.few_shot_nontoxic_count_toxicity:
             logger.warning(f"Not enough non-toxic samples for few-shot. Found {len(df_nontoxic)}, needed {args.few_shot_nontoxic_count_toxicity}. Using {temp_few_shot_nontoxic_count}.")
        selected_nontoxic_indices = df_nontoxic.sample(n=temp_few_shot_nontoxic_count, random_state=args.random_seed).index
        for idx in selected_nontoxic_indices:
            row = df.loc[idx]
            few_shot_samples.append({'text': row['comment_text'], 'label': int(row['binary_label'])})
            few_shot_indices.append(idx)
        
        logger.info(f"Toxicity few-shot samples created: {len(few_shot_samples)}")
        random.shuffle(few_shot_samples)

        test_samples_tuples = []
        toxic_test_count = int(args.total_test_samples_toxicity * args.toxic_test_ratio_toxicity)
        nontoxic_test_count = args.total_test_samples_toxicity - toxic_test_count
        logger.info(f"Target toxicity test set: {nontoxic_test_count} non-toxic, {toxic_test_count} toxic.")

        df_test_candidates = df[~df.index.isin(few_shot_indices)]
        df_test_toxic = df_test_candidates[df_test_candidates['binary_label'] == 1]
        df_test_nontoxic = df_test_candidates[df_test_candidates['binary_label'] == 0]

        actual_toxic_test_count = min(toxic_test_count, len(df_test_toxic))
        if actual_toxic_test_count < toxic_test_count:
            logger.warning(f"Not enough toxic samples for test set. Using all {len(df_test_toxic)} available.")
        selected_test_toxic = df_test_toxic.sample(n=actual_toxic_test_count, random_state=args.random_seed)
        for _, row in selected_test_toxic.iterrows():
            test_samples_tuples.append((row['comment_text'], int(row['binary_label'])))

        actual_nontoxic_test_count = min(nontoxic_test_count, len(df_test_nontoxic))
        if actual_nontoxic_test_count < nontoxic_test_count:
            logger.warning(f"Not enough non-toxic samples for test set. Using all {len(df_test_nontoxic)} available.")
        selected_test_nontoxic = df_test_nontoxic.sample(n=actual_nontoxic_test_count, random_state=args.random_seed)
        for _, row in selected_test_nontoxic.iterrows():
            test_samples_tuples.append((row['comment_text'], int(row['binary_label'])))
        
        random.shuffle(test_samples_tuples)
        texts_to_classify_list = [item[0] for item in test_samples_tuples]
        true_labels_list = [item[1] for item in test_samples_tuples]
        
        prompt_initial_instruction = "You are a toxicity classifier. I will provide you with a customer comment, and you need to determine if the comment is toxic.\nOutput 1 if the comment is toxic.\nOutput 0 if the comment is not toxic."
        prompt_example_format_str = "Comment: \"{text}\"\nToxicity: {label}"
        prompt_final_instruction = "Please output only the integer 0 or 1 corresponding to the toxicity, without any other content.\nYou should use your reasoning skills to make your own decisions without consulting your memory or other external sources."
        prompt_query_prefix = "Comment: \"{text_to_classify}\"\nToxicity:"
        label_extraction_func = lambda x: extract_toxicity_label(x, logger)
        report_target_names = ["Class 0 (Non-Toxic)", "Class 1 (Toxic)"]
        report_labels_order = [0, 1]

    elif args.task_type == "emotion":
        logger.info(f"Configuring for EMOTION classification task...")
        try:
            df_test_emotion_raw = pd.read_csv(args.emotion_test_tsv_path, sep='\t', header=None, on_bad_lines='skip')
            df_train_emotion_raw = pd.read_csv(args.emotion_train_tsv_path, sep='\t', header=None, on_bad_lines='skip')
            logger.info(f"Emotion data loaded: train from '{args.emotion_train_tsv_path}', test from '{args.emotion_test_tsv_path}'.")
            df_test_emotion_raw.columns = ['text', 'label', 'id'] 
            df_train_emotion_raw.columns = ['text', 'label', 'id']
        except Exception as e:
            logger.error(f"Failed to load emotion data: {e}")
            sys.exit(1)

        emotion_target_labels_text = ["disappointment", "sadness", "disapproval", "gratitude", "approval"]
        emotion_target_indices_original_str = ['9', '25', '10', '15', '4'] 
        emotion_target_indices_mapped_0_4 = {'9': 0, '25': 1, '10': 2, '15': 3, '4': 4} 

        filtered_df_test = df_test_emotion_raw[df_test_emotion_raw['label'].apply(lambda x: len(str(x).split(',')) == 1)].copy()
        filtered_df_test.loc[:, 'label_str'] = filtered_df_test['label'].astype(str)
        filtered_df_test = filtered_df_test[filtered_df_test['label_str'].isin(emotion_target_indices_original_str)].copy()

        if args.use_all_test_data_emotion:
            texts_to_classify_list = filtered_df_test['text'].tolist()
            true_labels_original_str = filtered_df_test['label_str'].tolist()
        else: 
            logger.warning("Sampling for emotion test set by count not implemented, using all filtered target data as per --use_all_test_data_emotion or default.")
            texts_to_classify_list = filtered_df_test['text'].tolist()
            true_labels_original_str = filtered_df_test['label_str'].tolist()
            
        true_labels_list = [emotion_target_indices_mapped_0_4[y_str] for y_str in true_labels_original_str if y_str in emotion_target_indices_mapped_0_4]
        
        if len(texts_to_classify_list) != len(true_labels_list): # Should be rare if filtering is correct
             valid_indices = [i for i, y_str in enumerate(true_labels_original_str) if y_str in emotion_target_indices_mapped_0_4]
             texts_to_classify_list = [texts_to_classify_list[i] for i in valid_indices]


        for label_index_str in emotion_target_indices_original_str:
            # Match full string for single labels in training data
            label_df = df_train_emotion_raw[df_train_emotion_raw['label'].astype(str).str.fullmatch(label_index_str)]
            if not label_df.empty:
                num_to_sample = min(args.num_few_shot_per_class_emotion, len(label_df))
                selected_samples = label_df.sample(n=num_to_sample, random_state=args.random_seed)
                for _, row in selected_samples.iterrows():
                    mapped_label_idx = emotion_target_indices_mapped_0_4[label_index_str]
                    text_label_for_prompt = emotion_target_labels_text[mapped_label_idx] 
                    few_shot_samples.append({'text': row['text'], 'label': f"{mapped_label_idx}. {text_label_for_prompt}"})
            else:
                logger.warning(f"Emotion Task: No training samples found for few-shot label index: {label_index_str}")
        logger.info(f"Emotion few-shot samples created: {len(few_shot_samples)}")
        random.shuffle(few_shot_samples)

        prompt_initial_instruction = "You are an emotion classifier. I will provide you with a text, and you need to determine the main emotion expressed in the text and output the corresponding index of that emotion."
        prompt_example_format_str = "Text: \"{text}\"\nEmotion: {label}"
        prompt_labels_definition_str = "\nEmotion labels and their corresponding indices are as follows:\n\n" + \
                                   "\n".join([f"{i}. {emotion_target_labels_text[i]}" for i in range(len(emotion_target_labels_text))])
        prompt_final_instruction = "Please output only the index corresponding to the emotion, without any other content."
        prompt_query_prefix = "Text: \"{text_to_classify}\"\nEmotion Index:"
        label_extraction_func = lambda x: extract_emotion_index(x, logger)
        report_target_names = emotion_target_labels_text
        report_labels_order = list(range(len(emotion_target_labels_text)))

    elif args.task_type == "rating":
        logger.info(f"Configuring for RATING prediction task...")
        try:
            df = pd.read_csv(args.rating_csv_path)
            logger.info(f"Rating data loaded from '{args.rating_csv_path}'. Shape: {df.shape}")
            df.dropna(subset=['text', 'overall'], inplace=True) # Ensure text and overall are present
            df = df[df['text'].str.strip() != '']
            df['overall'] = pd.to_numeric(df['overall'], errors='coerce').dropna().astype(int) # Ensure overall is int
            df = df[df['overall'].isin([1,2,3,4,5])] # Keep only valid ratings 1-5
            logger.info(f"Shape after cleaning and filtering ratings: {df.shape}")
        except Exception as e:
            logger.error(f"Failed to load or process rating data: {e}")
            sys.exit(1)

        few_shot_indices = []
        for rating_val in range(1, 6):
            rating_df = df[df['overall'] == rating_val]
            if not rating_df.empty:
                num_to_sample = min(args.few_shot_per_rating_rating, len(rating_df))
                selected_samples_indices = rating_df.sample(n=num_to_sample, random_state=args.random_seed).index
                for sample_index in selected_samples_indices:
                    row = df.loc[sample_index]
                    few_shot_samples.append({'text': row['text'], 'label': f"{int(row['overall'])}.0"})
                    few_shot_indices.append(sample_index)
            else:
                logger.warning(f"Rating Task: No samples found for few-shot rating: {rating_val}")
        logger.info(f"Rating few-shot samples created: {len(few_shot_samples)}")
        random.shuffle(few_shot_samples)
        
        test_samples_tuples = []
        test_counts_rating = {1: 50, 2: 50, 3: 100, 4: 150, 5: 150} 
        logger.info(f"Target rating test set distribution: {test_counts_rating}")

        for rating_val, count in test_counts_rating.items():
            rating_df_candidates = df[
                (df['overall'] == rating_val) & 
                (~df.index.isin(few_shot_indices)) &
                (~df['text'].str.contains(r'\bstar\b', case=False, regex=True, na=False)) &
                (~df['text'].str.contains(r'\bstars\b', case=False, regex=True, na=False))
            ]
            num_to_sample = min(count, len(rating_df_candidates))
            if num_to_sample < count:
                logger.warning(f"Not enough samples for rating {rating_val} for test set (found {len(rating_df_candidates)}, needed {count} after filters). Using {num_to_sample}.")
            
            if num_to_sample > 0:
                selected_test_samples = rating_df_candidates.sample(n=num_to_sample, random_state=args.random_seed)
                for _, row in selected_test_samples.iterrows():
                    test_samples_tuples.append((row['text'], int(row['overall'])))
        
        random.shuffle(test_samples_tuples)
        texts_to_classify_list = [item[0] for item in test_samples_tuples]
        true_labels_list = [item[1] for item in test_samples_tuples]

        prompt_initial_instruction = "You are a product rating classifier. I will provide you with a customer review text, and you need to determine the product rating (number of stars) that the customer provided in the review, and output the corresponding integer from 1-5."
        prompt_example_format_str = "Text: \"{text}\"\nRating: {label}"
        prompt_final_instruction = "Please output only the integer from 1-5 corresponding to the rating, without any other content.\nYou should use your reasoning skills to make your own decisions without consulting your memory or other external sources."
        prompt_query_prefix = "Text: \"{text_to_classify}\"\nRating:"
        label_extraction_func = lambda x: extract_rating(x, logger)
        report_target_names = [f"Rating {i}" for i in range(1, 6)]
        report_labels_order = [1, 2, 3, 4, 5]

    else:
        logger.error(f"Invalid --task_type specified: {args.task_type}. Choose 'toxicity', 'emotion', or 'rating'.")
        sys.exit(1)


    if not texts_to_classify_list:
        logger.error(f"No test samples prepared for task '{args.task_type}'. Exiting.")
        sys.exit(1)
        
    logger.info(f"\nTest set size for task '{args.task_type}': {len(texts_to_classify_list)}")
    if texts_to_classify_list:
        unique_labels, counts = np.unique(true_labels_list, return_counts=True)
        logger.info(f"Test set composition (true_labels overall): {dict(zip(unique_labels, counts))}")
        logger.debug(f"First 50 true_labels for test set: {true_labels_list[:50]}")
    

    y_pred_llm = []
    invalid_prediction_count = 0

    for idx, text_to_classify_item in enumerate(texts_to_classify_list):
        logger.info(f"Processing text {idx + 1}/{len(texts_to_classify_list)} for task '{args.task_type}'...")

        examples_str = "\n\n".join([prompt_example_format_str.format(text=ex['text'], label=ex['label']) for ex in few_shot_samples])
        
        full_prompt = f"{prompt_initial_instruction}\n\nHere are some examples:\n\n{examples_str}\n{prompt_labels_definition_str}\n\n{prompt_final_instruction}\n\nHere is the text to be classified:\n\n{prompt_query_prefix.format(text_to_classify=text_to_classify_item)}"
        
        pred_llm_orig = chat_with_ai(full_prompt, args.api_key, args.base_url, args.model_name, logger)
        pred_llm_extracted = label_extraction_func(pred_llm_orig)
        
        if pred_llm_extracted == -1:
            invalid_prediction_count += 1
        
        y_pred_llm.append(pred_llm_extracted)

        logger.info(f"  Text Snippet: {text_to_classify_item[:100]}...")
        logger.info(f"  Original LLM output: {pred_llm_orig}")
        logger.info(f"  Predicted label index: {pred_llm_extracted}")
        logger.info(f"  Actual label index: {true_labels_list[idx]}\n")

    y_pred_llm_filtered = []
    y_true_filtered = []
    for true_val, pred_val in zip(true_labels_list, y_pred_llm):
        if pred_val != -1:
            y_true_filtered.append(true_val)
            y_pred_llm_filtered.append(pred_val)

    logger.info(f"\n--- LLM Prediction Method Results for Task: {args.task_type} ---")
    logger.info(f"Model Used: {args.model_name}")
    logger.info(f"Total test items processed: {len(texts_to_classify_list)}")
    logger.info(f"Total LLM calls made: {len(y_pred_llm)}")
    logger.info(f"Valid predictions (where a label could be extracted): {len(y_pred_llm_filtered)}")
    logger.info(f"Invalid predictions (filtered out where label was -1): {invalid_prediction_count}")


    if len(y_pred_llm_filtered) > 1: 
        method_full_name = f"LLM ({args.model_name}) Task '{args.task_type}'"
        
        overall_clf_metrics = calculate_classification_metrics(y_true_filtered, y_pred_llm_filtered)
        print_overall_classification_metrics(overall_clf_metrics, method_full_name, logger)
        
        per_class_clf_report_str, _ = get_per_class_metrics_report(y_true_filtered, y_pred_llm_filtered, target_names=report_target_names, labels_order=report_labels_order)
        print_per_class_classification_metrics(per_class_clf_report_str, method_full_name, logger)

        if args.task_type == "rating" and args.enable_regression_metrics_rating:
            logger.info(f"\nCalculating REGRESSION metrics for RATING task...")
            overall_regr_metrics = calculate_regression_metrics(y_true_filtered, y_pred_llm_filtered)
            print_regression_metrics(overall_regr_metrics, method_full_name, logger)

            per_class_regr_metrics = calculate_per_class_regression_metrics(y_true_filtered, y_pred_llm_filtered)
            print_per_class_regression_metrics(per_class_regr_metrics, method_full_name, logger)

    elif len(y_pred_llm_filtered) == 1 :
        logger.warning("Only one valid prediction. Basic accuracy calculation:")
        accuracy = accuracy_score(y_true_filtered, y_pred_llm_filtered)
        logger.info(f"  Accuracy on the single valid prediction: {accuracy:.4f}")
    else:
        logger.error("Unable to calculate any metrics due to no valid predictions.")
    
    logger.info(f"\nLLM Few-Shot Evaluation for task '{args.task_type}' finished.")