"""
Run bedrock inference given an arn of the model and a jsonl file with the input data.
Will mask out the final "assistant" message in the input data.
"""

import boto3
import json
import argparse
from tqdm import tqdm
from tqdm.contrib.concurrent import thread_map
from copy import deepcopy
import os
from time import sleep
import time
import logging
from dataclasses import dataclass
import datetime
import random

logging.getLogger().setLevel(logging.INFO)


def DeleteThroughput(arn):
    import boto3
    bedrock = boto3.client(service_name='bedrock')
    try:
        bedrock.delete_provisioned_model_throughput(provisionedModelId=arn)
    except bedrock.exceptions.ValidationException as e:
        return False
    except bedrock.exceptions.ResourceNotFoundException as e:
        return False
    except bedrock.exceptions.InternalServerException as e:
        return False
    except bedrock.exceptions.ThrottlingException as e:
        return False
    except bedrock.exceptions.ConflictException as e:
        return False
    except bedrock.exceptions.AccessDeniedException as e:
        return False
    return True
def FindThroughput(model_name):
    import boto3
    bedrock = boto3.client(service_name='bedrock')
    # Find the model arn
    models = bedrock.list_custom_models(nameContains=model_name).get('modelSummaries')
    # If there are multiple matching models, look for an exact match,
    # otherwise take the latest one
    models.sort(key=lambda x: x.get('creationTime'), reverse=True)
    model = models[0]
    for m in models:
        if m.get('modelName') == model_name:
            model = m
            break
    model_arn = model.get('modelArn')
    # Check if there is an active provisioned throughput for this model
    ptps = bedrock.list_provisioned_model_throughputs(modelArnEquals=model_arn)
    ptps = ptps.get('provisionedModelSummaries')
    # First try to return an active PT, and if that doesn't exist, then
    # one that is being created.
    ptp = None
    for p in ptps:
        if ThroughputReady(p.get('provisionedModelArn')):
            ptp = p
            break
    if ptp is None:
        for p in ptps:
            if ThroughputStarting(p.get('provisionedModelArn')):
                ptp = p
                break

    return ptp

def AnyHaveThroughput(*model_arns):
    import boto3
    bedrock = boto3.client(service_name='bedrock')
    for arn in model_arns:
        ptps = bedrock.list_provisioned_model_throughputs(modelArnEquals=arn).get('provisionedModelSummaries')
        if len(ptps) > 0:
            return arn
    return None

def GroupByModelArn(configs):
    # Group the configs by model arn
    grouped_configs = {}
    for config in configs:
        model_arn = config.get('model')
        if model_arn not in grouped_configs:
            grouped_configs[model_arn] = []
        grouped_configs[model_arn].append(config)
    return grouped_configs

def OrderConfigs(configs):
    starting_group = AnyHaveThroughput(*configs.keys())
    if starting_group:
        return configs[starting_group] + [c for arn, cfgs in configs.items() for c in cfgs if arn != starting_group]
    return [c for cfgs in configs.values() for c in cfgs]

def GetOrCreateThroughput(model_arn):
    import boto3
    bedrock = boto3.client(service_name='bedrock')
    ptp = None
    ptps = bedrock.list_provisioned_model_throughputs(modelArnEquals=model_arn).get('provisionedModelSummaries')
    created_throughput = False
    if len(ptps) > 0:
        # First try to return an active PT, and if that doesn't exist, then
        # one that is being created.
        for p in ptps:
            if ThroughputReady(p.get('provisionedModelArn')):
                ptp = p
                break
        if ptp is None:
            for p in ptps:
                if ThroughputStarting(p.get('provisionedModelArn')):
                    ptp = p
                    break
    if ptp is None:
        # Create a new provisioned throughput
        pt_name = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
        response = bedrock.create_provisioned_model_throughput(
            modelUnits=1,
            provisionedModelName=pt_name,
            modelId=model_arn
        )

        ptp = response.get('provisionedModelArn')
        created_throughput = True
    WaitforService(ptp)
    return ptp, created_throughput

