from .util import *
import pathspec
import pickle
import json
from .templates import *
from tqdm.contrib.concurrent import process_map
import random
from copy import deepcopy, copy
import numpy as np


# Generate a Dataset Given a Directory of Programs

def make_dataset(
        data_location: str, 
        language_description: str, # String or file location
        system_prompt: str, # String or file location
        task_formatters: list, # List of objects that generate task examples for each base example
        program_type: ProgramType = ProgramType.GRAPH,
        ):
    
    # Set up common templates
    language_description = get_text(language_description)
    system_prompt = get_text(system_prompt)
    system_prompt = system_prompt.format(language_description = language_description)
    
    # List all programs to format
    all_examples = list_all_programs(data_location, program_type)

    #for example in all_examples:
    #    format_example(example, program_type,language_description, system_prompt)

    return all_examples

def format_example(
        example: str,
        program_type: ProgramType,
        language_description: str,
        system_prompt: str,
        task_formatters: list
) -> dict:
    """
    Generates a collection of task LLM traces with masks for training.
    Returns them as a dictionary of lists of examples, where the keys
    are the task types.
    """

    # Gather the required data for the example (only do this once)
    # Find the union of required information across the formatters
    required_data = {
        requirement for formatter in task_formatters 
        for requirement in formatter.requirements(program_type)}
    example_data = {name: loader(example) for name,loader in required_data}

    task_examples = {}
    for formatter in task_formatters:
        for task_name, example in formatter.format_example(example, example_data):
            if task_name not in task_examples:
                task_examples[task_name] = []
            task_examples[task_name].append(example)

# ========================================================

task_templates = [
        ('predict_from_image',predict_from_image_template, predict_output_template),
        ('predict_from_code',predict_from_code_template, predict_output_template),
        ('predict_from_image_and_code',predict_from_image_and_code_template, predict_output_template),
        ('generate_from_image',generate_from_image_template, generate_output_template),
        ('generate_from_properties',generate_from_properties_template, generate_output_template),
        ('generate_from_image_and_properties',generate_from_image_and_properties_template, generate_output_template)
    ]

def format_task_example(example_pair):
    i, example = example_pair

    material_properties = load_and_format_properties(example)
    views = rendered_urls(example)

    formatted_examples = []
    
    for rep_type in [ProgramType.DSL, ProgramType.GRAPH]:
        lang = 'python' if rep_type == ProgramType.DSL else 'json'
        lang_template = dsl_code_template if rep_type == ProgramType.DSL else graph_code_template
        api_description = code_api_description if rep_type == ProgramType.DSL else graph_api_description

        code_path = path_append(example, 'code.py') if rep_type == ProgramType.DSL else path_append(example, 'graph.json')
        
        code = read_text_file(code_path)
        code = f'```{lang}\n{code}\n```'
        
        task_vars = {
            'code':code,
            'lang':lang,
            'lang_template':lang_template,
            'api_description':api_description   
        }

        system_message = universal_system_prompt_template.format(**task_vars)
        
        for task_template in task_templates:
            task_name, user_template, output_template = task_template
            user_message = user_template.format(**material_properties, **task_vars, **views)
            assistant_message = output_template.format(**material_properties, **task_vars, **views)
            messages = [
                {'role': 'system', 'content': system_message},
                {'role': 'user', 'content': user_message},
                {'role': 'assistant', 'content': assistant_message}
            ]
            
            formatted_examples.append({
                'idx': i,
                'lang': lang,
                'task': task_name,
                'example': example,
                'messages': messages
            })
    return formatted_examples

def generate_conversations(
        stats_path: str = '../datasets/dataset_stats.pkl',
        conversation_path: str = '../datasets/task_examples.json'
    ):
    if os.path.exists(conversation_path):
        return
    with open(stats_path, 'rb') as f:
        dataset_stats = pickle.load(f)
    all_examples = dataset_stats['example_order']

    formatted_examples = process_map(format_task_example, enumerate(all_examples), total=len(all_examples), desc='Formatting Examples', chunksize=10)
    task_examples = {}
    for examples in formatted_examples:
        for example in examples:
            lang_name = example['lang']
            task_name = example['task']
            if lang_name not in task_examples:
                task_examples[lang_name] = {}
            lang_examples = task_examples[lang_name]
            if task_name not in lang_examples:
                lang_examples[task_name] = []
            lang_examples[task_name].append(example)
    
    if conversation_path:
        with open(conversation_path, 'w') as f:
            json.dump(task_examples, f)

# 3 letter codes to represent task types
task_codes = {
    'predict_from_image':"PIm",
    'predict_from_code':"PCo",
    'predict_from_image_and_code':"PIC",
    'generate_from_image':"GIm",
    'generate_from_properties':"GPr",
    'generate_from_image_and_properties':"GIP"
}
query_codes = {
    'test':"Z", # zero-shot
    'krn':"R", # Random neighbors
    'knn':"N", # Nearest neighbors
    'train': "T"
}

