import pandas as pd
import os
import numpy as np
from scipy import stats
import argparse
import torch
import json
import re
import shutil
from collections import defaultdict
import pickle
import random
import ast
from code.utils import generate_and_save_embeddings, get_embedding_llm
from code.agents import init_agent_wikiarts
import requests
from code.prompts import WIKIARTS_SYNTHETIC_HUMAN_INSTRUCTION, WIKIARTS_SYNTHETIC_HUMAN_SYSTEM_MESSAGE

def get_art_sample(art_style, n_sample=100, seed=42):
        # Load WikiArt info
        wikiart_path = os.path.join("data", "wikiarts", "WikiArt-info.tsv") 
        wikiart_df = pd.read_csv(wikiart_path, sep='\t')
        
        # Filter for specified art style
        style_df = wikiart_df[wikiart_df['Category'] == art_style]
        
        # Sample n artworks randomly
        if n_sample:
            np.random.seed(seed)
            style_df = style_df.sample(n=min(n_sample, len(style_df)))
            
        return style_df

def calculate_big_five_scores(responses):
    """Calculate Big Five personality trait scores from BFI-10 responses."""
    # Extract responses (0-indexed list, but questions are 1-indexed in content)
    # For reverse-scored items (1R, 3R, 4R, 5R, 7R), convert from 1-5 to 5-1 scale
    extraversion = ((6 - responses[0]) + responses[5]) / 2     # Q1R, Q6
    agreeableness = (responses[1] + (6 - responses[6])) / 2    # Q2, Q7R
    conscientiousness = ((6 - responses[2]) + responses[7]) / 2 # Q3R, Q8
    neuroticism = ((6 - responses[3]) + responses[8]) / 2      # Q4R, Q9
    openness = ((6 - responses[4]) + responses[9]) / 2         # Q5R, Q10
    
    return {
        'Extraversion': round(extraversion, 2),
        'Agreeableness': round(agreeableness, 2),
        'Conscientiousness': round(conscientiousness, 2),
        'Neuroticism': round(neuroticism, 2),
        'Openness': round(openness, 2)
    }

