import itertools
import os
import random
import yaml
from .util import *
from .templates import *
from tqdm import tqdm
import argparse

# ==================================
# Description generation for seeds
# ==================================

def make_description_generation_prompt(uri, representation = ProgramType.GRAPH):
    api_description = graph_api_description if representation == ProgramType.GRAPH else code_api_description
    system_contents = format_prompt_openai(
        description_system_prompt_template.format(
            api_description=api_description))
    system_message = {"role": "system", "content": system_contents}

    code = graph_representation(uri) if representation == ProgramType.GRAPH else code_representation(uri)
    images = rendered_urls(uri)
    prompt_template = description_generation_prompt_template
    user_contents = format_prompt_openai(
        prompt_template.format(
            dsl_code = code,
            top = images['top'],
            front = images['front'],
            right = images['right'],
            top_right = images['top_right']
        ))

    user_message = {"role": "user", "content": user_contents}

    return [system_message, user_message]

def make_simple_description_generation_prompt(uri, representation = ProgramType.GRAPH):
    system_contents = format_prompt_openai("You are an assistant that generates detailed text descriptions for cellular metamaterials based on a rendering of their unit cell.")
    system_message = {"role": "system", "content": system_contents}
    prompt_template = "<[{top_right}]>This is an image of a base cell of a cellular metamaterial. Please describe it in detail, taking into account its structures and symmetries."
    images = rendered_urls(uri)
    user_contents = format_prompt_openai(
        prompt_template.format(
            top_right = images['top_right']
        ))

    user_message = {"role": "user", "content": user_contents}

    return [system_message, user_message]

def make_simple_quad_view_description_generation_prompt(uri, representation = ProgramType.GRAPH):
    system_contents = format_prompt_openai("You are an assistant that generates detailed text descriptions for cellular metamaterials based on a rendering of their unit cell.")
    system_message = {"role": "system", "content": system_contents}
    prompt_template = "Angled View: <[{top_right}]>\nTop View: <[{top}]>\nFront View: <[{front}]>\nRight View: <[{right}]>\nThis is an image of a base cell of a cellular metamaterial. Please describe it in detail, taking into account its structures and symmetries."
    images = rendered_urls(uri)
    user_contents = format_prompt_openai(
        prompt_template.format(
            top = images['top'],
            front = images['front'],
            right = images['right'],
            top_right = images['top_right']
        ))

    user_message = {"role": "user", "content": user_contents}

    return [system_message, user_message]

# This won the o1 evaluation competition
describe_material_prompt = make_simple_quad_view_description_generation_prompt

def evaluate_prompt_gen(prompt_gen, example_locations, debug=False):
    N = len(example_locations)
    
    prompt_template = """You are a cellular metamaterial expert. You will help to evaluate text descriptions of metamaterials.

**Task:**
You will be given renderings of the base cell of a cellular metamaterial from four viewpoints, along with {N} text descriptions of metamaterials, one of which matches the images. Rank the descriptions from best-to-worst match to the images.

**Ouput Format:**
Output only the ranked list as space-separated integers

**Rendered Viewpoints:**
Angled View (Front-Top-Right): <[{top_right}]>
Top View: <[{top}]>
Front View: <[{front}]>
Right View: <[{right}]>

**Material Descriptions:**
{material_descriptions}
"""
    descriptions = []
    example_prompts = [prompt_gen(location) for location in example_locations]
    for prompt in tqdm(example_prompts, desc="Generating Descriptions"):
        descriptions.append(run_prompt_openai(prompt))
    
    rankings = []
    ranking_txts = []
    permutations = []
    for i, location in tqdm(enumerate(example_locations), total=N, desc="Ranking Descriptions"):
        # i = index in original example list
        permutation = list(range(N))
        random.shuffle(permutation)

        material_descriptions = ""
        for dummy_idx,original_idx in enumerate(permutation):
            material_descriptions += f"\nMaterial {dummy_idx}:\n"
            material_descriptions += descriptions[original_idx]
        
        images = rendered_urls(location)

        prompt = format_prompt_openai(prompt_template.format(
            N = N,
            top_right = images['top_right'],
            top = images['top'],
            front = images['front'],
            right = images['right'],
            material_descriptions=material_descriptions
        ))

        ranking_txt = run_prompt_openai([{"role": "user", "content": prompt}], "gpt-4o")
        permuted_ranking = [int(n) for n in ranking_txt.strip().split()]

        ranking = [permutation[r] for r in permuted_ranking]

        rankings.append(ranking)
        ranking_txts.append(ranking_txt)
        permutations.append(permutation)
    
    if debug:
        return rankings, ranking_txts, permutations
    else:
        return rankings

def generate_description_batch_file(
        example_loc="s3://metagen-datasets/data/graph_v2/processed",
        batch_file_loc="./description_batch.jsonl",
        representation: ProgramType = ProgramType.GRAPH):
    all_programs = list_all_programs(example_loc, program_type=representation)
    with open(batch_file_loc, 'w') as f:
        for prog in tqdm(all_programs):
            prompt = describe_material_prompt(prog, representation)
            formatted = format_for_openai_batch(prompt)
            f.write(formatted + '\n')

def description_result_to_description_uri(result):
    dir_uri = url_to_s3(result[0]['messages'][1]['content'][1]['image_url']['url'])
    return path_append(dir_uri, 'description.txt')

