1
0
forked from Mirror/wren

Fix stack corruption caused by Fn call primitives (#807)

Excerpt from @munificent on the nature of the bug:

In runInterpreter, for performance, the vm caches an IP pointing into some bytecode.

All primitives except for `.call`, do not touch Wren's own callstack. They run a little C code and return, so the array of CallFrames, their IPs, and the IP cached inside run() are not affected at all.

While runInterpreter() is running, the IP in the top CallFrame is not updated, so it gets out of sync. This is deliberate, since storing to a field is slow, but it means the value of that field is stale and doesn't represent where execution actually is at that point in time.

To get that field in sync, we use STORE_FRAME(), which stores the local IP value back into the IP field for the top CallFrame. The interpreter is careful to always call STORE_FRAME() before executing any code that pushes a new CallFrame onto the stack.

In particular, if you look around, you'll see that every place the interpreter calls wrenCallFunction() is preceded by a STORE_FRAME(). That is, except for the call to wrenCallFunction() in the call_fn() primitive. That's the bug.

The .call() method on Fn is special because it does modify the Wren call stack and the C code for that primitive directly calls wrenCallFunction(). When that happens, the correct IP for the current function, which lives only in runInterpreter()'s local variable gets discarded and you're left with a stale IP in the CallFrame.

Giving the function call primitives a different method type and having the case for that method type call STORE_FRAME() before invoking the primitive fixes the bug.
This commit is contained in:
ruby
2020-09-18 12:32:43 -07:00
committed by GitHub
parent f769599bc6
commit 86463acb90
4 changed files with 41 additions and 17 deletions

View File

@ -1247,23 +1247,25 @@ void wrenInitializeCore(WrenVM* vm)
PRIMITIVE(vm->fnClass->obj.classObj, "new(_)", fn_new); PRIMITIVE(vm->fnClass->obj.classObj, "new(_)", fn_new);
PRIMITIVE(vm->fnClass, "arity", fn_arity); PRIMITIVE(vm->fnClass, "arity", fn_arity);
PRIMITIVE(vm->fnClass, "call()", fn_call0);
PRIMITIVE(vm->fnClass, "call(_)", fn_call1); FUNCTION_CALL(vm->fnClass, "call()", fn_call0);
PRIMITIVE(vm->fnClass, "call(_,_)", fn_call2); FUNCTION_CALL(vm->fnClass, "call(_)", fn_call1);
PRIMITIVE(vm->fnClass, "call(_,_,_)", fn_call3); FUNCTION_CALL(vm->fnClass, "call(_,_)", fn_call2);
PRIMITIVE(vm->fnClass, "call(_,_,_,_)", fn_call4); FUNCTION_CALL(vm->fnClass, "call(_,_,_)", fn_call3);
PRIMITIVE(vm->fnClass, "call(_,_,_,_,_)", fn_call5); FUNCTION_CALL(vm->fnClass, "call(_,_,_,_)", fn_call4);
PRIMITIVE(vm->fnClass, "call(_,_,_,_,_,_)", fn_call6); FUNCTION_CALL(vm->fnClass, "call(_,_,_,_,_)", fn_call5);
PRIMITIVE(vm->fnClass, "call(_,_,_,_,_,_,_)", fn_call7); FUNCTION_CALL(vm->fnClass, "call(_,_,_,_,_,_)", fn_call6);
PRIMITIVE(vm->fnClass, "call(_,_,_,_,_,_,_,_)", fn_call8); FUNCTION_CALL(vm->fnClass, "call(_,_,_,_,_,_,_)", fn_call7);
PRIMITIVE(vm->fnClass, "call(_,_,_,_,_,_,_,_,_)", fn_call9); FUNCTION_CALL(vm->fnClass, "call(_,_,_,_,_,_,_,_)", fn_call8);
PRIMITIVE(vm->fnClass, "call(_,_,_,_,_,_,_,_,_,_)", fn_call10); FUNCTION_CALL(vm->fnClass, "call(_,_,_,_,_,_,_,_,_)", fn_call9);
PRIMITIVE(vm->fnClass, "call(_,_,_,_,_,_,_,_,_,_,_)", fn_call11); FUNCTION_CALL(vm->fnClass, "call(_,_,_,_,_,_,_,_,_,_)", fn_call10);
PRIMITIVE(vm->fnClass, "call(_,_,_,_,_,_,_,_,_,_,_,_)", fn_call12); FUNCTION_CALL(vm->fnClass, "call(_,_,_,_,_,_,_,_,_,_,_)", fn_call11);
PRIMITIVE(vm->fnClass, "call(_,_,_,_,_,_,_,_,_,_,_,_,_)", fn_call13); FUNCTION_CALL(vm->fnClass, "call(_,_,_,_,_,_,_,_,_,_,_,_)", fn_call12);
PRIMITIVE(vm->fnClass, "call(_,_,_,_,_,_,_,_,_,_,_,_,_,_)", fn_call14); FUNCTION_CALL(vm->fnClass, "call(_,_,_,_,_,_,_,_,_,_,_,_,_)", fn_call13);
PRIMITIVE(vm->fnClass, "call(_,_,_,_,_,_,_,_,_,_,_,_,_,_,_)", fn_call15); FUNCTION_CALL(vm->fnClass, "call(_,_,_,_,_,_,_,_,_,_,_,_,_,_)", fn_call14);
PRIMITIVE(vm->fnClass, "call(_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_)", fn_call16); FUNCTION_CALL(vm->fnClass, "call(_,_,_,_,_,_,_,_,_,_,_,_,_,_,_)", fn_call15);
FUNCTION_CALL(vm->fnClass, "call(_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_)", fn_call16);
PRIMITIVE(vm->fnClass, "toString", fn_toString); PRIMITIVE(vm->fnClass, "toString", fn_toString);
vm->nullClass = AS_CLASS(wrenFindVariable(vm, coreModule, "Null")); vm->nullClass = AS_CLASS(wrenFindVariable(vm, coreModule, "Null"));

View File

@ -16,6 +16,19 @@
wrenBindMethod(vm, cls, symbol, method); \ wrenBindMethod(vm, cls, symbol, method); \
} while (false) } while (false)
// Binds a primitive method named [name] (in Wren) implemented using C function
// [fn] to `ObjClass` [cls], but as a FN call.
#define FUNCTION_CALL(cls, name, function) \
do \
{ \
int symbol = wrenSymbolTableEnsure(vm, \
&vm->methodNames, name, strlen(name)); \
Method method; \
method.type = METHOD_FUNCTION_CALL; \
method.as.primitive = prim_##function; \
wrenBindMethod(vm, cls, symbol, method); \
} while (false)
// Defines a primitive method whose C function name is [name]. This abstracts // Defines a primitive method whose C function name is [name]. This abstracts
// the actual type signature of a primitive function and makes it clear which C // the actual type signature of a primitive function and makes it clear which C
// functions are invoked as primitives. // functions are invoked as primitives.

View File

@ -359,6 +359,9 @@ typedef enum
// this can directly manipulate the fiber's stack. // this can directly manipulate the fiber's stack.
METHOD_PRIMITIVE, METHOD_PRIMITIVE,
// A primitive that handles .call on Fn.
METHOD_FUNCTION_CALL,
// A externally-defined C method. // A externally-defined C method.
METHOD_FOREIGN, METHOD_FOREIGN,

View File

@ -1015,6 +1015,12 @@ static WrenInterpretResult runInterpreter(WrenVM* vm, register ObjFiber* fiber)
} }
break; break;
case METHOD_FUNCTION_CALL:
STORE_FRAME();
method->as.primitive(vm, args);
LOAD_FRAME();
break;
case METHOD_FOREIGN: case METHOD_FOREIGN:
callForeign(vm, fiber, method->as.foreign, numArgs); callForeign(vm, fiber, method->as.foreign, numArgs);
if (wrenHasError(fiber)) RUNTIME_ERROR(); if (wrenHasError(fiber)) RUNTIME_ERROR();