1
0
forked from Mirror/wren

Get infix arithmetic operators working.

This commit is contained in:
Bob Nystrom
2013-10-31 07:04:44 -07:00
parent 3f433922eb
commit 64f1b39ee9
9 changed files with 179 additions and 42 deletions

View File

@ -1,9 +1,6 @@
class Foo { // line comment io.write(1 + 2)
bar { io.write(3 - 1)
123 io.write(10 / 3)
} io.write(20 * 30)
} io.write("abc" + "def")
var a = Foo.new io.write(1 + "string")
a.bar
"something".contains("meth")
io.write("hey there")

View File

@ -135,6 +135,8 @@ static void error(Compiler* compiler, const char* format, ...);
static void statement(Compiler* compiler); static void statement(Compiler* compiler);
static void expression(Compiler* compiler); static void expression(Compiler* compiler);
static void term(Compiler* compiler);
static void factor(Compiler* compiler);
static void call(Compiler* compiler); static void call(Compiler* compiler);
static void primary(Compiler* compiler); static void primary(Compiler* compiler);
static void number(Compiler* compiler); static void number(Compiler* compiler);
@ -387,8 +389,61 @@ void statement(Compiler* compiler)
} }
void expression(Compiler* compiler) void expression(Compiler* compiler)
{
term(compiler);
}
void term(Compiler* compiler)
{
factor(compiler);
while (match(compiler, TOKEN_PLUS) || match(compiler, TOKEN_MINUS))
{
const char* name;
if (compiler->parser->previous.type == TOKEN_PLUS)
{
name = "+ ";
}
else
{
name = "- ";
}
// Compile the right-hand side.
factor(compiler);
// Call the operator method on the left-hand side.
int symbol = ensureSymbol(&compiler->parser->vm->symbols, name, 2);
emit(compiler, CODE_CALL_1);
emit(compiler, symbol);
}
}
// TODO(bob): Virtually identical to term(). Unify.
void factor(Compiler* compiler)
{ {
call(compiler); call(compiler);
while (match(compiler, TOKEN_STAR) || match(compiler, TOKEN_SLASH))
{
const char* name;
if (compiler->parser->previous.type == TOKEN_STAR)
{
name = "* ";
}
else
{
name = "/ ";
}
// Compile the right-hand side.
call(compiler);
// Call the operator method on the left-hand side.
int symbol = ensureSymbol(&compiler->parser->vm->symbols, name, 2);
emit(compiler, CODE_CALL_1);
emit(compiler, symbol);
}
} }
// Method calls like: // Method calls like:

View File

