#lang plait ;; Start with "poly-lambda.rkt" ;; ;; Add the usual `box`, `unbox`, and `set-box!` forms ;; - using Plait boxes ;; - with suitable type checking ;; ;; The define polymorphic `swap` in Curly (define-type Value (numV [n : Number]) (closV [arg : Symbol] [body : Exp] [env : Env]) (polyV [body : Exp] [env : Env]) (boxV [b : (Boxof Value)])) (define-type Exp (numE [n : Number]) (idE [s : Symbol]) (plusE [l : Exp] [r : Exp]) (multE [l : Exp] [r : Exp]) (lamE [n : Symbol] [arg-type : Type] [body : Exp]) (appE [fun : Exp] [arg : Exp]) (tylamE [n : Symbol] [body : Exp]) (tyappE [tyfun : Exp] [tyarg : Type]) (boxE [val : Exp]) (unboxE [b : Exp]) (set-box!E [b : Exp] [v : Exp])) (define-type Type (numT) (boolT) (arrowT [arg : Type] [result : Type]) (idT [n : Symbol]) (forallT [n : Symbol] [body : Type]) (boxofT [t : Type])) (define-type Binding (bind [name : Symbol] [val : Value])) (define-type-alias Env (Listof Binding)) (define-type Type-Binding (tbind [name : Symbol] [type : Type]) (tid [name : Symbol])) (define-type-alias Type-Env (Listof Type-Binding)) (define mt-env empty) (define extend-env cons) (module+ test (print-only-errors #t)) ;; parse ---------------------------------------- (define (parse [s : S-Exp]) : Exp (cond [(s-exp-match? `NUMBER s) (numE (s-exp->number s))] [(s-exp-match? `SYMBOL s) (idE (s-exp->symbol s))] [(s-exp-match? `{+ ANY ANY} s) (plusE (parse (second (s-exp->list s))) (parse (third (s-exp->list s))))] [(s-exp-match? `{* ANY ANY} s) (multE (parse (second (s-exp->list s))) (parse (third (s-exp->list s))))] [(s-exp-match? `{let {[SYMBOL : ANY ANY]} ANY} s) (let ([bs (s-exp->list (first (s-exp->list (second (s-exp->list s)))))]) (appE (lamE (s-exp->symbol (first bs)) (parse-type (third bs)) (parse (third (s-exp->list s)))) (parse (fourth bs))))] [(s-exp-match? `{lambda {[SYMBOL : ANY]} ANY} s) (let ([arg (s-exp->list (first (s-exp->list (second (s-exp->list s)))))]) (lamE (s-exp->symbol (first arg)) (parse-type (third arg)) (parse (third (s-exp->list s)))))] [(s-exp-match? `[LAMBDA ['SYMBOL] ANY] s) (tylamE (s-exp->symbol (second (s-exp->list (first (s-exp->list (second (s-exp->list s))))))) (parse (third (s-exp->list s))))] [(s-exp-match? `[@ ANY ANY] s) (tyappE (parse (second (s-exp->list s))) (parse-type (third (s-exp->list s))))] [(s-exp-match? `{box ANY} s) (boxE (parse (second (s-exp->list s))))] [(s-exp-match? `{unbox ANY} s) (unboxE (parse (second (s-exp->list s))))] [(s-exp-match? `{set-box! ANY ANY} s) (set-box!E (parse (second (s-exp->list s))) (parse (third (s-exp->list s))))] [(s-exp-match? `{ANY ANY} s) (appE (parse (first (s-exp->list s))) (parse (second (s-exp->list s))))] [else (error 'parse "invalid input")])) (define (parse-type [s : S-Exp]) : Type (cond [(s-exp-match? `num s) (numT)] [(s-exp-match? `bool s) (boolT)] [(s-exp-match? `(ANY -> ANY) s) (arrowT (parse-type (first (s-exp->list s))) (parse-type (third (s-exp->list s))))] [(s-exp-match? `'SYMBOL s) ; ''SYMBOL is equivalent to '(quote SYMBOL) (idT (s-exp->symbol (second (s-exp->list s))))] [(s-exp-match? `(forall ('SYMBOL) ANY) s) (forallT (s-exp->symbol (second (s-exp->list (first (s-exp->list (second (s-exp->list s))))))) (parse-type (third (s-exp->list s))))] [(s-exp-match? `(boxof ANY) s) (boxofT (parse-type (second (s-exp->list s))))] [else (error 'parse-type "invalid input")])) (module+ test (test (parse `2) (numE 2)) (test (parse `x) ; note: backquote instead of normal quote (idE 'x)) (test (parse `{+ 2 1}) (plusE (numE 2) (numE 1))) (test (parse `{* 3 4}) (multE (numE 3) (numE 4))) (test (parse `{+ {* 3 4} 8}) (plusE (multE (numE 3) (numE 4)) (numE 8))) (test (parse `{let {[x : num {+ 1 2}]} y}) (appE (lamE 'x (numT) (idE 'y)) (plusE (numE 1) (numE 2)))) (test (parse `{lambda {[x : num]} 9}) (lamE 'x (numT) (numE 9))) (test (parse `{double 9}) (appE (idE 'double) (numE 9))) (test (parse `[LAMBDA ['a] {lambda {[x : 'a]} x}]) (tylamE 'a (lamE 'x (idT 'a) (idE 'x)))) (test (parse `[@ f num]) (tyappE (idE 'f) (numT))) (test/exn (parse `{{+ 1 2}}) "invalid input") (test (parse-type `num) (numT)) (test (parse-type `bool) (boolT)) (test (parse-type `(num -> bool)) (arrowT (numT) (boolT))) (test (parse-type `'a) (idT 'a)) (test (parse-type `(forall ('a) ('a -> 'a))) (forallT 'a (arrowT (idT 'a) (idT 'a)))) (test/exn (parse-type `1) "invalid input")) ;; interp ---------------------------------------- (define (interp [a : Exp] [env : Env]) : Value (type-case Exp a [(numE n) (numV n)] [(idE s) (lookup s env)] [(plusE l r) (num+ (interp l env) (interp r env))] [(multE l r) (num* (interp l env) (interp r env))] [(boxE v) (boxV (box (interp v env)))] [(unboxE v) (type-case Value (interp v env) [(boxV b) (unbox b)] [else (error 'interp "not a box")])] [(set-box!E b v) (type-case Value (interp b env) [(boxV b) (let ([new-val (interp v env)] [old-val (unbox b)]) (begin (set-box! b new-val) old-val))] [else (error 'interp "not a box")])] [(lamE n t body) (closV n body env)] [(appE fun arg) (type-case Value (interp fun env) [(closV n body c-env) (interp body (extend-env (bind n (interp arg env)) c-env))] [else (error 'interp "not a function")])] [(tylamE n body) (polyV body env)] [(tyappE tyfun tyarg) (type-case Value (interp tyfun env) [(polyV body p-env) (interp body p-env)] [else (error 'interp "not a polymorphic value")])])) (module+ test (test (interp (parse `2) mt-env) (numV 2)) (test/exn (interp (parse `x) mt-env) "free variable") (test (interp (parse `x) (extend-env (bind 'x (numV 9)) mt-env)) (numV 9)) (test (interp (parse `{+ 2 1}) mt-env) (numV 3)) (test (interp (parse `{* 2 1}) mt-env) (numV 2)) (test (interp (parse `{+ {* 2 3} {+ 5 8}}) mt-env) (numV 19)) (test (interp (parse `{lambda {[x : num]} {+ x x}}) mt-env) (closV 'x (plusE (idE 'x) (idE 'x)) mt-env)) (test (interp (parse `{let {[x : num 5]} {+ x x}}) mt-env) (numV 10)) (test (interp (parse `{let {[x : num 5]} {let {[x : num {+ 1 x}]} {+ x x}}}) mt-env) (numV 12)) (test (interp (parse `{let {[x : num 5]} {let {[y : num 6]} x}}) mt-env) (numV 5)) (test (interp (parse `{{lambda {[x : num]} {+ x x}} 8}) mt-env) (numV 16)) (test (interp (parse `[LAMBDA ['a] {lambda {[x : 'a]} x}]) mt-env) (polyV (lamE 'x (idT 'a) (idE 'x)) mt-env)) (test (interp (parse `[@ [LAMBDA ['a] {lambda {[x : 'a]} x}] num]) mt-env) (closV 'x (idE 'x) mt-env)) (test (interp (parse `{let {[f : (forall ('a) ('a -> 'a)) [LAMBDA ['a] {lambda {[x : 'a]} x}]]} {+ {[@ f num] 1} {{[@ f (num -> num)] {lambda {[n : num]} {+ n 1}}} 2}}}) mt-env) (numV 4)) (test/exn (interp (parse `{1 2}) mt-env) "not a function") (test/exn (interp (parse `{+ 1 {lambda {[x : num]} x}}) mt-env) "not a number") (test/exn (interp (parse `{let {[bad : (num -> num) {lambda {[x : num]} {+ x y}}]} {let {[y : num 5]} {bad 2}}}) mt-env) "free variable") (test/exn (interp (parse `[@ 1 num]) mt-env) "not a polymorphic value")) ;; num+ and num* ---------------------------------------- (define (num-op [op : (Number Number -> Number)] [l : Value] [r : Value]) : Value (cond [(and (numV? l) (numV? r)) (numV (op (numV-n l) (numV-n r)))] [else (error 'interp "not a number")])) (define (num+ [l : Value] [r : Value]) : Value (num-op + l r)) (define (num* [l : Value] [r : Value]) : Value (num-op * l r)) (module+ test (test (num+ (numV 1) (numV 2)) (numV 3)) (test (num* (numV 2) (numV 3)) (numV 6))) ;; lookup ---------------------------------------- (define (make-lookup [check? : ('a -> Boolean)] [name-of : ('a -> Symbol)] [val-of : ('a -> 'b)]) (lambda ([name : Symbol] [vals : (Listof 'a)]) : 'b (cond [(empty? vals) (error 'find "free variable")] [else (if (and (check? (first vals)) (equal? name (name-of (first vals)))) (val-of (first vals)) ((make-lookup check? name-of val-of) name (rest vals)))]))) (define lookup (make-lookup bind? bind-name bind-val)) (module+ test (test/exn (lookup 'x mt-env) "free variable") (test (lookup 'x (extend-env (bind 'x (numV 8)) mt-env)) (numV 8)) (test (lookup 'x (extend-env (bind 'x (numV 9)) (extend-env (bind 'x (numV 8)) mt-env))) (numV 9)) (test (lookup 'y (extend-env (bind 'x (numV 9)) (extend-env (bind 'y (numV 8)) mt-env))) (numV 8))) ;; typecheck ---------------------------------------- (define (typecheck [a : Exp] [tenv : Type-Env]) (type-case Exp a [(numE n) (numT)] [(plusE l r) (typecheck-nums l r tenv)] [(multE l r) (typecheck-nums l r tenv)] [(idE n) (type-lookup n tenv)] [(lamE n arg-type body) (begin (tvarcheck arg-type tenv) (arrowT arg-type (typecheck body (extend-env (tbind n arg-type) tenv))))] [(appE fun arg) (type-case Type (typecheck fun tenv) [(arrowT arg-type result-type) (if (equal? arg-type (typecheck arg tenv)) result-type (type-error arg (to-string arg-type)))] [else (type-error fun "function")])] [(tylamE n body) (forallT n (typecheck body (extend-env (tid n) tenv)))] [(tyappE tyfun tyarg) (begin (tvarcheck tyarg tenv) (type-case Type (typecheck tyfun tenv) [(forallT n body) (type-subst n tyarg body)] [else (type-error tyfun "polymorphic value")]))] [(boxE c) (boxofT (typecheck c tenv))] [(unboxE b) (type-case Type (typecheck b tenv) [(boxofT contained-t) contained-t] [else (type-error b "box")])] [(set-box!E b v) (type-case Type (typecheck b tenv) [(boxofT contained-t) (if (equal? (typecheck v tenv) contained-t) contained-t (type-error v (to-string contained-t)))] [else (type-error b "box")])])) (define (typecheck-nums l r tenv) (type-case Type (typecheck l tenv) [(numT) (type-case Type (typecheck r tenv) [(numT) (numT)] [else (type-error r "num")])] [else (type-error l "num")])) (define (type-error a msg) (error 'typecheck (string-append "no type: " (string-append (to-string a) (string-append " not " msg))))) (define type-lookup (make-lookup tbind? tbind-name tbind-type)) (define type-var-lookup (make-lookup tid? tid-name tid-name)) (module+ test (test (typecheck (parse `10) mt-env) (numT)) (test (typecheck (parse `{+ 10 17}) mt-env) (numT)) (test (typecheck (parse `{* 10 17}) mt-env) (numT)) (test (typecheck (parse `{lambda {[x : num]} 12}) mt-env) (arrowT (numT) (numT))) (test (typecheck (parse `{lambda {[x : num]} {lambda {[y : bool]} x}}) mt-env) (arrowT (numT) (arrowT (boolT) (numT)))) (test (typecheck (parse `{{lambda {[x : num]} 12} {+ 1 17}}) mt-env) (numT)) (test (typecheck (parse `{let {[x : num 4]} {let {[f : (num -> num) {lambda {[y : num]} {+ x y}}]} {f x}}}) mt-env) (numT)) (test (typecheck (parse `[LAMBDA ['a] {lambda {[x : 'a]} x}]) mt-env) (forallT 'a (arrowT (idT 'a) (idT 'a)))) (test (typecheck (parse `[@ [LAMBDA ['a] {lambda {[x : 'a]} x}] num]) mt-env) (arrowT (numT) (numT))) (test (typecheck (parse `{let {[f : (forall ('a) ('a -> 'a)) [LAMBDA ['a] {lambda {[x : 'a]} x}]]} {+ {[@ f num] 1} {{[@ f (num -> num)] {lambda {[n : num]} {+ n 1}}} 2}}}) mt-env) (numT)) (test/exn (typecheck (parse `{+ 1 {box 4}}) empty) "no type") (test (typecheck (parse `{box 4}) empty) (boxofT (numT))) (test/exn (typecheck (parse `{1 2}) mt-env) "no type") (test/exn (typecheck (parse `{{lambda {[x : bool]} x} 2}) mt-env) "no type") (test/exn (typecheck (parse `{+ 1 {lambda {[x : num]} x}}) mt-env) "no type") (test/exn (typecheck (parse `{* {lambda {[x : num]} x} 1}) mt-env) "no type") (test/exn (typecheck (parse `[@ 1 num]) mt-env) "no type")) ;; tvarcheck ---------------------------------------- (define (tvarcheck ty tenv) (type-case Type ty [(numT) (values)] [(boolT) (values)] [(arrowT a b) (begin (tvarcheck a tenv) (tvarcheck b tenv))] [(boxofT t) (tvarcheck t tenv)] [(idT id) (begin (type-var-lookup id tenv) (values))] [(forallT id t) (tvarcheck t (extend-env (tid id) tenv))])) (module+ test (test (tvarcheck (numT) mt-env) (values)) (test (tvarcheck (boolT) mt-env) (values)) (test (tvarcheck (arrowT (numT) (boolT)) mt-env) (values)) (test (tvarcheck (idT 'a) (extend-env (tid 'a) mt-env)) (values)) (test (tvarcheck (forallT 'a (idT 'a)) mt-env) (values)) (test/exn (tvarcheck (idT 'a) mt-env) "free variable")) ;; type-subst ---------------------------------------- (define (type-subst [what : Symbol] [for : Type] [in : Type]) (type-case Type in [(numT) (numT)] [(boolT) (boolT)] [(arrowT l r) (arrowT (type-subst what for l) (type-subst what for r))] [(idT n) (if (equal? what n) for (idT n))] [(boxofT t) (boxofT (type-subst what for t))] [(forallT n body) (cond [(equal? what n) (forallT n body)] [(free-type-var? n for) ;; If we want to replace `a` in ;; `{forall b (b -> a)}` with `a`, the result ;; `{forall b (b -> b)}` would be wrong, since the ;; `b` would get captured. We instead need to ;; produce `{forall b1 (b1 -> b)}`. (local [(define new-n (gen-name n 1 for body)) (define new-body (type-subst n (idT new-n) body))] (type-subst what for (forallT new-n new-body)))] [else (forallT n (type-subst what for body))])])) ;; Helper function for substitution: generates a name like `n` that is ;; not currently used (as a free type variable) in `for` or `body`. (define (gen-name [n : Symbol] [i : Number] [for : Type] [body : Type]) (let ([new-n (string->symbol (string-append (symbol->string n) (to-string i)))]) (if (or (free-type-var? new-n for) (free-type-var? new-n body)) (gen-name n (+ i 1) for body) new-n))) ;; Helper function for substutition: check whether a name is used as a ;; free type variable in a type. (define (free-type-var? [n : Symbol] [t : Type]) (type-case Type t [(numT) #f] [(boolT) #f] [(arrowT l r) (or (free-type-var? n l) (free-type-var? n r))] [(idT n-v) (equal? n-v n)] [(boxofT t) (free-type-var? n t)] [(forallT n-f body) (cond [(equal? n n-f) #f] [else (free-type-var? n body)])])) (module+ test (test (free-type-var? 'a (numT)) #f) (test (free-type-var? 'a (boolT)) #f) (test (free-type-var? 'a (arrowT (idT 'b) (idT 'b))) #f) (test (free-type-var? 'a (arrowT (idT 'a) (idT 'b))) #t) (test (free-type-var? 'a (idT 'a)) #t) (test (free-type-var? 'a (idT 'b)) #f) (test (free-type-var? 'a (forallT 'a (idT 'a))) #f) (test (free-type-var? 'a (forallT 'b (idT 'a))) #t) (test (free-type-var? 'a (forallT 'b (idT 'c))) #f) (test (gen-name 'a 1 (numT) (numT)) 'a1) (test (gen-name 'a 1 (numT) (idT 'a1)) 'a2) (test (gen-name 'a 1 (idT 'a1) (numT)) 'a2) (test (gen-name 'a 1 (arrowT (numT) (idT 'a1)) (numT)) 'a2) (test (gen-name 'a 1 (forallT 'a1 (idT 'a1)) (numT)) 'a1) (test (gen-name 'a 1 (forallT 'b (idT 'a1)) (numT)) 'a2) (test (type-subst 'a (boolT) (numT)) (numT)) (test (type-subst 'a (numT) (boolT)) (boolT)) (test (type-subst 'a (numT) (arrowT (idT 'a) (boolT))) (arrowT (numT) (boolT))) (test (type-subst 'a (numT) (idT 'a)) (numT)) (test (type-subst 'a (numT) (idT 'b)) (idT 'b)) (test (type-subst 'a (numT) (forallT 'a (idT 'a))) (forallT 'a (idT 'a))) (test (type-subst 'a (numT) (forallT 'b (idT 'a))) (forallT 'b (numT))) (test (type-subst 'a (idT 'b) (forallT 'b (idT 'a))) (forallT 'b1 (idT 'b))) (test (type-subst 'a (idT 'b) (forallT 'b (arrowT (idT 'a) (arrowT (idT 'b1) (idT 'b))))) (forallT 'b2 (arrowT (idT 'b) (arrowT (idT 'b1) (idT 'b2)))))) (module+ test (define (add-swap body) `{let {[swap! : (forall ('a) ((boxof 'a) -> ((boxof 'a) -> 'a))) [LAMBDA ['a] {lambda {[a : (boxof 'a)]} {lambda {[b : (boxof 'a)]} {let {[b-v : 'a {set-box! b {unbox a}}]} {set-box! a b-v}}}}]]} ,body}) (define num-prog (add-swap `{let {[a : (boxof num) {box 1}]} {let {[b : (boxof num) {box 2}]} {let {[d : num {{[@ swap! num] a} b}]} {unbox b}}}})) (test (typecheck (parse num-prog) mt-env) (numT)) (test (interp (parse num-prog) mt-env) (numV 1)) ;; Types are bogus: (define func-prog (add-swap `{let {[a : (boxof (num -> num)) {box {lambda {[x : num]} {+ 2 x}}}]} {let {[b : (boxof (num -> num)) {box {lambda {[x : num]} {* 2 x}}}]} {let {[d : (num -> num) {{[@ swap! (num -> num)] a} b}]} {{unbox b} 5}}}})) (test (typecheck (parse func-prog) mt-env) (numT)) (test (interp (parse func-prog) mt-env) (numV 7)))