Pratt parsing FTW!

This commit is contained in:
Bob Nystrom
2013-11-03 10:16:14 -08:00
parent 1741bb2693
commit dae0246581

View File

@ -107,6 +107,29 @@ typedef struct sCompiler
SymbolTable locals;
} Compiler;
typedef void (*ParseFn)(Compiler*);
enum
{
PREC_NONE,
PREC_LOWEST,
PREC_EQUALITY, // == !=
PREC_COMPARISON, // < > <= >=
PREC_BITWISE, // | &
PREC_TERM, // + -
PREC_FACTOR, // * / %
PREC_CALL // ()
};
typedef struct
{
ParseFn prefix;
ParseFn infix;
int precedence;
const char* name;
} ParseRule;
static void initCompiler(Compiler* compiler, Parser* parser, Compiler* parent);
static ObjBlock* compileBlock(Parser* parser, Compiler* parent,
@ -135,12 +158,17 @@ static void error(Compiler* compiler, const char* format, ...);
static void statement(Compiler* compiler);
static void expression(Compiler* compiler);
static void term(Compiler* compiler);
static void factor(Compiler* compiler);
static void call(Compiler* compiler);
static void primary(Compiler* compiler);
static void parsePrecedence(Compiler* compiler, int precedence);
static void grouping(Compiler* compiler);
static void block(Compiler* compiler);
static void name(Compiler* compiler);
static void number(Compiler* compiler);
static void string(Compiler* compiler);
static void call(Compiler* compiler);
static void infixOp(Compiler* compiler);
static TokenType peek(Compiler* compiler);
static int match(Compiler* compiler, TokenType expected);
static void consume(Compiler* compiler, TokenType expected);
@ -190,6 +218,43 @@ static char peekChar(Parser* parser);
// range.
static void makeToken(Parser* parser, TokenType type);
ParseRule rules[] =
{
{ grouping, NULL, PREC_NONE, NULL }, // TOKEN_LEFT_PAREN
{ NULL, NULL, PREC_NONE, NULL }, // TOKEN_RIGHT_PAREN
{ NULL, NULL, PREC_NONE, NULL }, // TOKEN_LEFT_BRACKET
{ NULL, NULL, PREC_NONE, NULL }, // TOKEN_RIGHT_BRACKET
{ block, NULL, PREC_NONE, NULL }, // TOKEN_LEFT_BRACE
{ NULL, NULL, PREC_NONE, NULL }, // TOKEN_RIGHT_BRACE
{ NULL, NULL, PREC_NONE, NULL }, // TOKEN_COLON
{ NULL, call, PREC_CALL, NULL }, // TOKEN_DOT
{ NULL, NULL, PREC_NONE, NULL }, // TOKEN_COMMA
{ NULL, infixOp, PREC_FACTOR, "* " }, // TOKEN_STAR
{ NULL, infixOp, PREC_FACTOR, "/ " }, // TOKEN_SLASH
{ NULL, infixOp, PREC_NONE, "% " }, // TOKEN_PERCENT
{ NULL, infixOp, PREC_TERM, "+ " }, // TOKEN_PLUS
{ NULL, infixOp, PREC_TERM, "- " }, // TOKEN_MINUS
{ NULL, NULL, PREC_NONE, NULL }, // TOKEN_PIPE
{ NULL, NULL, PREC_NONE, NULL }, // TOKEN_AMP
{ NULL, NULL, PREC_NONE, NULL }, // TOKEN_BANG
{ NULL, NULL, PREC_NONE, NULL }, // TOKEN_EQ
{ NULL, NULL, PREC_NONE, NULL }, // TOKEN_LT
{ NULL, NULL, PREC_NONE, NULL }, // TOKEN_GT
{ NULL, NULL, PREC_NONE, NULL }, // TOKEN_LTEQ
{ NULL, NULL, PREC_NONE, NULL }, // TOKEN_GTEQ
{ NULL, NULL, PREC_NONE, NULL }, // TOKEN_EQEQ
{ NULL, NULL, PREC_NONE, NULL }, // TOKEN_BANGEQ
{ NULL, NULL, PREC_NONE, NULL }, // TOKEN_CLASS
{ NULL, NULL, PREC_NONE, NULL }, // TOKEN_META
{ NULL, NULL, PREC_NONE, NULL }, // TOKEN_VAR
{ name, NULL, PREC_NONE, NULL }, // TOKEN_NAME
{ number, NULL, PREC_NONE, NULL }, // TOKEN_NUMBER
{ string, NULL, PREC_NONE, NULL }, // TOKEN_STRING
{ NULL, NULL, PREC_NONE, NULL }, // TOKEN_LINE
{ NULL, NULL, PREC_NONE, NULL }, // TOKEN_ERROR
{ NULL, NULL, PREC_NONE, NULL } // TOKEN_EOF
};
ObjBlock* compile(VM* vm, const char* source, size_t sourceLength)
{
Parser parser;
@ -327,7 +392,7 @@ void error(Compiler* compiler, const char* format, ...)
va_start(args, format);
vfprintf(stderr, format, args);
va_end(args);
fprintf(stderr, "\n");
}
@ -389,207 +454,77 @@ void statement(Compiler* compiler)
void expression(Compiler* compiler)
{
term(compiler);
return parsePrecedence(compiler, PREC_LOWEST);
}
void term(Compiler* compiler)
void parsePrecedence(Compiler* compiler, int precedence)
{
factor(compiler);
nextToken(compiler->parser);
ParseFn prefix = rules[compiler->parser->previous.type].prefix;
while (match(compiler, TOKEN_PLUS) || match(compiler, TOKEN_MINUS))
if (prefix == NULL)
{
const char* name;
if (compiler->parser->previous.type == TOKEN_PLUS)
{
name = "+ ";
}
else
{
name = "- ";
}
// TODO(bob): Handle error better.
error(compiler, "No prefix parser.");
}
// Compile the right-hand side.
factor(compiler);
prefix(compiler);
// Call the operator method on the left-hand side.
int symbol = ensureSymbol(&compiler->parser->vm->symbols, name, 2);
emit(compiler, CODE_CALL_1);
emit(compiler, symbol);
while (precedence <= rules[compiler->parser->current.type].precedence)
{
nextToken(compiler->parser);
ParseFn infix = rules[compiler->parser->previous.type].infix;
infix(compiler);
}
}
// TODO(bob): Virtually identical to term(). Unify.
void factor(Compiler* compiler)
void grouping(Compiler* compiler)
{
call(compiler);
while (match(compiler, TOKEN_STAR) || match(compiler, TOKEN_SLASH))
{
const char* name;
if (compiler->parser->previous.type == TOKEN_STAR)
{
name = "* ";
}
else
{
name = "/ ";
}
// Compile the right-hand side.
call(compiler);
// Call the operator method on the left-hand side.
int symbol = ensureSymbol(&compiler->parser->vm->symbols, name, 2);
emit(compiler, CODE_CALL_1);
emit(compiler, symbol);
}
expression(compiler);
consume(compiler, TOKEN_RIGHT_PAREN);
}
// Method calls like:
//
// foo.bar
// foo.bar(arg, arg)
// foo.bar { block } other { block }
// foo.bar(arg) nextPart { arg } lastBit
void call(Compiler* compiler)
void block(Compiler* compiler)
{
primary(compiler);
ObjBlock* block = compileBlock(
compiler->parser, compiler, TOKEN_RIGHT_BRACE);
while (match(compiler, TOKEN_DOT))
{
char name[MAX_NAME];
int length = 0;
int numArgs = 0;
// Add the block to the constant table.
compiler->block->constants[compiler->block->numConstants++] = (Value)block;
consume(compiler, TOKEN_NAME);
// Build the method name. The mangled name includes all of the name parts
// in a mixfix call as well as spaces for every argument.
// So a method call like:
//
// foo.bar(arg, arg) else { block } last
//
// Will have name: "bar else last"
// Compile all of the name parts.
for (;;)
{
// Add the just-consumed part name to the method name.
int partLength = compiler->parser->previous.end -
compiler->parser->previous.start;
strncpy(name + length,
compiler->parser->source + compiler->parser->previous.start,
partLength);
length += partLength;
// TODO(bob): Check for length overflow.
// TODO(bob): Allow block arguments.
// Parse the argument list, if any.
if (match(compiler, TOKEN_LEFT_PAREN))
{
for (;;)
{
expression(compiler);
numArgs++;
// Add a space in the name for each argument. Lets us overload by
// arity.
name[length++] = ' ';
if (!match(compiler, TOKEN_COMMA)) break;
}
consume(compiler, TOKEN_RIGHT_PAREN);
// If there isn't another part name after the argument list, stop.
if (!match(compiler, TOKEN_NAME)) break;
}
else
{
// If there isn't an argument list, we're done.
break;
}
}
int symbol = ensureSymbol(&compiler->parser->vm->symbols, name, length);
// Compile the method call.
emit(compiler, CODE_CALL_0 + numArgs);
// TODO(bob): Handle > 10 args.
emit(compiler, symbol);
}
// Compile the code to load it.
emit(compiler, CODE_CONSTANT);
emit(compiler, compiler->block->numConstants - 1);
}
void primary(Compiler* compiler)
void name(Compiler* compiler)
{
// Block.
if (match(compiler, TOKEN_LEFT_BRACE))
// See if it's a local in this scope.
int local = findSymbol(&compiler->locals,
compiler->parser->source + compiler->parser->previous.start,
compiler->parser->previous.end - compiler->parser->previous.start);
if (local != -1)
{
ObjBlock* block = compileBlock(
compiler->parser, compiler, TOKEN_RIGHT_BRACE);
// Add the block to the constant table.
compiler->block->constants[compiler->block->numConstants++] = (Value)block;
// Compile the code to load it.
emit(compiler, CODE_CONSTANT);
emit(compiler, compiler->block->numConstants - 1);
emit(compiler, CODE_LOAD_LOCAL);
emit(compiler, local);
return;
}
// Variable name.
if (match(compiler, TOKEN_NAME))
// TODO(bob): Look up names in outer local scopes.
// See if it's a global variable.
int global = findSymbol(&compiler->parser->vm->globalSymbols,
compiler->parser->source + compiler->parser->previous.start,
compiler->parser->previous.end - compiler->parser->previous.start);
if (global != -1)
{
// See if it's a local in this scope.
int local = findSymbol(&compiler->locals,
compiler->parser->source + compiler->parser->previous.start,
compiler->parser->previous.end - compiler->parser->previous.start);
if (local != -1)
{
emit(compiler, CODE_LOAD_LOCAL);
emit(compiler, local);
return;
}
// TODO(bob): Look up names in outer local scopes.
// See if it's a global variable.
int global = findSymbol(&compiler->parser->vm->globalSymbols,
compiler->parser->source + compiler->parser->previous.start,
compiler->parser->previous.end - compiler->parser->previous.start);
if (global != -1)
{
emit(compiler, CODE_LOAD_GLOBAL);
emit(compiler, global);
return;
}
// TODO(bob): Look for names in outer scopes.
error(compiler, "Undefined variable.");
}
// Number.
if (match(compiler, TOKEN_NUMBER))
{
number(compiler);
emit(compiler, CODE_LOAD_GLOBAL);
emit(compiler, global);
return;
}
// String.
if (match(compiler, TOKEN_STRING))
{
string(compiler);
return;
}
// Parentheses.
if (match(compiler, TOKEN_LEFT_PAREN))
{
expression(compiler);
consume(compiler, TOKEN_RIGHT_PAREN);
return;
}
// TODO(bob): Look for names in outer scopes.
error(compiler, "Undefined variable.");
}
void number(Compiler* compiler)
@ -634,6 +569,90 @@ void string(Compiler* compiler)
emit(compiler, constant);
}
// Method calls like:
//
// foo.bar
// foo.bar(arg, arg)
// foo.bar { block } other { block }
// foo.bar(arg) nextPart { arg } lastBit
void call(Compiler* compiler)
{
char name[MAX_NAME];
int length = 0;
int numArgs = 0;
consume(compiler, TOKEN_NAME);
// Build the method name. The mangled name includes all of the name parts
// in a mixfix call as well as spaces for every argument.
// So a method call like:
//
// foo.bar(arg, arg) else { block } last
//
// Will have name: "bar else last"
// Compile all of the name parts.
for (;;)
{
// Add the just-consumed part name to the method name.
int partLength = compiler->parser->previous.end -
compiler->parser->previous.start;
strncpy(name + length,
compiler->parser->source + compiler->parser->previous.start,
partLength);
length += partLength;
// TODO(bob): Check for length overflow.
// TODO(bob): Allow block arguments.
// Parse the argument list, if any.
if (match(compiler, TOKEN_LEFT_PAREN))
{
for (;;)
{
expression(compiler);
numArgs++;
// Add a space in the name for each argument. Lets us overload by
// arity.
name[length++] = ' ';
if (!match(compiler, TOKEN_COMMA)) break;
}
consume(compiler, TOKEN_RIGHT_PAREN);
// If there isn't another part name after the argument list, stop.
if (!match(compiler, TOKEN_NAME)) break;
}
else
{
// If there isn't an argument list, we're done.
break;
}
}
int symbol = ensureSymbol(&compiler->parser->vm->symbols, name, length);
// Compile the method call.
emit(compiler, CODE_CALL_0 + numArgs);
// TODO(bob): Handle > 10 args.
emit(compiler, symbol);
}
void infixOp(Compiler* compiler)
{
ParseRule* rule = &rules[compiler->parser->previous.type];
// Compile the right-hand side.
parsePrecedence(compiler, rule->precedence + 1);
// Call the operator method on the left-hand side.
int symbol = ensureSymbol(&compiler->parser->vm->symbols, rule->name, 2);
emit(compiler, CODE_CALL_1);
emit(compiler, symbol);
}
TokenType peek(Compiler* compiler)
{
return compiler->parser->current.type;