def preprocess_data(args):
    preprocessed_data_file = os.path.join(args.output_dir, f'preprocessed_data.pkl')
    if os.path.exists(preprocessed_data_file):
        print(f"Preprocessed data already exists at {preprocessed_data_file}. Loading...")
        with open(preprocessed_data_file, 'rb') as f:
            user_responses_train, user_responses_test, question_ids, final_sampled_user_ids, train_questions_ids, test_questions_ids = pickle.load(f)
        return user_responses_train, user_responses_test, question_ids, final_sampled_user_ids, train_questions_ids, test_questions_ids

    split_dir = os.path.join(args.output_dir, 'question_splits')
    os.makedirs(split_dir, exist_ok=True)

    train_file = os.path.join(split_dir, 'train_questions.json')
    test_file = os.path.join(split_dir, 'test_questions.json')
    
    train_questions = {}
    test_questions = {}

    train_question_path = os.path.join(args.output_dir, "train", "questions.csv")
    test_question_path = os.path.join(args.output_dir, "test", "questions.csv")
    train_questions_df = pd.read_csv(train_question_path)
    test_questions_df = pd.read_csv(test_question_path)
    train_questions_ids = train_questions_df['ID'].unique().tolist()
    test_questions_ids = test_questions_df['ID'].unique().tolist()
    question_ids = train_questions_df['ID'].unique().tolist() + test_questions_df['ID'].unique().tolist()
    print(f"Train question ids: {len(train_questions_ids)}")
    print(f"Test question ids: {len(test_questions_ids)}")
    print(f"All question ids: {len(question_ids)}")

    for index, row in train_questions_df.iterrows():
        question_id = row['ID']
        question_image_path = os.path.join(args.output_dir, "train", row['image_path'])
        train_questions[question_id] = question_image_path
    for index, row in test_questions_df.iterrows():
        question_id = row['ID']
        question_image_path = os.path.join(args.output_dir, "test", row['image_path'])
        test_questions[question_id] = question_image_path
    with open(train_file, 'w') as f:
        json.dump(train_questions, f)
    with open(test_file, 'w') as f:
        json.dump(test_questions, f)

    train_responses_path = os.path.join(args.output_dir, "train", f"responses.csv")
    test_responses_path = os.path.join(args.output_dir, "test", f"responses.csv")
    train_responses_df = pd.read_csv(train_responses_path)
    test_responses_df = pd.read_csv(test_responses_path)

    final_sampled_user_ids = train_responses_df['human_id'].unique().tolist()
    final_sampled_user_ids = [int(id.split('_')[1]) for id in final_sampled_user_ids]
    print(f"Final sampled user ids: {len(final_sampled_user_ids)}")

    user_responses_train = defaultdict(list)
    user_responses_test = defaultdict(list)


    for index, row in train_responses_df.iterrows():
        user_id = row['human_id']
        if 'user_' in user_id:
            user_id = user_id.split('_')[1]
        user_id = int(user_id)
        question_id = row['image_id']
        response = row['response']
        emotion = row['emotions']
        explanation = row['explanation']
        image_path = train_questions[question_id]
        user_responses_train[user_id].append((question_id, image_path, response, emotion, explanation))
    
    for index, row in test_responses_df.iterrows():
        user_id = row['human_id']
        if 'user_' in user_id:
            user_id = user_id.split('_')[1]
        user_id = int(user_id)
        question_id = row['image_id']
        response = row['response']
        emotion = row['emotions']
        explanation = row['explanation']
        image_path = test_questions[question_id]
        user_responses_test[user_id].append((question_id, image_path, response, emotion, explanation))
    
    # Save the preprocessed data
    with open(preprocessed_data_file, 'wb') as f:
        pickle.dump((user_responses_train, user_responses_test, question_ids, final_sampled_user_ids, train_questions_ids, test_questions_ids), f, protocol=pickle.HIGHEST_PROTOCOL)
    
    return user_responses_train, user_responses_test, question_ids, final_sampled_user_ids, train_questions_ids, test_questions_ids

def extract_json_from_text(text):
    try:
        # Attempt to fix potential encoding issues
        try:
            text_utf8 = text.encode('latin1').decode('utf-8')
            text = text_utf8
        except UnicodeDecodeError:
            try:
                text_utf8 = text.encode('cp1252').decode('utf-8')
                text = text_utf8
            except UnicodeDecodeError:
                pass # Use original text if decoding attempts fail
        except Exception:
            pass # Catch any other error during encoding/decoding
    except Exception: # Catch error if text is not a string (e.g. bytes)
        pass

    match = re.search(r'```json\s*(\{.*?\})\s*```', text, re.DOTALL | re.IGNORECASE)
    if match:
        json_string = match.group(1)
        json_string_cleaned = json_string.replace('"', '"').replace('"', '"')
        json_string_cleaned = json_string_cleaned.replace("'", "'").replace("'", "'")
        json_string_cleaned = json_string_cleaned.replace('\\', '\\\\')

        try:
            # Parse the cleaned string as JSON
            parsed_json = json.loads(json_string_cleaned)
            emotion = parsed_json.get("emotions") 
            explanation = parsed_json.get("explanation") 
            return parsed_json, emotion, explanation
        except json.JSONDecodeError as e:
            json_string_cleaned_alt = json_string_cleaned.replace("'", "\\'")
            try:
                parsed_json = json.loads(json_string_cleaned_alt)
                emotion = parsed_json.get("emotions")
                explanation = parsed_json.get("explanation")
                return parsed_json, emotion, explanation
            except json.JSONDecodeError as e2:
                return json_string, None, None

    else:
        try:
            text_cleaned = text.replace('"', '"').replace('"', '"')
            text_cleaned = text_cleaned.replace("'", "'").replace("'", "'")
            text_cleaned = text_cleaned.replace('\\', '\\\\')
            parsed_json = json.loads(text_cleaned.strip())
            emotion = parsed_json.get("emotions")
            explanation = parsed_json.get("explanation")
            return parsed_json, emotion, explanation
        except json.JSONDecodeError:
            # if text and not text.isspace():
            #     pass
            return text, None, None 




