"""
CAID/LLuMi Language Server for Metagen
"""

from .benchmarks import code_api_description, universal_system_prompt_template
#from .templates import code_api_description, universal_system_prompt_template
from argparse import ArgumentParser
from multiprocessing.connection import Listener
#import logging
from .util import exec_with_return
import io
import traceback
import contextlib
import sys
import boto3
import json

def run_server():
    parser = ArgumentParser()
    parser.add_argument('port', type=int, default=6027)
    parser.add_argument('--host', type=str, default='0.0.0.0')
    parser.add_argument('--arn', '-a', type=str, default="arn:aws:bedrock:us-east-1:537124976905:provisioned-model/oikph7kc38le")
    args = parser.parse_args()

    client = boto3.client(
        "bedrock-runtime",
        region_name="us-east-1",
    )

    MODEL_ID = args.arn#"arn:aws:bedrock:us-east-1:537124976905:provisioned-model/oikph7kc38le"

    address = (args.host, args.port)
    listener = Listener(address)
    print(f'Listening to {args.host} on port {args.port}', file=sys.stderr)
    while True:
        conn = listener.accept()
        print(f'connection accepted from {listener.last_accepted}', file=sys.stderr)
        while True:
            msg = conn.recv()
            if not isinstance(msg, dict):
                conn.send({
                    'type':'error',
                    'message':'Command must be a dictionary.'
                })
                continue
            if 'type' not in msg:
                conn.send({
                    'type':'error',
                    'message': 'Command must contain a "type".'
                })
                continue
            message_type = msg['type']
            if msg['type'] == 'close':
                conn.send({
                    'type': 'close'
                })
                conn.close()
                break
            elif msg['type'] == 'describe':
                conn.send({
                    'type':'describe',
                    'description': describe_language()
                })
            elif msg['type'] == 'execute':
                if 'code' not in msg:
                    conn.send({
                        'type':'error',
                        'message': 'execute commands must contain a "code".'
                    })
                    continue
                code = msg['code']
                result = execute_code(code)
                conn.send({
                    'type':'execute',
                    'result': result
                })
            elif msg['type'] == 'chat':
                if 'messages' not in msg:
                    conn.send({
                        'type':'error',
                        'message': 'chat commands must contain "messages".'
                    })
                    continue
                messages = msg['messages']
                if not isinstance(messages, list):
                    conn.send({
                        'type':'error',
                        'message': 'The "messages" field must be a list.'
                    })
                    continue
                if len(messages) == 0:
                    conn.send({
                        'type':'error',
                        'message': 'The "messages" list cannot be empty.'
                    })
                    continue
                chat_response = execute_chat(messages, MODEL_ID, client)
                conn.send({
                    'type': 'chat',
                    'response': chat_response
                })

            else:
                conn.send({
                    'type':'error',
                    'message': f'Unknown command type "{message_type}"'
                })
        print(f'disconnected from {listener.last_accepted}', file=sys.stderr)
            

def execute_chat(messages: list, MODEL_ID: str, client) -> dict:
    # filter messages to only include user and assistant roles
    filtered_messages = []
    for message in messages:
        if 'role' not in message or 'content' not in message:
            continue
        role = message['role']
        content = message['content']
        if role in ['user', 'assistant']:
            filtered_messages.append({
                'role': role,
                'content': content
            })
    if not filtered_messages:
        return {'role': 'assistant', 'content': [{'type': 'text', 'value': "No valid messages provided."}]}
    
    # Find the last message with the 'dsl' role
    last_dsl_message = None
    for message in reversed(filtered_messages):
        if message['role'] == 'dsl':
            last_dsl_message = message
            break

    # Format the filtered messages for nova
    # TODO -- use a templating system ot adjust to different models

    # Input format:
    # {'role': 'user/assistant', 'content': [ {'type': 'text', 'value': '...'}, {'type': 'image', 'value': 'data:image/png;base64, ...'} ]}
    # Nova Format:
    # {'role': 'user/assistant', 'content': [ {'text': '...'}, {'image': {'format': 'png', 'source': {'bytes': ...}}} ]}

    formatted_messages = []
    for message in filtered_messages:
        content = []
        for item in message['content']:
            if item['type'] == 'text':
                content.append({'text': item['value']})
            elif item['type'] == 'image':
                content.append({
                    'image': {
                        'format': 'png',
                        'source': {'bytes': item['value'].split(',')[1].strip()}  # Remove the data:image/png;base64, prefix
                    }
                })
            else:
                continue
        formatted_messages.append({
            'role': message['role'],
            'content': content
        })

    system_prompt = universal_system_prompt_template.format(api_description=code_api_description)

    native_request = {
        'messages': formatted_messages,
        'schemaVersion': 'messages-v1',
        'inferenceConfig': {'max_new_tokens': 2000},
        'system': [{'text':system_prompt}] # multiple system messages are allowed...
    }
    

    # Invoke the model and extract the response body.
    response = client.invoke_model(modelId=MODEL_ID, body=json.dumps(native_request))
    modelOutput = json.loads(response["body"].read())
    response_text = modelOutput['output']['message']['content'][0]['text']

    return response_text

def execute_code(code: str) -> dict:

    output_buffer = io.StringIO()

    """
    globals = {
                    "__builtins__": {
                        name: getattr(__builtins__, name)
                        for name in [
                            'abs', 'all', 'any', 'ascii', 'bin', 'bool', 'chr', 
                            'dict', 'dir', 'divmod', 'enumerate', 'filter', 'float', 
                            'format', 'frozenset', 'hash', 'hex', 'int', 'isinstance', 
                            'issubclass', 'len', 'list', 'map', 'max', 'min', 'oct', 
                            'ord', 'pow', 'print', 'range', 'reversed', 'round', 
                            'set', 'slice', 'sorted', 'str', 'sum', 'tuple', 'type', 
                            'zip'
                        ]
                    }
                }
    """

    locals = {}

    try:
        with contextlib.redirect_stdout(output_buffer):
            return_value = exec_with_return(code, locals)
    except Exception as e:
        error_message = f"Error: {str(e)}\n{traceback.format_exc()}"
        return {'role':'dsl', 'content':[
            {'type':'text', 'value': error_message}
        ]}
    if 'make_structure' in locals:
        try:
            structure = locals['make_structure']()
            return structure._repr_llm_()
        except Exception as e:
            error_message = f"Error: {str(e)}\n{traceback.format_exc()}"
            return {'role':'dsl', 'content':[
                {'type':'text', 'value': error_message}
            ]}
    elif return_value is not None:
        try:
            return return_value._repr_llm_()
        except AttributeError:
            representation = repr(return_value)
            return {'role':'dsl', 'content':[
                {'type':'text', 'value': representation}
            ]}
    else:
        output = output_buffer.getvalue()
        return {'role':'dsl', 'content':[{'type':'text', 'value':output}]}

def describe_language() -> str:
    return code_api_description