def make_splits(task_examples, train_test_split_path: str):
    if os.path.exists(train_test_split_path):
        splits = read_json(train_test_split_path)
        return splits['train'], splits['test']
    # Split into train and test sets
    num_examples = len(task_examples['python']['generate_from_image'])
    # We have 2693 examples, so let's use 2500 for training and 193 for testing -- this is about 93% training, 7% testing
    # It'll also give us 15000 training examples total, which is within the nova-lite bound (but we'll need to remove 5k for nova-pro)

    shuffled_indices = list(range(num_examples))
    random.shuffle(shuffled_indices)
    train_indices = sorted(shuffled_indices[:2500])
    test_indices = sorted(shuffled_indices[2500:])
    os.makedirs(os.dirname(train_test_split_path), exist_ok=True)
    with open(train_test_split_path, 'w') as f:
        json.dump({'train':train_indices, 'test':test_indices}, f)
    return train_indices, test_indices

def to_nova(ex, inference=False, max_new_tokens = 2000):
    ex = deepcopy(ex)
    training_example ={
        "schemaVersion": 'messages-v1'
    }

    example_id = ex['example_id']

    system_messages = [message for message in ex['messages'] if message['role'] == 'system']
    assert(len(system_messages) <= 1)
    if len(system_messages) == 1:
        system_message = format_prompt_contents(system_messages[0]['content'], llm=LLM.NOVA, inference=inference)
        training_example['system'] = system_message
    messages = [message for message in ex['messages'] if message['role'] != 'system']
    if inference and messages[-1]['role'] == 'assistant':
        messages = messages[:-1]
    for message in messages:
        message['content'] = format_prompt_contents(message['content'], llm=LLM.NOVA, inference=inference)
    training_example['messages'] = messages

    if inference:
        training_example["inferenceConfig"] = {
            "max_new_tokens": max_new_tokens
        }

        training_example = {"recordId": example_id, "modelInput": training_example}

    return training_example

def to_openai(ex, inference=False, o1=False):
    ex = deepcopy(ex)
    example_id = ex['example_id']
    messages = ex['messages']
    if inference and messages[-1]['role'] == 'assistant':
        messages = messages[:-1]
    for message in messages:
        if o1 and message['role'] == 'system':
            message['role'] = 'developer'
        elif not o1 and message['role'] == 'developer':
            message['role'] = 'system'
        message['content'] = format_prompt_contents(message['content'], llm=(LLM.O1 if o1 else LLM.GPT), inference=inference)
    
    if inference:
        if o1:
            body = {"model": "o1", "reasoning_effort": "high", "messages": messages}
        else:
            body = {"model": "gpt-4o-mini", "max_completion_tokens": 2000, "messages": messages}
        return {
            "custom_id": example_id, "method": "POST", "url": "/v1/chat/completions", "body": body}
    else:
        return {'messages': messages}

def to_llava(ex, inference=False):
    ex = deepcopy(ex)
    messages = ex['messages']
    images = []
    example_id = ex['example_id']
    # for LlaVa, the dataloader handles removing the assistant role for inference
    #if inference and messages[-1]['role'] == 'assistant':
    #    messages = messages[:-1]
    for message in messages:
        message['content'] = format_prompt_contents(message['content'], llm = LLM.LLAVA, inference=inference)
        for c in message['content']:
            if c['type'] == 'image':
                images.append(c['source'])
                del c['source'] # removes the key to leave a blank image token
    return {"id": example_id, "images": images, "messages":messages}

def update_id(example, qcode):
    if "recordId" in example:
        example["recordId"] = qcode + example["recordId"][1:]
    if "custom_id" in example:
        example["custom_id"] = qcode + example["custom_id"][1:]
    if "id" in example:
        example["id"] = qcode + example["id"][1:]

def add_context_examples(example, context_examples, qcode):
    icl_example = deepcopy(example)
    update_id(icl_example, qcode)
    context_messages = [cm for c_ex in context_examples for cm in c_ex['messages'] if cm['role'] != 'system' and cm['role'] != 'developer']
    if 'messages' in icl_example:
        messages = icl_example['messages']
    elif 'body' in icl_example:
        messages = icl_example['body']['messages']
    elif 'modelInput' in icl_example:
        messages = icl_example['modelInput']['messages']
    else:
        assert(False)
    
    # Find the first non-system message to insert in-front of:
    for pos,m in enumerate(messages):
        if m['role'] != 'system' and m['role'] != 'developer':
            break
    for i,cm in enumerate(context_messages):
        messages.insert(pos+i, cm)

    # Add the images if they are in the examples

    
    return icl_example