def generate_human_text_representation(df, chosen_texts):
    representation = []
    for i, row in df.iterrows():
        statement_text = row['statement_text']
        chosen_text = chosen_texts[i]
        representation.append(f"{statement_text}: {chosen_text}")
    return "\n".join(representation)


def gen_synthetic_humans(df, num_humans=100, seed=42):
    """Generate synthetic humans using parametric distributions for Big Five personality traits."""
    # Set the random seed for reproducibility
    np.random.seed(seed)
    
    personality_clusters = {
    "Extraverted Leaders": {
        "E": (4.8, 0.3),   
        "A": (2.6, 0.5),
        "C": (1.5, 0.4),   
        "N": (0.9, 0.5),
        "O": (1.2, 0.6),
    },
    "Anxious Creatives": {
        "E": (1.7, 0.5),
        "A": (1.0, 0.5),
        "C": (1.0, 0.6),
        "N": (4.7, 0.3),   
        "O": (2.5, 0.3),  
    },
    "Conscientious Introverts": {
        "E": (1.2, 0.3),  
        "A": (2.5, 0.5),
        "C": (4.8, 0.3),   
        "N": (1.4, 0.5),
        "O": (1.6, 0.4),
    }
    }    
    question_traits = {
        1: ("E", True), 2: ("A", False), 3: ("C", True), 4: ("N", True), 5: ("O", True),
        6: ("E", False), 7: ("A", True), 8: ("C", False), 9: ("N", False), 10: ("O", False),
    }
    
    cluster_names = list(personality_clusters.keys())
    
    cluster_weights = [0.55, 0.2, 0.25]
    if len(cluster_weights) != len(cluster_names) or not np.isclose(sum(cluster_weights), 1.0):
        raise ValueError("Cluster weights must match the number of clusters and sum to 1.0")

    synthetic_data = []
    
    # Generate synthetic humans
    for human_id in range(num_humans):
        cluster_name = np.random.choice(cluster_names, p=cluster_weights)
        cluster_traits = personality_clusters[cluster_name]
        
        chosen_values = []
        chosen_texts = []
        
        for _, row in df.iterrows():
            statement_id = row['statement_id'] 
            option_values = eval(row['option_values'])
            option_texts = eval(row['option_texts'])
            
            trait, is_reversed = question_traits[statement_id]
            trait_mean, trait_std = cluster_traits[trait]
            
            if is_reversed:
                effective_mean = 6 - trait_mean
            else:
                effective_mean = trait_mean
                
            a = (1 - effective_mean) / trait_std
            b = (5 - effective_mean) / trait_std
            sample = stats.truncnorm.rvs(a, b, loc=effective_mean, scale=trait_std)
            
            value_idx = int(np.round(np.clip(sample, 1, 5))) - 1
            chosen_value = option_values[value_idx]
            chosen_text = option_texts[value_idx]
            
            chosen_values.append(chosen_value)
            chosen_texts.append(chosen_text)
        
        big_five_scores = calculate_big_five_scores(chosen_values)
        text_representation = generate_human_text_representation(df, chosen_texts)

        record = {
            'human_id': f"user_{human_id + 1}",
            'cluster': cluster_name,
            'chosen_option_values': chosen_values,
            'chosen_option_texts': chosen_texts,
            'text_representation': text_representation
        }
        
        record.update(big_five_scores)
        synthetic_data.append(record)
    
    result_df = pd.DataFrame(synthetic_data)
    return result_df

def load_data(data_path):
    """Load BFI-10 data from CSV."""
    df = pd.read_csv(data_path)
    return df


