import glob 
import os
import sys
import pandas as pd
import torch
import argparse
import json
from pathlib import Path
import csv
from tqdm import tqdm
sys.path.insert(0, str(Path(__file__).parent.parent.parent))

from llava.constants import (
    IMAGE_TOKEN_INDEX,
    DEFAULT_IMAGE_TOKEN,
    DEFAULT_IM_START_TOKEN,
    DEFAULT_IM_END_TOKEN,
    IMAGE_PLACEHOLDER,
)
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import (
    process_images,
    tokenizer_image_token,
    get_model_name_from_path,
    KeywordsStoppingCriteria,
)

from PIL import Image
import requests
from io import BytesIO
import re
import time

def load_image(image_file):
    """
    Load an image from a file path or URL
    
    Args:
        image_file (str): Path or URL to the image
        
    Returns:
        PIL.Image: The loaded image in RGB format
    """
    if image_file.startswith("http") or image_file.startswith("https"):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert("RGB")
    else:
        image = Image.open(image_file).convert("RGB")
    return image

def generate_query(row, image, temp_str, expname):
    """
    Generate a query based on the row information and image
    
    Args:
        row (dict): Row of data containing image information
        image (PIL.Image): The loaded image
        temp_str (str): Temperature string used as part of the column names
        expname (str): Name of the experiment
        
    Returns:
        str: Generated query for the model
    """
    query = f"""This is an image that a marketer from company "{row['company']}" wants to post on social media for marketing purposes. The following information about this image is also given:
(1) image resolution i.e. (width, height): [{image.size[0]}, {image.size[1]}],
(2) image colors and tones: {row[f'color_{expname}_{temp_str}']},
(3) marketer's intended image description: {row[f'caption_{expname}_{temp_str}']},
(4) marketer's intended image tags: {row[f'keywords_{expname}_{temp_str}']},
(5) date of posting: 5 may 2023
Now, carefully observe the image. You have to predict the "number of likes" that this image will get, on a scale of 0 to 100. 
It measures the number of times the viewers will interact with the social media post by clicking the "Like" button to express their appreciation for the image. Thus, an image with higher visual appeal, alignment with the company's brand identity, and relevance to the audience, is likely to receive more likes. Moreover, a good image should strongly correspond with the marketer's intended image description and tags to attract the target audience. 
Your predicted "number of likes" will help the marketer to decide whether to post this image or not on the social media platform.
Answer properly in JSON format. Do not include any other information in your answer."""
    
    return query

def find_image_path(base_dir, temp_str, image_file):
    """
    Find the image path where the file name ends with the specified image_file
    
    Args:
        base_dir (str): Base directory to search in
        temp_str (str): Temperature string for subdirectory
        image_file (str): Image file name or part of it
        
    Returns:
        str or None: Path to the matched image file or None if not found
    """
    pattern = f"{base_dir}/{temp_str}/*{image_file}"
    matching_files = glob.glob(pattern)
    if matching_files:
        return matching_files[0]  # Return the first match
    return None

def eval_model_for_row(row, model_path, model_base, conv_mode, gpu_id, args, tokenizer, model, image_processor, context_len, expname):
    """
    Evaluate a single row using LLaVA model
    
    Args:
        row (dict): Row of data containing image information
        model_path (str): Path to the model
        model_base (str): Base model path
        conv_mode (str): Conversation mode
        gpu_id (int): GPU ID to use
        args (argparse.Namespace): Command line arguments
        tokenizer: The model tokenizer
        model: The loaded model
        image_processor: The image processor
        context_len (int): Context length for the model
        expname (str): Name of the experiment
        
    Returns:
        tuple: Dictionaries of values and outputs for different temperatures
    """
    device = f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu"
    model_name = get_model_name_from_path(model_path)
    # Ensure model is on the specified device
    model = model.to(device)
    
    # Different temperature settings to evaluate
    temp_set = ['0.75', '0.85', '0.95']
   
    output_values = {}
    output_outputs = {}
    
    for temp_str in temp_set:
        try:
            # Prepare query
            image_file = row['media_keys']
            
            # Construct image path
            img_path = f"{expname}-sdxl_images/{temp_str}/{image_file}"
            
            # Check if image exists
            if not os.path.exists(img_path):
                img_path = find_image_path(f"{expname}-sdxl_images", temp_str, image_file)
                if not img_path:
                    print(f"Image not found: {img_path}")
                    continue
            
            image = load_image(img_path)

            qs = generate_query(row, image, temp_str, expname)
            image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
            
            if IMAGE_PLACEHOLDER in qs:
                if model.config.mm_use_im_start_end:
                    qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
                else:
                    qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
            else:
                if model.config.mm_use_im_start_end:
                    qs = image_token_se + "\n" + qs
                else:
                    qs = DEFAULT_IMAGE_TOKEN + "\n" + qs

            # Determine conversation mode
            if "llama-2" in model_name.lower():
                conv_mode = "llava_llama_2"
            elif "v1" in model_name.lower():
                conv_mode = "llava_v1"
            elif "mpt" in model_name.lower():
                conv_mode = "mpt"
            else:
                conv_mode = "llava_v0"
            
            if args.conv_mode is not None and conv_mode != args.conv_mode:
                print(
                    "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
                        conv_mode, args.conv_mode, args.conv_mode
                    )
                )
            else:
                args.conv_mode = conv_mode
            
            conv = conv_templates[conv_mode].copy()
            conv.append_message(conv.roles[0], qs)
            conv.append_message(conv.roles[1], None)
            prompt = conv.get_prompt()

            # Load and process image on the specified device
            images_tensor = process_images(
                [image],
                image_processor,
                model.config
            ).to(device, dtype=torch.float16)

            # Tokenize input for the specified device
            input_ids = (
                tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
                .unsqueeze(0)
                .to(device)
            )

            stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
            keywords = [stop_str]
            stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)

            with torch.inference_mode():
                output_ids = model.generate(
                    input_ids,
                    images=images_tensor,
                    do_sample=True if args.temperature > 0 else False,
                    temperature=args.temperature,
                    top_p=args.top_p if args.top_p is not None else 0.9,
                    num_beams=args.num_beams,
                    max_new_tokens=args.max_new_tokens,
                    use_cache=True,
                    stopping_criteria=[stopping_criteria],
                )

            input_token_len = input_ids.shape[1]
            n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
            if n_diff_input_output > 0:
                print(
                    f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids"
                )
            outputs = tokenizer.batch_decode(
                output_ids[:, input_token_len:], skip_special_tokens=True
            )[0]
            outputs = outputs.strip()
            
            with torch.inference_mode():
                values = model(
                    output_ids,
                    images=images_tensor,
                    return_values=True
                )
                
            # Store the results
            output_values[f'value_{expname}_{temp_str}'] = values.item()
            output_outputs[f'outputs_{expname}_{temp_str}'] = outputs
            
        except Exception as e:
            print(f"Error processing temperature {temp_str}: {e}")
            output_values[f'value_{expname}_{temp_str}'] = None
            output_outputs[f'outputs_{expname}_{temp_str}'] = f"Error: {str(e)}"
    
    return output_values, output_outputs

