#include "compiler.h"
#include "scanner.h"
#include <stdio.h>
#include <string.h>
#include "object.h"

typedef struct{
    Token name;
    int depth;
} Local;

typedef struct{
    struct Compiler* enclosing;
    ObjFunction* function;
    Local locals[UINT8_COUNT];
    int localCount;
    int scopeDepth;
} Compiler;

Compiler* current = NULL;

static void advance(Parser* parser, Scanner* scanner);
static uint8_t makeConstant(Value value);

static void initCompiler(Compiler* compiler, FunctionType type, const char* name) {
    compiler->enclosing = current;
    compiler->function = NULL;
    compiler->localCount = 0;
    compiler->scopeDepth = 0;
    current = compiler;
    compiler->function->name = name;
}

static void errorAt(Parser* parser, Token* token, const char* message){
    fprintf(stderr, "[line %d] Error: %s\n", token->line, message);
    if (token->type == EOF_TOKEN) {
        fprintf(stderr, " at end\n");
    } else if (token->type == ERROR_TOKEN) {
        // Nothing to do
    } else {
        fprintf(stderr, " at '%.*s'\n", token->length, token->start);
    }
    fprintf(stderr, ": %s\n", message);
    parser->hadError = 1;
}

static void consume(Parser* parser, Scanner* scanner, TokenType type, const char* message) {
    if (parser->current.type == type) {
        advance(parser, scanner);
        return;
    } 
    errorAtCurrent(parser, message);
}

static void errorAtCurrent(Parser* parser, const char* message) {
    errorAt(parser, &parser->current, message);
}

static Chunk* currentChunk() {
    return compilingChunk;
}

static void emitByte(Parser* parser, uint8_t byte) {
    writeChunk(currentChunk(), byte, parser->previous.line);
}

static void emitReturn(Parser* parser) {
    emitByte(parser, OP_RETURN);
}

static ObjFunction* endCompiler(Parser* parser) {
    emitReturn(parser);
    ObjFunction* function = current->function;
    #ifdef DEBUG_PRINT_CODE
    if (!parser->hadError) {
        disassembleChunk(currentChunk(), function->name != NULL
            ? function->name->chars : "<script>");
    }
    #endif
    current = current->enclosing;
    return function;
}

static void emitBytes(Parser* parser, uint8_t byte1, uint8_t byte2) {
    emitByte(parser, byte1);
    emitByte(parser, byte2);
}


static void expression(Parser* parser);
static void declaration(Parser* parser);
static void grouping(Parser* parser);

static bool check(Parser* parser, TokenType type) {
    return parser->current.type == type;
}

static bool match(Parser* parser, TokenType type) {
    if (!check(parser, type)) return false;
    advance(parser, &defaultScanner);
    return true;
}

static uint8_t argumentList(Parser* parser) {
    uint8_t arg = 0;
    while (!match(parser, RIGHT_PAREN)) {
        expression(parser);
        arg++;
    }
    return arg;
}

static void call(Parser* parser, bool canAssign) {
    uint8_t arg = argumentList(parser);
    emitBytes(parser, OP_CALL, arg);
    if (canAssign) {
        emitBytes(parser, OP_SET_LOCAL, arg);
    }
}

int compile(const char* source, Chunk* chunk) {
    initScanner(&defaultScanner, source);
    Compiler compiler;
    initCompiler(&compiler);
    compilingChunk = chunk;
    advance(&defaultParser, &defaultScanner);
    while (!match(&defaultParser, EOF_TOKEN)) {
        grouping(&defaultParser);
    }
    endCompiler(&defaultParser);
    return !defaultParser.hadError;
}

static void program(Parser* parser) {
    while (!match(parser, RIGHT_PAREN)) {
        grouping(parser);
    }
}

static void beginScope(Parser* parser) {
    current->scopeDepth++;
}

static void endScope(Parser* parser) {
    current->scopeDepth--;
    while (current->localCount > 0 && current->locals[current->localCount - 1].depth > current->scopeDepth) {
        emitByte(parser, OP_POP);
        current->localCount--;
    }
}

static int resolveLocal(Parser* parser, Token* name) {
    for (int i = current->localCount - 1; i >= 0; i--) {
        Local local = current->locals[i];
        if (identifierMatch(local.name, name)) {
            if (local.depth == -1) {
                errorAtCurrent("Can't read local variable in its own initializer.");
                return -1;
            }
            return i;
        }
    }
    return -1;
}


static void namedVariable(Parser* parser, Token name, bool canAssign) {
    uint8_t getOp, setOp;
    int arg = resolveLocal(parser, &name);
    if (arg != -1) {
        getOp = OP_GET_LOCAL;
        setOp = OP_SET_LOCAL;
    } else {
        arg = identifierConstant(parser, &name);
        getOp = OP_GET_GLOBAL;
        setOp = OP_SET_GLOBAL;
    }

    if (canAssign && match(parser, EQUAL)) {
        advance(parser, &defaultScanner);
        expression(parser);
        emitBytes(parser, setOp, arg);
    } else { // Accessing a variable.
        emitBytes(parser, getOp, arg);
    }
}
        

static void addLocal(Parser* parser, Token name) {
    if (current->localCount >= UINT8_COUNT) {
        errorAtCurrent(parser, "Too many local variables in this scope.");
        return;
    }
    Local local;
    local.depth = current->scopeDepth;
    local.name = name;
    current->locals[current->localCount++] = local;
}

static bool identifierMatch(Token name, Token* expected) {
    if (expected->length != name.length) return false;
    return memcmp(expected->start, name.start, name.length) == 0;
}