def process_synthetic_humans(args):
    """Main function to process synthetic human data with caching support."""
    # Define file paths
    data_path = os.path.join("data", "wikiarts", "bfi10.csv")
    output_dir = args.output_dir
    os.makedirs(output_dir, exist_ok=True)
    output_path = os.path.join(output_dir, f"synthetic_humans_n{args.n_humans}.csv")
    viz_output_path = os.path.join(output_dir, 'plots', f"wikiarts_simulated_humans_n{args.n_humans}")
    
    # Check if output file exists and whether to force recomputation
    if os.path.exists(output_path) and not args.force_recompute:
        print(f"Loading existing synthetic humans from {output_path}")
        synthetic_humans_df = pd.read_csv(output_path)
        # Handle list columns that were stored as strings
        for col in ['chosen_option_values', 'chosen_option_texts']:
            if col in synthetic_humans_df.columns and isinstance(synthetic_humans_df[col].iloc[0], str):
                try:
                    synthetic_humans_df[col] = synthetic_humans_df[col].apply(ast.literal_eval)
                except (ValueError, SyntaxError) as e:
                    print(f"Warning: Could not parse column {col} with ast.literal_eval: {e}. Falling back to eval.")
                    synthetic_humans_df[col] = synthetic_humans_df[col].apply(eval)
    else:
        print(f"Generating {args.n_humans} synthetic humans with seed {args.random_seed}")
        # Load the BFI-10 data
        bfi_df = load_data(data_path)
        
        # Generate synthetic humans
        synthetic_humans_df = gen_synthetic_humans(bfi_df, num_humans=args.n_humans, seed=args.random_seed)
        
        # Save to CSV
        synthetic_humans_df.to_csv(output_path, index=False)
        print(f"Saved synthetic humans data to: {output_path}")
    
    return synthetic_humans_df


# Helper function for downloading images (extracted and modified from main)
def _download_images_and_update_df(df, output_dir_split):
    """Download images and add local paths to DataFrame"""
    img_folder = os.path.join(output_dir_split, "images")
    os.makedirs(img_folder, exist_ok=True)

    for i, row in df.iterrows():
        img_url = row['Image URL']
        img_filename = os.path.basename(img_url)
        img_path_abs = os.path.join(img_folder, img_filename)
        
        # Store relative path for the CSV
        df.loc[i, 'image_path'] = os.path.join("images", img_filename)
        
        if not os.path.exists(img_path_abs):
            try:
                import urllib.request
                urllib.request.urlretrieve(img_url, img_path_abs)
                print(f"Downloaded {img_filename} to {img_path_abs}")
            except Exception as e:
                print(f"Failed to download {img_url}: {e}")
    return df

def prepare_artwork_data(args):
    """Loads WikiArt data, splits into train/test, and downloads images."""
    train_csv_path = os.path.join(args.output_dir, "train", "questions.csv")
    test_csv_path = os.path.join(args.output_dir, "test", "questions.csv")
    
    train_dir = os.path.join(args.output_dir, "train")
    test_dir = os.path.join(args.output_dir, "test")

    if os.path.exists(train_csv_path) and os.path.exists(test_csv_path) and not args.force_recompute:
        print("Loading existing train/test splits for artwork...")
        train_df = pd.read_csv(train_csv_path)
        test_df = pd.read_csv(test_csv_path)
        print(f"Train set size: {len(train_df)}")
        print(f"Test set size: {len(test_df)}")
    else:
        print("Preparing artwork data: sampling, splitting, and downloading images...")
        art_df = get_art_sample('Impressionism', n_sample=args.n_questions, seed=args.random_seed)
        print(f"\nSampled Artworks size: {art_df.shape}")
        
        train_size = int(len(art_df) * args.split_size)
        train_df = art_df.iloc[:train_size].copy()
        test_df = art_df.iloc[train_size:].copy()
        
        print(f"Train set size: {len(train_df)}")
        print(f"Test set size: {len(test_df)}")

        for d_path in [train_dir, test_dir]:
            img_dir_path = os.path.join(d_path, "images")
            if os.path.exists(img_dir_path) and args.force_recompute: # Delete only if force_recompute
                try:
                    shutil.rmtree(img_dir_path)
                    print(f"Removed existing images directory: {img_dir_path}")
                except Exception as e:
                    print(f"Error deleting {img_dir_path}: {e}")
            os.makedirs(img_dir_path, exist_ok=True)

        print("\nProcessing train images...")
        train_df = _download_images_and_update_df(train_df, train_dir)
        train_df.to_csv(train_csv_path, index=False)
        print(f"Saved train questions to: {train_csv_path}")

        print("\nProcessing test images...")
        test_df = _download_images_and_update_df(test_df, test_dir)
        test_df.to_csv(test_csv_path, index=False)
        print(f"Saved test questions to: {test_csv_path}")
        
    return train_df, test_df