def WaitforService(ptp):
    import time
    wait_time = 1
    max_wait_time = 60
    while not ThroughputReady(ptp):
        status = ThroughputStatus(ptp)
        logging.info(f"Waiting for provisioned throughput {ptp} to be ready. Status: {status}")
        time.sleep(wait_time)
        wait_time = min(wait_time * 2, max_wait_time)

def UpdateThroughput(ptp, model_arn):
    import boto3
    bedrock = boto3.client(service_name='bedrock')
    # Update the provisioned throughput
    bedrock.update_provisioned_model_throughput(
        provisionedModelId=ptp,
        desiredModelId=model_arn
    )

    model_id = GetModelId(model_arn)
    # Wait for the update to complete
    WaitforService(ptp)
    return model_id

def GetModelId(arn):
    import boto3
    bedrock = boto3.client(service_name='bedrock')
    
    # Check if the arn is a foundation model
    foundation_models = bedrock.list_foundation_models().get('modelSummaries')
    foundation_models = [fm for fm in foundation_models if fm.get('modelArn') == arn]
    if len(foundation_models) > 0:
        model = foundation_models[0]   
    else:
        model = bedrock.get_custom_model(modelIdentifier=arn)
    model_id = model.get('modelName')
    return model_id

def ModelDetailsFromARN(arn):
    """
    Determine if the arn refers to a custom model, a provisioned throughput, or is the endpoint for a
    foundation model. Return the model name and the arn inference arn.
    """
    import boto3
    bedrock = boto3.client(service_name='bedrock')
    
    # Otherwise, it is a foundation model
    foundation_models = bedrock.list_foundation_models().get('modelSummaries')
    foundation_models = [fm for fm in foundation_models if fm.get('modelArn') == arn or fm.get('modelId') == arn]
    if len(foundation_models) > 0:
        model = foundation_models[0]
        model_name = model.get('modelName')
        return model_name, arn, False
    
    # Check if the arn is a provisioned throughput
    ptps = bedrock.list_provisioned_model_throughputs(modelArnEquals=arn).get('provisionedModelSummaries')
    if len(ptps) > 0:
        ptp = ptps[0]
        model_arn = ptp.get('modelArn')
        ptp_arn = ptp.get('provisionedModelArn')
        model_name = bedrock.get_custom_model(modelIdentifier=model_arn).get('modelName')
        WaitforService(ptp_arn)
        return model_name, ptp_arn, False
    
    model = bedrock.get_custom_model(modelIdentifier=arn)
    model_arn = model.get('modelArn')
    model_name = model.get('modelName')
    ptp, created_throughput = GetOrCreateThroughput(model_arn)
    return model_name, ptp, created_throughput


def ThroughputStatus(arn):
    import boto3
    bedrock = boto3.client(service_name='bedrock')
    status = bedrock.get_provisioned_model_throughput(provisionedModelId=arn).get('status')
    return  status
def ThroughputUpdating(arn):
    import boto3
    bedrock = boto3.client(service_name='bedrock')
    status = bedrock.get_provisioned_model_throughput(provisionedModelId=arn).get('status')
    return  status == 'Updating'
def ThroughputStarting(arn):
    import boto3
    bedrock = boto3.client(service_name='bedrock')
    status = bedrock.get_provisioned_model_throughput(provisionedModelId=arn).get('status')
    return  status == 'Creating'
def ThroughputReady(arn):
    import boto3
    bedrock = boto3.client(service_name='bedrock')
    status = bedrock.get_provisioned_model_throughput(provisionedModelId=arn).get('status')
    return  status == 'InService'

class Timer:    
    def __enter__(self):
        self.start = time.time()
        return self

    def __exit__(self, *args):
        self.end = time.time()
        self.interval = self.end - self.start

def ChainedInference(configs, args):
    configs = GroupByModelArn(configs)
    configs = OrderConfigs(configs)
    # Find or Create throughput for the starting config
    current_arn = configs[0]['model']
    model_id, ptp, created_throughput = ModelDetailsFromARN(current_arn)
    for config in configs:
        model_arn = config['model']
        if model_arn != current_arn:
            current_arn = model_arn
            model_id = UpdateThroughput(ptp, current_arn)
        run_inference(model_id, config, ptp, args)
    if created_throughput or args.delete:
        # Delete the provisioned throughput
        logging.info(f"Deleting provisioned throughput {ptp}")
        DeleteThroughput(ptp)

