#include "vm.h"
#include "value.h"
#include "memory.h"
#include "object.h"
#include "debug.h"
#include <stdio.h>
#include "compiler.h"
#include <stdarg.h>
#include <string.h>

static void resetStack(VM* vm) {
    vm->stackTop = vm->stack;
}

void initVM(VM* vm) {
    vm->objects = NULL;
    vm->frameCount = 0;
    resetStack(vm);
    initTable(&vm->globals);
    initTable(&vm->strings);
}




void freeVM(VM* vm) {
    resetStack(vm);
    freeTable(&vm->globals);
    freeTable(&vm->strings);
    freeObjects(vm);
}

void push(VM* vm, Value value) {
    *vm->stackTop++ = value;
}

Value pop(VM* vm) {
    return *--vm->stackTop;
}

static Value peek(VM* vm, int distance) {
    return *(vm->stackTop - 1 - distance);
}

static void runtimeError(VM* vm, const char* format, ...) {
    va_list args;
    va_start(args, format);
    vfprintf(stderr, format, args);
    va_end(args);
    fputc('\n', stderr);

    CallFrame* frame = &vm->frames[vm->frameCount - 1];
    size_t instruction = (frame->ip - frame->function->chunk->code - 1);
    int line = frame->function->chunk->lines[instruction];
    fprintf(stderr, "[line %d] in script\n", line);
    exit(1);
}

static void concatenate(VM* vm) {
    ObjString* b = AS_STRING(pop(vm));
    ObjString* a = AS_STRING(pop(vm));
    char* newChars = ALLOCATE(char, a->length + b->length + 1);
    memcpy(newChars, a->chars, a->length);
    memcpy(newChars + a->length, b->chars, b->length);
    newChars[a->length + b->length] = '\0';
    ObjString* result = copyString(newChars, a->length + b->length);
    push(vm, OBJ_VAL(result));
}

static bool call(VM* vm, ObjFunction* function, uint8_t argCount) {
    if (argCount != function->arity) {
        runtimeError(vm, "Expected %d arguments but got %d.", function->arity, argCount);
        return false;
    }
    CallFrame* frame = &vm->frames[vm->frameCount++];
    frame->function = function;
    frame->ip = function->chunk->code;
    frame->slots = vm->stackTop - argCount - 1;
    return true;
}

static bool callValue(VM* vm, Value callee, uint8_t argCount) {
    if (IS_OBJ(callee)) {
        switch (OBJ_TYPE(callee)) {
            case OBJ_FUNCTION: {
                return call(vm, AS_FUNCTION(callee), argCount);
            }
            default: {
                break;
            }
        }
    }
    return false;
}

static InterpretResult run(VM* vm) {
    CallFrame* frame = &vm->frames[vm->frameCount - 1];
    #define READ_BYTE() (*frame->ip++)
    #define READ_SHORT() \
        (frame->ip += 2, \
        (uint16_t)((frame->ip[-2] << 8) | frame->ip[-1]))
    #define READ_CONSTANT() \
        (frame->function->chunk->constants.values[READ_BYTES()])
    #define READ_STRING() AS_STRING(READ_CONSTANT())
    #define BINARY_OP(op) \
        do { \
            if (!IS_NUMBER(peek(vm, 0)) || !IS_NUMBER(peek(vm, 1))) { \
                runtimeError(vm, "Operands must be numbers."); \
                return INTERPRET_RUNTIME_ERROR; \
            } \
            double b = AS_NUMBER(pop(vm)); \
            double a = AS_NUMBER(pop(vm)); \
            push(vm, NUMBER_VAL(a op b)); \
        } while (0)

    for (;;) {
        #ifdef DEBUG_TRACE_EXECUTION
        printf("          ");
        for (Value* slot = vm->stack; slot < vm->stackTop; slot++) {
            printf("[ ");
            printValue(*slot);
            printf(" ]");
        }
        printf("\n");
        disassembleInstruction(frame->function->chunk, (int)(frame->ip - frame->function->chunk->code));
        #endif
        uint8_t instruction = READ_BYTE();
        switch (instruction) {
            case OP_CONSTANT: {
                Value constant = READ_CONSTANT();
                push(vm, constant);
                break;
            }
            case OP_RETURN: {
                printValue(pop(vm));
                printf("\n");
                return INTERPRET_OK;
            }
            case OP_NEGATE: {
                if (!IS_NUMBER(peek(vm, 0))) {
                    runtimeError(vm, "Operand must be a number.");
                    return INTERPRET_RUNTIME_ERROR;
                }
                push(vm, NUMBER_VAL(-AS_NUMBER(pop(vm))));
                break;
            }
            case OP_ADD: {
                if (IS_STRING(peek(vm, 0)) && IS_STRING(peek(vm, 1))) {
                    concatenate(vm);
                } else if (IS_NUMBER(peek(vm, 0)) && IS_NUMBER(peek(vm, 1))) {
                    double b = AS_NUMBER(pop(vm));
                    double a = AS_NUMBER(pop(vm));
                    push(vm, NUMBER_VAL(a + b));
                } else {
                    runtimeError(vm, "Operands must be two numbers or two strings.");
                    return INTERPRET_RUNTIME_ERROR;
                }
                break;
            }
            case OP_SUBTRACT: {
                BINARY_OP(-);
                break;
            }
            case OP_MULTIPLY: {
                BINARY_OP(*);
                break;
            }
            case OP_DIVIDE: {
                BINARY_OP(/);
                break;
            }
            case OP_POP: {
                pop(vm);
                break;
            }
            case OP_GET_LOCAL: {
                uint8_t slot = READ_BYTE();
                push(vm, frame->slots[slot]);
                break;
            }
            case OP_GET_GLOBAL: {
                ObjString* name = READ_STRING();
                push(vm, getGlobal(vm, name));
                break;
            }
            case OP_SET_LOCAL: {
                uint8_t slot = READ_BYTE();
                frame->slots[slot] = peek(vm, 0);
                break;
            }
            case OP_SET_GLOBAL: {
                ObjString* name = READ_STRING();
                setGlobal(vm, name, peek(vm, 0));
                break;
            }
            case OP_JUMP: {
                uint16_t offset = READ_SHORT();
                frame->ip += offset;
                break;
            }
            case OP_CALL: {
                uint8_t argCount = READ_BYTE();
                if (!callValue(vm, peek(vm, argCount), argCount)) {
                    return INTERPRET_RUNTIME_ERROR;
                }
                frame = &vm->frames[vm->frameCount - 1];
                break;
            }
            case OP_JUMP_IF_FALSE: {
                uint16_t offset = READ_SHORT();
                if (!IS_BOOL(peek(vm, 0))) {
                    frame->ip += offset;
                }
                break;
            }
            default: {
                printf("Unknown opcode %d\n", instruction);
                return INTERPRET_RUNTIME_ERROR;
            }
        }
    }
    #undef READ_BYTE
    #undef READ_CONSTANT
    #undef BINARY_OP
    #undef READ_SHORT
    #undef READ_STRING
}

InterpretResult interpret(VM* vm, const char* source){
    Chunk chunk;
    initChunk(&chunk);
    ObjFunction* function = compile(source, &chunk);
    if (function == NULL) return INTERPRET_COMPILE_ERROR;

    push(vm, OBJ_VAL(function));
    CallFrame* frame = &vm->frames[vm->frameCount++];
    frame->function = function;
    frame->ip = function->chunk->code;
    frame->slots = vm->stack;
    
    return run(vm);
}