def generate_all_model_responses(args, synthetic_humans_df, train_df, test_df, model, processor):
    """Queries the model with synthetic humans for train and test splits."""
    instruction = WIKIARTS_SYNTHETIC_HUMAN_INSTRUCTION

    for split in args.splits:
        base_dir = os.path.join(args.output_dir, split)
        output_csv_path = os.path.join(base_dir, f"responses.csv")

        if os.path.exists(output_csv_path) and not args.force_recompute:
            print(f"Loading existing responses from {output_csv_path} for {split} split.")
            continue

        print(f"\nGenerating responses for {split} split...")
        all_responses_list = []
        data_df = train_df if split == "train" else test_df

        for index, human_row in synthetic_humans_df.iterrows():
            human_id = human_row["human_id"]
            persona = human_row["text_representation"]
            sys_msg = WIKIARTS_SYNTHETIC_HUMAN_SYSTEM_MESSAGE + persona
            print(f"Processing human: {human_id} for split: {split} ({index + 1}/{len(synthetic_humans_df)})")

            batch_messages_list = []
            batch_image_ids = []

            for _, image_row in data_df.iterrows():
                image_id = image_row["ID"]
                # Ensure image_path is absolute for the model
                image_path_abs = os.path.join(args.output_dir, split, image_row["image_path"])


                messages = [
                    {"role": "system", "content": [{"type": "text", "text": sys_msg}]},
                    {"role": "user", "content": [
                        {"type": "text", "text": "Painting:\n"},
                        {"type": "image", "image": image_path_abs},
                        {"type": "text", "text": instruction},
                    ]}
                ]
                batch_messages_list.append(messages)
                batch_image_ids.append(image_id)
            
            if not batch_messages_list:
                print(f"  No images to process for human {human_id} in {split} split.")
                continue

            try:
                print(f"  Tokenizing batch of {len(batch_messages_list)} images for human {human_id}...")
                inputs = processor.apply_chat_template(
                    batch_messages_list,
                    add_generation_prompt=True,
                    tokenize=True,
                    return_dict=True,
                    return_tensors="pt",
                    padding=True
                ).to(model.device)

                input_len = inputs["input_ids"].shape[-1]

                print(f"  Generating responses for human {human_id}...")
                with torch.inference_mode():
                    generation = model.generate(
                        **inputs,
                        max_new_tokens=512,
                        do_sample=True,
                        temperature=1.0,
                        pad_token_id=processor.tokenizer.pad_token_id
                    )
                
                generated_tokens = generation[:, input_len:]

                print(f"  Decoding responses for human {human_id}...")
                decoded_batch = processor.batch_decode(generated_tokens, skip_special_tokens=True)

                for i, decoded_text in enumerate(decoded_batch):
                    current_image_id = batch_image_ids[i]
                    raw_or_parsed_json, emotion, explanation = extract_json_from_text(decoded_text)
                    all_responses_list.append({
                        "human_id": human_id,
                        "image_id": current_image_id,
                        "emotions": emotion,
                        "explanation": explanation,
                        "response": json.dumps(raw_or_parsed_json) if isinstance(raw_or_parsed_json, dict) else raw_or_parsed_json
                    })
                
                del inputs, generation
                torch.cuda.empty_cache()

            except Exception as e:
                print(f"Error processing batch for human {human_id}: {e}")
                torch.cuda.empty_cache()
        
        if all_responses_list:
            responses_df = pd.DataFrame(all_responses_list)
            responses_df.to_csv(output_csv_path, index=False)
            print(f"Saved {len(responses_df)} responses to {output_csv_path}")
        else:
            print(f"No responses generated for split {split}.")