def run_inference(model_id, config, ptp, args):
    if args.parallel:
        parallel_inference(model_id, config, ptp, args.threads)
    else:
        serial_inference(model_id, config, ptp, args)
    pass



def invoke_bedrock(input, model_arn):
    client = boto3.client(
        "bedrock-runtime",
        region_name="us-east-1",
    )

    i = 0
    response = None
    while response is None:
        try:
            response = client.invoke_model(
                modelId=model_arn,
                body=json.dumps(input)
            )
        except Exception as e:
            i += 1
            logging.warning(f'Exception raised: {e}\nRetrying in {i} seconds')
            time.sleep(i)
    modelOutput = json.loads(response["body"].read())
    return modelOutput


def parse_arguments():
    parser = argparse.ArgumentParser(description="Run Bedrock inference.")
    parser.add_argument('configpath', type=str, help='Path to the configuration file.')
    #parser.add_argument(
    #    "--model_arn", '-a',
    #    type=str,
    #    required=True,
    #    help="The ARN of the Bedrock model to use for inference, either a foundation model, custom model, or provisioned throughput.",
    #)
    #parser.add_argument(
    #    "--input_file", '-i',
    #    type=str,
    #    required=True,
    #    help="Path to the input JSONL file containing the data for inference.",
    #)
    #parser.add_argument(
    #    "--output_file", '-o',
    #    type=str,
    #    required=True,
    #    help="Path to the output JSONL file to save the inference results.",
    #)
    parser.add_argument(
        "--parallel", '-p',
        action='store_true',
        default=False,
        help="Run inference in parallel.",
    )
    parser.add_argument(
        '--threads', '-t',
        type=int,
        default=2,
        help='Number of threads to use for parallel inference. Default is 2.',
    )
    parser.add_argument(
        '--delete', '-d',
        action='store_true',
        default=False,
        help='Delete the provisioned throughput after inference, even if not provisioned by this process.',
    )
    #parser.add_argument('--model_name', '-n',
    #                    type=str,
    #                    default=None,
    #                    help='Human readable name for the model. Default is empty string.',
    #)
    args = parser.parse_args()

    return args
    #import boto3
    #bedrock = boto3.client(service_name = 'bedrock')
    # If the model is a custom model, get the model name from the ARN
    # If it is not custom, then find the name of the model from Amazon

    # Trying to get the actual model name if this is a custom model
    #model_id = bedrock.get_custom_model(modelIdentifier=args.).get('modelName')

    #return BedrockInferenceConfig(
    #    model_arn=args.model_arn,
    #    input_file=args.input_file,
    #    output_file=args.output_file,
    #    parallel=args.parallel,
    #    threads=args.threads,
    #    model_name=args.model_name,
    #    model_id=args.model_arn.split('/')[-1]
    #)


def run_example(packed_example):
    """
    example: input data + labels
    model_arn: inference (throughput) model id
    model_name: human readable name for the model
    model_id: full name of model
    """
    example, model_arn, model_name, model_id = packed_example
    complete_example = deepcopy(example)
    example = example['modelInput']
    example['schemaVersion'] = 'messages-v1'
    messages = example['messages']
    # Mask out the final "assistant" message in the input data
    if messages[-1]['role'] == 'assistant':
        example['messages'] = messages[:-1]
    # Invoke the model and extract the response body.
    modelOutput = invoke_bedrock(example, model_arn)
    output = {
        'modelInput': complete_example,
        'modelOutput': modelOutput,
        'model': model_name,
        'modelId': model_id
    }
    return output

def get_random_temperature():
    """
    Generate a random Nova-compatible temperature.
    Valid values are between 0.00001 and 1, inclusive.
    Uses tenths but uses 0.00001 instead of 0.
    """
    temp = random.randint(0,10) / 10.0
    if temp == 0:
        temp = 0.00001
    return temp