def description_result_to_description(result):
    return result[1]['choices'][0]['message']['content']

def write_result_description(result):
    description = description_result_to_description(result)
    dir_uri = description_result_to_description_uri(result)
    bucket, key = parse_s3_uri(dir_uri)
    write_s3(bucket, key, description.encode(encoding='utf-8'))


# ==================================================================
# Augmentation Generation
# ==================================================================

def collect_seeds(seed_dir, representation: ProgramType = ProgramType.GRAPH):
    all_programs = list_all_programs(seed_dir, representation)
    program_categories = [file_dir(p.rstrip('/')) for p in all_programs]
    seed_groups = dict()
    for category, program in zip(program_categories, all_programs):
        if category not in seed_groups:
            seed_groups[category] = []
        seed_groups[category].append(program)
    return seed_groups

def augmentation_samples(seed_dir, max_product_samples = 1000, representation: ProgramType = ProgramType.GRAPH):
    seed_groups = collect_seeds(seed_dir, representation)
    grouped_samples = []
    samples = []
    for g1, g2 in itertools.product(seed_groups.keys(), repeat=2):
        pairs = [(s1, s2) for s1, s2 in itertools.product(seed_groups[g1], seed_groups[g2], repeat=1) if s1 != s2]
        if len(pairs) > max_product_samples:
            pairs = random.sample(pairs, max_product_samples)
        if len(pairs) > 0:
            grouped_samples.append(pairs)
        samples = samples + pairs
    return samples

def augmentation_prompt(parents, representation: ProgramType = ProgramType.GRAPH):

    api_description = graph_api_description if representation == ProgramType.GRAPH else code_api_description

    lang = 'json' if representation == ProgramType.GRAPH else 'python'

    system_message = {"role":"system", "content": format_prompt_openai(
        augmentation_system_prompt_template.format(api_description=api_description)
    )}

    parent_renders = [rendered_urls(p) for p in parents]
    parent_reps = [get_representation(p) for p in parents]
    parent_descriptions = [get_description(p) for p in parents]

    parent_material_info = []
    for i, (renders, code, description) in enumerate(zip(parent_renders, parent_reps, parent_descriptions)):
        material_info = augmentation_material_template.format(
            n=i, 
            top_right=renders['top_right'], 
            top=renders['top'], 
            front=renders['front'],
            right=renders['right'], 
            lang=lang, 
            description=description, 
            code=code)
        parent_material_info.append(material_info)
    parent_material_info = "\n\n".join(parent_material_info)

    user_message = {
        "role": "user",
        "content": format_prompt_openai(
        augmentation_user_prompt_template.format(materials=parent_material_info, lang=lang))
    }

    return [system_message, user_message]

def extract_program(augmentation_result, representation: ProgramType = ProgramType.GRAPH):
    lang = 'json' if representation == ProgramType.GRAPH else 'python'
    try:
        prefix_len = len(f"```{lang}")
        program_start = augmentation_result.index(f"```{lang}")
    except Exception as e:
        prefix_len = 3
        program_start = augmentation_result.index("```")
    program_end = augmentation_result.rindex("```")
    program = augmentation_result[program_start+prefix_len:program_end].strip()
    description = augmentation_result[:program_start].strip()
    return description, program


def synthesize_prompt(parents, representation: ProgramType = ProgramType.DSL, use_renders = False):
    lang = 'python'
    lang_template = dsl_code_template
    api_description = code_api_description
    if representation == ProgramType.GRAPH:
        lang = 'json'
        lang_template = graph_code_template
        api_description = graph_api_description
    
    parent_template = parent_code_and_render_template if use_renders else parent_code_template

    parent_descriptions = []
    for p in parents:
        rep = get_representation(p, representation)
        if use_renders:
            renders = rendered_urls(p)
        parent_descriptions.append(parent_template.format(
            lang=lang,
            code=rep,
            **renders
        ))
    parent_descriptions = '\n'.join(parent_descriptions)

    system = universal_system_prompt_template.format(api_description = api_description)
    
    user_template = synthesize_from_code_and_renders_template if use_renders else synthesize_from_code_template
    user = user_template.format(
        materials = parent_descriptions,
        lang_template = lang_template
    )

    return [
        {'role': 'system', 'content': system},
        {'role': 'user', 'content': user}
    ]


# ======================================================
# Prepare Seeds for S3
# ======================================================
import subprocess
def clean_and_upload_seed_programs(seed_dir: str, sync_dir: str):
    count = 0
    with TempDir() as tmp:
        for dirpath, dirnames, filenames in os.walk(seed_dir):
            for filename in filenames:
                src = path_append(dirpath, filename)
                dst = path_append(tmp, path_append(dirpath[len(seed_dir):], filename))
                with open(src, 'r') as f:
                    code = f.read()
                end = code.index('if __name__ ==')
                code = code[:end].strip()
                os.makedirs(file_dir(dst), exist_ok=True)
                with open(dst, 'w') as f:
                    f.write(code)
                count = count+1
        subprocess.run(['aws', 's3', 'sync', '--delete', tmp, sync_dir])
        print(f'Uploaded {count} seeds')

def run_upload_seeds():
    parser = argparse.ArgumentParser()
    parser.add_argument('-i', '--input', type=str, required=True)
    parser.add_argument('-o', '--output', type=str, required=True)
    args = parser.parse_args()
    clean_and_upload_seed_programs(args.input, args.output)