import time
from tqdm import tqdm
import csv
import logging
import os, sys
import torch
from typing import Tuple, List, Any

from scripts.utils import (
    convert_responses_to_binary as convert_responses_to_binary, 
    extend_results as extend_results,
    metric_performances as metric_performances,
    format_elapsed_time as format_elapsed_time,
    construct_tensor_path as construct_tensor_path,
    store_tensor as store_tensor
    )

import PIL.Image
import google.generativeai as genai

def _load_img_label_pairs(csv_file_path: str) -> None:
    """
    Loads image-category-label pairs from a CSV file.
    
    Args:
        csv_file_path (str): Path to the CSV file containing image-category-label pairs.
    """
    img_category_pairs = []
    with open(csv_file_path, 'r', newline='', encoding='utf-8') as csvfile:
        reader = csv.reader(csvfile)
        next(reader)  # Skip header
        for row in reader:
            if len(row) < 3:
                continue  # Skip malformed rows
            img_id, category, label = row[0], row[1], row[2]
            img_category_pairs.append((img_id, category, 1 if label.upper() == 'YES' else 0))
    return img_category_pairs

def process_data(args: Any, img_id: str, category: str) -> Tuple[PIL.Image.Image, str, str]:
    """
    Processes image data and generates prompts based on the given category.
    
    Args:
        args (Any): The argument object containing dataset configurations.
        img_id (str): The identifier of the image file.
        category (str): The category name to be inserted into the prompts.
    
    Returns:
        Tuple[PIL.Image.Image, str, str]:
            - The loaded image as a PIL Image object.
            - The presence prompt with the category included.
            - The absence prompt with the category included.
    
    Example:
        >>> args = Namespace(DATASET_PATH={"dataset": "/path/to/dataset"},
        ...                 dataset="dataset",
        ...                 presence_question_template="Is there a [class] in the image?",
        ...                 absence_question_template="Is there no [class] in the image?")
        >>> img_id = "image1.jpg"
        >>> category = "cat"
        >>> image, prompt_presence, prompt_absence = process_data(args, img_id, category)
    """
    # Determine image path
    img_path = os.path.join(args.DATASET_PATH.get(args.dataset), img_id)

    # Load image
    image = PIL.Image.open(img_path)

    # Generate prompts
    prompt_presence = args.presence_question_template.replace('[class]', category)
    prompt_absence = args.absence_question_template.replace('[class]', category)

    return image, prompt_presence, prompt_absence

def gemini_1_5_pro_forward(model: Any, image: Any, prompt: str) -> Any:
    """
    Sends an image and a text prompt to the Gemini 1.5 Pro model for content generation.
    
    Args:
        model (Any): The Gemini 1.5 Pro model instance.
        image (Any): The image input, format depends on model requirements.
        prompt (str): The text prompt to guide the model's response.
    
    Returns:
        Any: The response generated by the model.
    
    Example:
        >>> model = GeminiModel()
        >>> image = load_image("example.jpg")
        >>> prompt = "Describe the objects in the image."
        >>> response = gemini_1_5_pro_forward(model, image, prompt)
    """
    response = model.generate_content([prompt, image])
    return response

def prepare_label(presence_label: int) -> Tuple[List[int], List[int]]:
    """
    Prepare presence and absence labels for binary classification.

    Args:
        presence_label (int): The presence label (0 or 1).

    Returns:
        Tuple[List[int], List[int]]: A tuple containing the presence and absence labels as lists.
    """
    absence_label = [1 - presence_label]
    presence_label = [presence_label]
    return presence_label, absence_label

def run_model(args: Any) -> None:
    """
    Runs the generative model on a dataset and evaluates its performance.
    
    Args:
        args (Any): The argument object containing dataset configurations and model parameters.
    
    Returns:
        None: The function does not return a value, but logs performance metrics and saves prediction tensors.
    
    Example:
        >>> args = Namespace(module="gemini-1.5-pro", dataset_csv_path="/path/to/dataset.csv")
        >>> run_model(args)
    """
    model_path = args.module

    # Load model
    genai.configure(api_key="your_api_key")
    model = genai.GenerativeModel(model_name=model_path)

    # Prepare data
    img_category_pair = _load_img_label_pairs(args.dataset_csv_path)

    start_time = time.time()

    # Storage for predictions and labels
    all_preds, all_labels = torch.empty(0), torch.empty(0)
    batch_preds, batch_labels = torch.empty(0), torch.empty(0)

    for idx, data in enumerate(tqdm(img_category_pair)):
        # Load batch data
        img_ids, categories, presence_label = data
        presence_label, absence_label = prepare_label(presence_label)

        # Process data for model
        image, prompt_presence, prompt_absence = process_data(args, img_ids, categories)

        try:
            if model_path == "gemini-1.5-pro":
                response_presences = gemini_1_5_pro_forward(model, image, prompt_presence)
                response_absences = gemini_1_5_pro_forward(model, image, prompt_absence)

            if isinstance(response_presences.text, str) and isinstance(response_absences.text, str):
                response_presences_content = response_presences.text
                response_absences_content = response_absences.text
            else:
                logging.info(f'Error: Response is empty or malformed for data {data}')
                raise ValueError("Response is empty or malformed")

        except (ValueError, KeyError) as e:
            logging.info(f'Error: Response is empty or malformed for data {data}')
            continue

        # Convert responses to binary
        response_presence_binary = convert_responses_to_binary([response_presences_content])
        response_absence_binary = convert_responses_to_binary([response_absences_content])

        results = (img_ids, categories, response_presence_binary, presence_label, response_absence_binary, absence_label)
        formulated_data = [results]

        # Update prediction and label storage
        all_preds, all_labels, batch_preds, batch_labels = \
            extend_results(formulated_data, all_preds, all_labels, batch_preds, batch_labels)
        
        # Periodic evaluation
        if idx % 10 == 0 and idx > 1:
            accuracy, f1, precision, recall, mcc = metric_performances(batch_preds, batch_labels)
            logging.info(f"Batch {idx} - Accuracy: {accuracy}, F1: {f1}, Precision: {precision}, Recall: {recall}, MCC: {mcc}")
            batch_preds, batch_labels = torch.empty(0), torch.empty(0)

    end_time = time.time()
    formatted_time = format_elapsed_time(start_time, end_time)

    # Final evaluation metrics
    accuracy, f1, precision, recall, mcc = metric_performances(all_preds, all_labels)
    logging.info(f"-Formatted time: {formatted_time}, GPUs num is {args.nproc_per_node}. The dataset has a total of {all_preds.size(0) if all_preds.size(0) == all_labels.size(0) else -1} samples complete evaluation of the performance of the metrics as follows: - Accuracy: {accuracy}, F1: {f1}, Precision: {precision}, Recall: {recall}, MCC: {mcc}")
    
    # Construct tensor file paths
    pred_tensor_dir = construct_tensor_path(label=False, args=args)
    label_tensor_dir = construct_tensor_path(label=True, args=args)

    # Save prediction and label tensors
    store_tensor(all_preds, pred_tensor_dir)
    store_tensor(all_labels, label_tensor_dir)