def parallel_inference(model_id, config, ptp, threads):
    print(f"Running in parallel with {threads} threads")
    input_file = config['input_file']
    output_file = config['output_file']
        
    logging.info(f"Input file: {input_file}")
    logging.info(f"Output file: {output_file}")
    logging.info(f"Model ID: {model_id}")
    with open(input_file, 'r') as f:
        examples = [json.loads(l.strip()) for l in f.readlines()]
    mode = 'w'
    if os.path.exists(output_file):
        with open(output_file, 'r') as f:
            n = len(f.readlines())
            examples = examples[n:]
            mode = 'a'
            print(f'{n} examples already processed, starting from there')

    if 'temperature' in config:
        if config['temperature'] == 'random':
            logging.info('Temperure override, adding random temperatures to input examples.')
            for example in examples:
                temp = get_random_temperature()
                example['inferenceConfig'] = {'temperature': temp}
        else:
            try:
                temp = float(config['temperature'])
                if temp == 0:
                    logging.info('Temperature = 0 chosen, but 0.00001 is the minimum allowed. Overriding.')
                    temp = 0.00001
                logging.info(f'Temperature override, adding {temp} temperature to input examples')

                for example in examples:
                    example['inferenceConfig'] = {'temperature': temp}

            except ValueError as e:
                t = config['temperature']
                logging.error(f'Could not convert temperature "{t}" to a float, aborting.')
                exit()

    num_examples = len(examples)
    model_name = config['model_name']
    packed_examples = zip(examples, [ptp] * num_examples, [model_name] * num_examples, [model_id] * num_examples)
    os.makedirs(os.path.dirname(output_file), exist_ok=True)
    with open(output_file, mode) as f:
        for output in thread_map(run_example, packed_examples, max_workers=threads, desc="Invoking Model...", total=num_examples):
            f.write(json.dumps(output) + '\n')
            f.flush()

def main():
    args = parse_arguments()
    with open(args.configpath, 'r') as f:
        config = json.load(f)    
    ChainedInference(config['jobs'], args)


def serial_inference(model_id, config, ptp):
    logging.info("Running in serial")
    input_file = config['input_file']
    output_file = config['output_file']
    model_name = config['model_name']
    logging.info(f"Input file: {input_file}")
    logging.info(f"Output file: {output_file}")
    logging.info(f"Model ID: {model_id}")
    with open(input_file, 'r') as f:
        examples = [json.loads(l.strip()) for l in f.readlines()]
    client = boto3.client(
        "bedrock-runtime",
        region_name="us-east-1",
    )
    mode = 'w'
    if os.path.exists(output_file):
        with open(output_file, 'r') as f:
            n = len(f.readlines())
            examples = examples[n:]
            mode = 'a'
            print(f'{n} examples already processed, starting from there')
    
    if 'temperature' in config:
        if config['temperature'] == 'random':
            logging.info('Temperure override, adding random temperatures to input examples.')
            for example in examples:
                temp = get_random_temperature()
                example['inferenceConfig'] = {'temperature': temp}
        else:
            try:
                temp = float(config['temperature'])
                if temp == 0:
                    logging.info('Temperature = 0 chosen, but 0.00001 is the minimum allowed. Overriding.')
                    temp = 0.00001
                logging.info(f'Temperature override, adding {temp} temperature to input examples')

                for example in examples:
                    example['inferenceConfig'] = {'temperature': temp}

            except ValueError as e:
                t = config['temperature']
                logging.error(f'Could not convert temperature "{t}" to a float, aborting.')
                exit()

    with open(output_file, mode) as f:
        for example in tqdm(examples):
            complete_example = deepcopy(example)
            example = example['modelInput']
            example['schemaVersion'] = 'messages-v1'
            messages = example['messages']
            # Mask out the final "assistant" message in the input data
            if messages[-1]['role'] == 'assistant':
                example['messages'] = messages[:-1]
            # Invoke the model and extract the response body.
            i = 0
            response = None
            while response is None:
                try: 
                    response = client.invoke_model(modelId=ptp, body=json.dumps(example))
                except Exception as e:
                    i += 1 
                    print('Exception Raised:')
                    print(e)
                    print(f'Retrying in {i} seconds')
                    sleep(i)
            modelOutput = json.loads(response["body"].read())
            output = {
                'modelInput': complete_example,
                'modelOutput': modelOutput,
                'model': model_name,
                'modelId': model_id
            }
            f.write(json.dumps(output) + '\n')

if __name__ == '__main__':
    main()