def run_embedding_generation_pipeline(args, model, processor):
    """Preprocesses data and generates embeddings for train and test sets."""
    print("\n-----Starting Embedding Generation-----")
    user_responses_train, user_responses_test, _, _, _, _ = preprocess_data(args)
    
    # Note: preprocess_data returns final_sampled_user_ids which might be needed by generate_and_save_embeddings
    # Re-calling preprocess_data to get all its return values if they are indeed used.
    # If only user_responses_train/test are needed, the call above is fine.
    # For safety, let's get all return values from preprocess_data as generate_and_save_embeddings might expect them.
    _, _, _, final_sampled_user_ids, _, _ = preprocess_data(args)


    train_embs_dir = os.path.join(args.output_dir, 'human_embeddings_train')
    test_embs_dir = os.path.join(args.output_dir, 'human_embeddings_test')

    print("\nGenerating embeddings for the training set...")
    generate_and_save_embeddings(
        model, processor, user_responses_train, train_embs_dir, final_sampled_user_ids, 
        args.force_recompute, desc_suffix="Train", use_llm=True, 
        data_dir=os.path.join(args.output_dir, "train")
    )

    print("\nGenerating embeddings for the test set...")
    generate_and_save_embeddings(
        model, processor, user_responses_test, test_embs_dir, final_sampled_user_ids, 
        args.force_recompute, desc_suffix="Test", use_llm=True, 
        data_dir=os.path.join(args.output_dir, "test")
    )
    print("-----Embedding Generation Complete-----")


def main():
    # Parse command line arguments
    parser = argparse.ArgumentParser(description='Generate synthetic humans, art responses, and embeddings for WikiArts.')
    parser.add_argument('--n_humans', type=int, default=100, 
                        help='Number of synthetic humans to generate')
    parser.add_argument('--n_questions', type=int, default=20,
                        help='Number of artworks to sample')
    parser.add_argument('--random_seed', type=int, default=42,
                        help='Random seed for reproducibility')
    parser.add_argument('--syn_human_model', type=str, default="google/gemma-3-27b-it", # Changed default model name for example
                        help='Model name for the human') 
    parser.add_argument('--emb_model', type=str, default="google/gemma-3-12b-it", # Changed default model name for example
                        help='Model name for embedding generation')    
    parser.add_argument('--output_dir', type=str, default="output/wikiarts",
                        help='Directory to save output files')
    parser.add_argument('--force_recompute', action='store_true', default=False,
                        help='Force recomputation even if output files exist')
    parser.add_argument('--split_size', type=float, default=0.5,
                        help='Proportion of artwork data to use for training')
    parser.add_argument('--splits', nargs='+', type=str, choices=['train', 'test'], default=['train', 'test'],
                        help='Which dataset splits to process for responses and embeddings')
    
    args = parser.parse_args()

    os.makedirs(args.output_dir, exist_ok=True)
    
    # Stage 1: Process synthetic humans
    print("-----Processing Synthetic Humans-----")
    synthetic_humans_df = process_synthetic_humans(args)
    print(f"Loaded/Generated {len(synthetic_humans_df)} synthetic humans.")

    # Stage 2: Load WikiArt data, split into train/test, and download images
    print("\n-----Preparing Artwork Data-----")
    train_art_df, test_art_df = prepare_artwork_data(args)
    
    # Stage 3: Initialize model and processor
    print("\n-----Initializing Model and Processor-----")
    model, processor = init_agent_wikiarts(args.syn_human_model, adapter_path=None) 
    print(f"Initialized model: {args.syn_human_model}")

    # Stage 4: Query the model with synthetic humans
    print("\n-----Generating Model Responses-----")
    generate_all_model_responses(args, synthetic_humans_df, train_art_df, test_art_df, model, processor)

    # Stage 5: Generate embeddings
    emb_model, emb_processor = init_agent_wikiarts(args.emb_model, adapter_path=None) 
    run_embedding_generation_pipeline(args, emb_model, emb_processor)

    print("\nProcessing complete.")

if __name__ == "__main__":
    main()