static void declareVariable(Parser* parser) {
    // If global scope, we don't need to declare a variable.
    if (current->scopeDepth == 0) return;

    Token* name = &parser->previous;
    for (int i = current->localCount - 1; i >= 0; i--) {
        Local local = current->locals[i];
        if (local.depth != -1 && local.depth < current->scopeDepth) {
            break;
        }
        if (identifierMatch(local.name, name)) {
            if (local.depth == -1) {
                errorAtCurrent(parser, "Can't read local variable in its own initializer.");
            }
            return i;
        }
    }

    addLocal(parser, *name);
}

static void markInitialized() {
    if (current->scopeDepth == 0) return;
    current->locals[current->localCount - 1].depth = current->scopeDepth;
}

static void defineVariable(Parser* parser, uint8_t globalindex) {
    // if we are in a scope, we don't need to define a variable.
    // because it's already sitting on top of the stack.
    if (current->scopeDepth > 0) {
        markInitialized();
        return;
    }

    emitBytes(parser, OP_DEFINE_GLOBAL, globalindex);

}

static void object(Parser* parser) {
    // TODO: Implement object.
}


static void letDeclaration(Parser* parser) {
    beginScope(parser);
    while (!match(parser, RIGHT_PAREN)) {
        expression(parser);
    }
    endScope(parser); 
    // Get the last expression's value and push it to the stack.
    emitByte(parser, OP_RETURN);
}

static uint8_t identifierConstant(Parser* parser, Token* name) {
    return makeConstant(OBJ_VAL(copyString(parser, name->start, name->length)));
}

static uint8_t parseVariable(Parser* parser, const char* errorMessage) {
    consume(parser, &defaultScanner, IDENTIFIER, errorMessage);
    return identifierConstant(parser, &parser->previous);
}

static void functionDeclaration(Parser* parser) {
    // TODO: Implement function declaration.
    current->function->arity = 0;
    if (match(parser, LEFT_PAREN)) {
        // Multiple or no arguments.
        while (!match(parser, RIGHT_PAREN)) {
            uint8_t arg = parseVariable(parser, "Expect variable name.");
            defineVariable(parser, arg);
            current->function->arity++;
        }
    } else {
        // Single argument.
        uint8_t arg = parseVariable(parser, "Expect variable name.");
        defineVariable(parser, arg);
        current->function->arity++;
    }
    // TODO: Implement function body.
    expression(parser);
    ObjFunction* function = endCompiler(parser);
    emitBytes(parser, OP_CONSTANT, makeConstant(OBJ_VAL(function)));
    emitBytes(parser, OP_RETURN);
}

static void grouping(Parser* parser) {
    if (match(parser, PROGRAM)){
        program(parser);
    } else if (match(parser, OBJECT)) {
        // TODO: Implement object.
    } else if (match(parser, LET)){
        letDeclaration(parser);
    } else if (match(parser, FUN) || match(parser, MAPTO)) {
        functionDeclaration(parser);
    } else {
        expression(parser);
    }
    consume(parser, &defaultScanner, RIGHT_PAREN, "Expect ')' after expression.");
}

static void declaration(Parser* parser) {
    if (match(parser, PROGRAM)) {
        program(parser);
    } else if (match(parser, OBJECT)) {
        object(parser);
    } else {
        expression(parser);
    }
}

static void advance(Parser* parser, Scanner* scanner) {
    parser->previous = parser->current;
    for (;;){
        parser->current = scanToken(scanner);
        if (parser->current.type != ERROR_TOKEN) break;
        errorAtCurrent(parser, parser->current.start);
    }
}

static uint8_t makeConstant(Value value) {
    int constant = addConstant(currentChunk(), value);
    if (constant == -1) {
        errorAtCurrent("Too many constants in this scope.");
        return 0;
    }
    return constant;
}

static void emitConstant(Parser* parser, Value value) {
    emitBytes(parser, OP_CONSTANT, makeConstant(value));
}

static void literal(Parser* parser) {
    switch (parser->previous.type) {
        case TRUE:
            emitByte(parser, OP_TRUE);
            break;
        case FALSE:
            emitByte(parser, OP_FALSE);
            break;
        case NIL:
            emitByte(parser, OP_NIL);
            break;
        default:
            //Unreachable.
            return;
    }
}

static void number(Parser* parser) {
    double number = strtod(parser->previous.start, NULL);
    emitConstant(parser, NUMBER_VAL(number));
}


static void string(Parser* parser) {
    emitConstant(OBJ_VAL(
        copyString(parser->previous.start + 1,
                   parser->previous.length - 2)));
}

void expression(Parser* parser) {
    // This is simpler than normal because it's LISP-like language.
    // We don't need to worry about operator precedence.
    if (parser->previous.type == NUMBER) {
        number(parser);
    } else if (parser->previous.type == FALSE) {
        emitByte(parser, OP_FALSE);
    } else if (parser->previous.type == TRUE) {
        emitByte(parser, OP_TRUE);
    } else if (parser->previous.type == NIL) {
        emitByte(parser, OP_NIL);
    } else if (parser->previous.type == STRING) {
        string(parser);
    } else if (parser->previous.type == LEFT_PAREN) {
        grouping(parser);
    } else {
        errorAtCurrent("Expect expression.");
    }
}

static void equality(Parser* parser) {
    int line = -1;
    while (1) {
        Token token = scanToken(&defaultScanner);
        if (token.line != line) {
            line = token.line;
        }
        printf("%2d '%.*s'\n", token.type, token.length, token.start);
        if (token.type == EOF_TOKEN) {
            break;
        }
    }
}