@ -12,7 +12,7 @@
} }
#define DEF_PRIMITIVE(prim) \ #define DEF_PRIMITIVE(prim) \
static Value primitive_##prim(Value* args, int numArgs) static Value primitive_##prim(VM* vm, Value* args, int numArgs)
#define GLOBAL(cls, name) \ #define GLOBAL(cls, name) \
{ \ { \
@ -29,6 +29,31 @@ DEF_PRIMITIVE(num_abs)
return (Value)makeNum(value); return (Value)makeNum(value);
} }
DEF_PRIMITIVE(num_minus)
{
if (args[1]->type != OBJ_NUM) return vm->unsupported;
return (Value)makeNum(((ObjNum*)args[0])->value - ((ObjNum*)args[1])->value);
}
DEF_PRIMITIVE(num_plus)
{
if (args[1]->type != OBJ_NUM) return vm->unsupported;
// TODO(bob): Handle coercion to string if RHS is a string.
return (Value)makeNum(((ObjNum*)args[0])->value + ((ObjNum*)args[1])->value);
}
DEF_PRIMITIVE(num_multiply)
{
if (args[1]->type != OBJ_NUM) return vm->unsupported;
return (Value)makeNum(((ObjNum*)args[0])->value * ((ObjNum*)args[1])->value);
}
DEF_PRIMITIVE(num_divide)
{
if (args[1]->type != OBJ_NUM) return vm->unsupported;
return (Value)makeNum(((ObjNum*)args[0])->value / ((ObjNum*)args[1])->value);
}
DEF_PRIMITIVE(string_contains) DEF_PRIMITIVE(string_contains)
{ {
const char* string = ((ObjString*)args[0])->value; const char* string = ((ObjString*)args[0])->value;
@ -49,6 +74,24 @@ DEF_PRIMITIVE(string_count)
return (Value)makeNum(count); return (Value)makeNum(count);
} }
DEF_PRIMITIVE(string_plus)
{
if (args[1]->type != OBJ_STRING) return vm->unsupported;
// TODO(bob): Handle coercion to string of RHS.
ObjString* left = (ObjString*)args[0];
ObjString* right = (ObjString*)args[1];
size_t leftLength = strlen(left->value);
size_t rightLength = strlen(right->value);
char* result = malloc(leftLength + rightLength);
strcpy(result, left->value);
strcpy(result + leftLength, right->value);
return (Value)makeString(result);
}
DEF_PRIMITIVE(io_write) DEF_PRIMITIVE(io_write)
{ {
printValue(args[1]); printValue(args[1]);
@ -59,10 +102,21 @@ DEF_PRIMITIVE(io_write)
void registerPrimitives(VM* vm) void registerPrimitives(VM* vm)
{ {
PRIMITIVE(vm->numClass, "abs", num_abs); PRIMITIVE(vm->numClass, "abs", num_abs);
PRIMITIVE(vm->numClass, "- ", num_minus);
PRIMITIVE(vm->numClass, "+ ", num_plus);
PRIMITIVE(vm->numClass, "* ", num_multiply);
PRIMITIVE(vm->numClass, "/ ", num_divide);
PRIMITIVE(vm->stringClass, "contains ", string_contains); PRIMITIVE(vm->stringClass, "contains ", string_contains);
PRIMITIVE(vm->stringClass, "count", string_count); PRIMITIVE(vm->stringClass, "count", string_count);
PRIMITIVE(vm->stringClass, "+ ", string_plus);
ObjClass* ioClass = makeClass(); ObjClass* ioClass = makeClass();
PRIMITIVE(ioClass, "write ", io_write); PRIMITIVE(ioClass, "write ", io_write);
GLOBAL(ioClass, "io"); GLOBAL(ioClass, "io");
ObjClass* unsupportedClass = makeClass();
// TODO(bob): Make this a distinct object type.
vm->unsupported = (Value)makeInstance(unsupportedClass);
} }

View File

@ -27,9 +27,13 @@ typedef struct
int numFrames; int numFrames;
} Fiber; } Fiber;
static Value primitive_metaclass_new(Value* args, int numArgs);
static void callBlock(Fiber* fiber, ObjBlock* block, int firstLocal); static void callBlock(Fiber* fiber, ObjBlock* block, int firstLocal);
static Value primitive_metaclass_new(VM* vm, Value* args, int numArgs);
// Pushes [value] onto the top of the stack.
static void push(Fiber* fiber, Value value); static void push(Fiber* fiber, Value value);
// Removes and returns the top of the stack.
static Value pop(Fiber* fiber); static Value pop(Fiber* fiber);
VM* newVM() VM* newVM()
@ -279,11 +283,10 @@ Value interpret(VM* vm, ObjBlock* block)
case CODE_CALL_10: case CODE_CALL_10:
{ {
int numArgs = frame->block->bytecode[frame->ip - 1] - CODE_CALL_0; int numArgs = frame->block->bytecode[frame->ip - 1] - CODE_CALL_0;
Value receiver = fiber.stack[fiber.stackSize - numArgs - 1];
int symbol = frame->block->bytecode[frame->ip++]; int symbol = frame->block->bytecode[frame->ip++];
// TODO(bob): Support classes for other object types. Value receiver = fiber.stack[fiber.stackSize - numArgs - 1];
ObjClass* classObj; ObjClass* classObj;
switch (receiver->type) switch (receiver->type)
{ {
@ -312,8 +315,11 @@ Value interpret(VM* vm, ObjBlock* block)
switch (method->type) switch (method->type)
{ {
case METHOD_NONE: case METHOD_NONE:
// TODO(bob): Should return nil or suspend fiber or something. printf("Receiver ");
printf("No method.\n"); printValue(receiver);
printf(" does not implement method \"%s\".\n",
vm->symbols.names[symbol]);
// TODO(bob): Throw an exception or halt the fiber or something.
exit(1); exit(1);
break; break;
@ -322,15 +328,18 @@ Value interpret(VM* vm, ObjBlock* block)
break; break;
case METHOD_PRIMITIVE: case METHOD_PRIMITIVE:
// TODO(bob): Pass args to primitive. {
fiber.stack[fiber.stackSize - numArgs - 1] = Value* args = &fiber.stack[fiber.stackSize - numArgs - 1];
method->primitive(&fiber.stack[fiber.stackSize - numArgs - 1], // TODO(bob): numArgs passed to primitive should probably include
numArgs); // receiver.
Value result = method->primitive(vm, args, numArgs);
fiber.stack[fiber.stackSize - numArgs - 1] = result;
// Discard the stack slots for the arguments. // Discard the stack slots for the arguments.
fiber.stackSize = fiber.stackSize - numArgs; fiber.stackSize -= numArgs;
break; break;
}
case METHOD_BLOCK: case METHOD_BLOCK:
callBlock(&fiber, method->block, fiber.stackSize - numArgs); callBlock(&fiber, method->block, fiber.stackSize - numArgs);
break; break;
@ -362,14 +371,6 @@ void printValue(Value value)
// TODO(bob): Do more useful stuff here. // TODO(bob): Do more useful stuff here.
switch (value->type) switch (value->type)
{ {
case OBJ_NUM:
printf("%g", ((ObjNum*)value)->value);
break;
case OBJ_STRING:
printf("%s", ((ObjString*)value)->value);
break;
case OBJ_BLOCK: case OBJ_BLOCK:
printf("[block]"); printf("[block]");
break; break;
@ -381,14 +382,15 @@ void printValue(Value value)
case OBJ_INSTANCE: case OBJ_INSTANCE:
printf("[instance]"); printf("[instance]");
break; break;
}
}
Value primitive_metaclass_new(Value* args, int numArgs) case OBJ_NUM:
{ printf("%g", ((ObjNum*)value)->value);
ObjClass* classObj = (ObjClass*)args[0]; break;
// TODO(bob): Invoke initializer method.
return (Value)makeInstance(classObj); case OBJ_STRING:
printf("%s", ((ObjString*)value)->value);
break;
}
} }
void callBlock(Fiber* fiber, ObjBlock* block, int firstLocal) void callBlock(Fiber* fiber, ObjBlock* block, int firstLocal)
@ -408,6 +410,13 @@ void callBlock(Fiber* fiber, ObjBlock* block, int firstLocal)
fiber->numFrames++; fiber->numFrames++;
} }
Value primitive_metaclass_new(VM* vm, Value* args, int numArgs)
{
ObjClass* classObj = (ObjClass*)args[0];
// TODO(bob): Invoke initializer method.
return (Value)makeInstance(classObj);
}
void push(Fiber* fiber, Value value) void push(Fiber* fiber, Value value)
{ {
// TODO(bob): Check for stack overflow. // TODO(bob): Check for stack overflow.

View File

@ -28,7 +28,9 @@ typedef struct
typedef Obj* Value; typedef Obj* Value;
typedef Value (*Primitive)(Value* args, int numArgs); typedef struct sVM VM;
typedef Value (*Primitive)(VM* vm, Value* args, int numArgs);
typedef struct typedef struct
{ {
@ -127,7 +129,7 @@ typedef enum
CODE_CALL_8, CODE_CALL_8,
CODE_CALL_9, CODE_CALL_9,
CODE_CALL_10, CODE_CALL_10,
// The current block is done and should be exited. // The current block is done and should be exited.
CODE_END CODE_END
} Code; } Code;
@ -139,7 +141,7 @@ typedef struct
int count; int count;
} SymbolTable; } SymbolTable;
typedef struct struct sVM
{ {
SymbolTable symbols; SymbolTable symbols;
@ -148,10 +150,13 @@ typedef struct
ObjClass* numClass; ObjClass* numClass;
ObjClass* stringClass; ObjClass* stringClass;
// The singleton value "unsupported".
Value unsupported;
SymbolTable globalSymbols; SymbolTable globalSymbols;
// TODO(bob): Using a fixed array is gross here. // TODO(bob): Using a fixed array is gross here.
Value globals[MAX_SYMBOLS]; Value globals[MAX_SYMBOLS];
} VM; };
VM* newVM(); VM* newVM();
void freeVM(VM* vm); void freeVM(VM* vm);

4
test/number_divide.wren Normal file
View File

@ -0,0 +1,4 @@
io.write(8 / 2) // expect: 4
// TODO(bob): Floating point numbers.
// TODO(bob): Unsupported RHS types.

5
test/number_minus.wren Normal file
View File

@ -0,0 +1,5 @@
io.write(5 - 3) // expect: 2
io.write(3 - 2 - 1) // expect: 0
// TODO(bob): Floating point numbers.
// TODO(bob): Unsupported RHS types.

View File

@ -0,0 +1,4 @@
io.write(5 * 3) // expect: 15
// TODO(bob): Floating point numbers.
// TODO(bob): Unsupported RHS types.

4
test/number_plus.wren Normal file
View File

@ -0,0 +1,4 @@
io.write(3 + 5 + 2) // expect: 10
// TODO(bob): Floating point numbers.
// TODO(bob): Unsupported RHS types.