def find_neighbors(neighbors_path, dataset_stats, train_indices, test_indices):
    """
    returns:
    KNN: the K nearest neighbors for each test example
    KRN: K random neighbors for each test example

    The values are indices into the train_indices array
    """
    if os.path.exists(neighbors_path):
        with open(neighbors_path, 'rb') as f:
            neighbors = pickle.load(f)
        return neighbors['KNN'], neighbors['KRN'], neighbors['test_knn'], neighbors['test_krn']
    distances = dataset_stats['distances']
    train_distances = distances[:,train_indices] # column number now indexes into the train_indices array
    sorted_train_examples = np.argsort(train_distances, axis=-1)
    selfless_train_examples = sorted_train_examples[:,1:] # skip first example since that will be itself
    selfless_train_examples[test_indices, :] = sorted_train_examples[test_indices,:-1] # except test rows
    K = 5
    KNN = selfless_train_examples[:,:K]
    steT = selfless_train_examples.copy().T
    np.random.shuffle(steT)
    shuffled_train_examples = steT.T
    KRN = shuffled_train_examples[:,:K]
    os.makedirs(os.dirname(neighbors_path), exist_ok=True)
    test_knn = KNN[test_indices]
    test_krn = KRN[test_indices]
    with open(neighbors_path, 'wb') as f:
        pickle.dump({'KNN':KNN, 'KRN':KRN, 'test_knn': test_knn, 'test_krn': test_krn}, f)
    return KNN, KRN, test_knn, test_krn

def add_query_types(conversations, train_indices, test_indices, test_knn, test_krn):
    """
    Split Conversations into train, test, knn, and krn
    """
    for lang in ['python', 'json']:
        lang_code = 'D' if lang == 'python' else 'G'
        for task_name, task_input_code in task_codes.items():
            examples = conversations[lang][task_name]
            examples_by_query = {}
            for query_type in ['test', 'krn', 'knn', 'train']:
                query_code = query_codes[query_type]
                indices = train_indices if query_type == 'train' else test_indices
                query_examples = []
                for split_idx, example_idx in enumerate(indices):
                    example = copy(examples[example_idx])
                    example_id = f'{query_code}{lang_code}{task_input_code}{example_idx:06d}'
                    icl_examples = []
                    if query_type == 'knn':
                        icl_examples = test_knn[split_idx]
                    if query_type == 'krn':
                        icl_examples = test_krn[split_idx]
                    for icl_example_idx in reversed(icl_examples): # insert from last to first to preserve order
                        icl_example = examples[icl_example_idx]
                        icl_example_messages = icl_example['messages'][1:]
                        example['messages'] = [example['messages'][0]] + icl_example_messages + example['messages'][1:]
                    
                    example['example_id'] = example_id
                    example['split_idx'] = split_idx
                    example['query_type'] = query_type
                    query_examples.append(example)
                examples_by_query[query_type] = query_examples
            conversations[lang][task_name] = examples_by_query
    return conversations
                        

def generate_splits(
    stats_path: str = '../datasets/dataset_stats.pkl',
    conversation_path: str = '../datasets/task_examples_redo.json',
    train_test_split_path: str = '../datasets/train_test_split.json',
    neighbors_path: str = '../datasets/neighbors.pkl',
):
    with open(stats_path, 'rb') as f:
        dataset_stats = pickle.load(f)
    task_examples = read_json(conversation_path)
    train_indices, test_indices = make_splits(task_examples, train_test_split_path)
    KNN, KRN, test_knn, test_krn = find_neighbors(neighbors_path, dataset_stats, train_indices, test_indices)

from functools import partial
def format_for_vlms(example):
    inference = example['query_type'] != 'train'
    formatters = {
        'Nova': to_nova,
        'GPT': partial(to_openai, o1=False),
        'O1': partial(to_openai, o1=True),
        'LlaVa': to_llava
    }

    formatted_examples = {k:v(example, inference=inference) for k,v in formatters.items()}

    return formatted_examples