def process_csv(input_csv, output_csv, model_path, model_base, conv_mode, gpu_id, args, expname):
    """
    Process the entire CSV file
    
    Args:
        input_csv (str): Path to input CSV file
        output_csv (str): Path to output CSV file
        model_path (str): Path to the model
        model_base (str): Base model path
        conv_mode (str): Conversation mode
        gpu_id (int): GPU ID to use
        args (argparse.Namespace): Command line arguments
        expname (str): Name of the experiment
    """
    disable_torch_init()
    device = f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu"
    if torch.cuda.is_available():
        torch.cuda.set_device(gpu_id)

    # Load model
    print(f"Loading model from {model_path}...")
    model_name = get_model_name_from_path(model_path)
    tokenizer, model, image_processor, context_len = load_pretrained_model(
        model_path, model_base, model_name, device=device
    )
    print("Model loaded successfully")
    
    # Ensure output directory exists
    os.makedirs(os.path.dirname(output_csv), exist_ok=True)

    # Read the input CSV to get total number of rows for progress tracking
    df = pd.read_csv(input_csv)
    total_rows = len(df)
    
    # Temperature settings to evaluate
    temp_set = ['0.75', '0.85', '0.95']
    
    # Prepare output fieldnames
    output_fieldnames = df.columns.tolist()
    for temp_str in temp_set:
        output_fieldnames.extend([
            f'value_{expname}_{temp_str}', 
            f'outputs_{expname}_{temp_str}'
        ])
    
    # Process rows and write to output CSV
    with open(output_csv, 'w', newline='') as output_file:
        writer = csv.DictWriter(output_file, fieldnames=output_fieldnames)
        writer.writeheader()
        
        # Process each row with progress bar
        for idx, row in tqdm(df.iterrows(), total=total_rows, desc="Processing rows"):
            try:
                # Convert row to dictionary
                row_dict = row.to_dict()
                
                # Evaluate the row
                output_values, output_outputs = eval_model_for_row(
                    row_dict, model_path, model_base, conv_mode, gpu_id, 
                    args, tokenizer, model, image_processor, context_len, expname
                )
                
                # Merge original row with results
                output_row = {**row_dict, **output_values, **output_outputs}
                
                # Write the row
                writer.writerow(output_row)
                
                # Flush to ensure writing
                output_file.flush()
                
            except Exception as e:
                print(f"Error processing row {idx}: {e}")
    
    print(f"Processed all {total_rows} rows from {input_csv} and saved to {output_csv}")

def main():
    parser = argparse.ArgumentParser(description="Process images in a CSV file with LLaVA model to predict scores")
    
    # Model parameters
    parser.add_argument("--model-path", type=str, default="facebook/opt-350m",
                        help="Path to the pretrained model")
    parser.add_argument("--model-base", type=str, default=None,
                        help="Base model path if applicable")
    parser.add_argument("--conv-mode", type=str, default=None,
                        help="Conversation mode to use with the model")
    
    # Generation parameters
    parser.add_argument("--temperature", type=float, default=0.2,
                        help="Temperature for sampling")
    parser.add_argument("--top_p", type=float, default=0.9,
                        help="Top-p sampling parameter")
    parser.add_argument("--num_beams", type=int, default=1,
                        help="Number of beams for beam search")
    parser.add_argument("--max_new_tokens", type=int, default=512,
                        help="Maximum number of new tokens to generate")
    
    # Input/output parameters
    parser.add_argument("--input-csv", type=str, required=True,
                        help="Path to the input CSV file")
    parser.add_argument("--output-csv", type=str, required=True,
                        help="Path to the output CSV file")
    parser.add_argument("--experiment-name", type=str, required=True,
                        help="Name of the experiment (used for column naming and image paths)")
    
    # Hardware parameters
    parser.add_argument("--gpu-id", type=int, default=0,
                        help="GPU ID to use for processing")
    
    # CSV parameters
    parser.add_argument("--sep", type=str, default=",",
                        help="Separator to use for CSV files")
    
    args = parser.parse_args()

    process_csv(
        args.input_csv, 
        args.output_csv, 
        args.model_path, 
        args.model_base, 
        args.conv_mode, 
        args.gpu_id,
        args,
        args.experiment_name
    )

if __name__ == "__main__":
    main()