def submit_tuning_job(model_name, base_model, training_data, output_location, role_arn, timestamp, validation_data=None):
    import boto3

    bedrock = boto3.client(service_name='bedrock')

    job_name = f"{model_name}-{timestamp}"
    if len(job_name) > 63:
        job_name = job_name[:63]
        
    # Set parameters
    customizationType = "FINE_TUNING"
    baseModelIdentifier = base_model
    roleArn = role_arn
    jobName = job_name
    customModelName = model_name
    hyperParameters = {
            "epochCount": "2",
            "batchSize": "1",
            "learningRate": "0.00001",
            "learningRateWarmupSteps": "10"
        }
    trainingDataConfig = {"s3Uri": training_data}
    outputDataConfig = {"s3Uri": output_location}

    print('Job inputs:')
    print(f'baseModelIdentifier: {baseModelIdentifier}')
    print(f'roleArn: {roleArn}')
    print(f'jobName: {jobName}')
    print(f'customModelName: {customModelName}')
    print(f'hyperParameters: {hyperParameters}')
    print(f'trainingDataConfig: {trainingDataConfig}')
    print(f'outputDataConfig: {outputDataConfig}')

    # Create job
    if validation_data is None:
        response_ft = bedrock.create_model_customization_job(
            jobName=jobName, 
            customModelName=customModelName,
            roleArn=roleArn,
            baseModelIdentifier=baseModelIdentifier,
            hyperParameters=hyperParameters,
            trainingDataConfig=trainingDataConfig,
            outputDataConfig=outputDataConfig,
            customizationType=customizationType,
        )
    else:
        response_ft = bedrock.create_model_customization_job(
            jobName=jobName, 
            customModelName=customModelName,
            roleArn=roleArn,
            baseModelIdentifier=baseModelIdentifier,
            hyperParameters=hyperParameters,
            trainingDataConfig=trainingDataConfig,
            outputDataConfig=outputDataConfig,
            customizationType=customizationType,
            validationDataConfig={"validators": [{"s3Uri": validation_data}]},
        )

    jobArn = response_ft.get('jobArn')

    return jobArn

def run_bedrock_training(
        model_name, 
        training_data, 
        output_location, 
        role_arn = 'arn:aws:iam::537124976905:role/service-role/TrainMetagenV1DSLNovaLite10k', 
        base_model = 'arn:aws:bedrock:us-east-1::foundation-model/amazon.nova-lite-v1:0:300k',
        validation_data = None
    ):
    from datetime import datetime
    now = datetime.now()
    timestamp = now.strftime("%Y%m%d%H%M%S")
    
    model_name = model_name + '-' + timestamp
    if len(model_name) > 63:
        model_name = model_name[:63]
    output = output_location.rstrip('/') + '/' + model_name

    jobArn = submit_tuning_job(model_name, base_model, training_data, output, role_arn, timestamp, validation_data)

    return jobArn

def check_status(jobArn):
    """
    returns InProgress | Completed | Failed | Stopping | Stopped
    """
    bedrock = boto3.client(service_name='bedrock')
    return bedrock.get_model_customization_job(jobIdentifier=jobArn).get('status')

def is_job_complete(jobArn):
    return check_status(jobArn).upper() in ['COMPLETE', 'COMPLETED']

def get_output_model_arn(jobArn):
    """
    returns the model ARN of the output model
    """
    bedrock = boto3.client(service_name='bedrock')
    response = bedrock.get_model_customization_job(jobIdentifier=jobArn)
    return response.get('outputModelArn')

def run_and_monitor_bedrock_training(
        model_name, 
        training_data, 
        output_location, 
        role_arn = 'arn:aws:iam::537124976905:role/service-role/TrainMetagenV1DSLNovaLite10k', 
        base_model = 'arn:aws:bedrock:us-east-1::foundation-model/amazon.nova-lite-v1:0:300k',
        validation_data = None
    ):
    import time
    print(f'Submitting tuning job for {model_name}')
    jobArn = run_bedrock_training(model_name, training_data, output_location, role_arn, base_model, validation_data)
    print(f'Job ARN: {jobArn}')
    wait_time = 10
    print(f'Monitoring...')
    failed = False
    while not is_job_complete(jobArn):
        status = check_status(jobArn)
        if status.upper() in ['FAILED', 'FAILURE', 'FAIL'] or 'FAIL' in status.upper():
            print(f"Job failed with status: {status}")
            modelArn = None
            failed = True
            break
        wait_time = wait_time * 2
        if wait_time > 300:
            wait_time = 300
        # Print timestamp and status
        print(f"{time.strftime('%Y-%m-%d %H:%M:%S')} - Job - arn: {jobArn} - status: {status}")
        time.sleep(wait_time)  # Check every minute
    if not failed:
        modelArn = get_output_model_arn(jobArn)
    print(f"Job completed with status: {status}")
    print(f"Model ARN: {modelArn}")
    return modelArn

def enqueue_bedrock_training(configs):
    num_jobs = len(configs)
    print(f'Enqueuing {num_jobs} jobs')
    for i, config in enumerate(configs):
        model_name = config['model_name']
        training_data = config['training_data']
        output_location = config['output_location']
        role_arn = config['role_arn']
        base_model = config['base_model']
        print(f'Enqueuing job {i+1}/{num_jobs} for {model_name}')
        modelArn = run_and_monitor_bedrock_training(model_name, training_data, output_location, role_arn, base_model)
        configs[i]['modelArn'] = modelArn
    return configs
