class Expr: next_id = 0 def __init__(self, op, inputs): self.op = op self.inputs = inputs self.id = Expr.next_id Expr.next_id += 1 if not isinstance(op, Op): raise Exception("Not an operator: %s" % op) def __dfs_post(self, ids, visitor): ids[self.id] = True for expr in self.inputs: if expr.id in ids: continue expr.__dfs_post(ids, visitor) visitor(self) def statements(self): lines = [] self.__dfs_post({}, lambda that: lines.append("%s" % that)) return "\n".join(lines) def __str__(self): args = ",".join(["t%d" % expr.id for expr in self.inputs]) return "t%d = %s(%s)" % (self.id, self.op, args) def __promote(r): if isinstance(r, Expr): return r else: return Const(r) def __add__(self, r): return Op("", "Add", 2, {})(self, Expr.__promote(r)) def __sub__(self, r): return Op("", "Sub", 2, {})(self, Expr.__promote(r)) def __mul__(self, r): return Op("", "Mul", 2, {})(self, Expr.__promote(r)) def __neg__(self): return Op("", "Neg", 1, {})(self) def compile(self, builder): self.__dfs_post({}, lambda that: builder.append(that)) return builder.build() def resolve(self, parameters): self.__dfs_post({}, lambda that: that.op.resolve(parameters)) return self class Op: def __init__(self, name, op_type, num_args, parameters): self.name = name self.op_type = op_type self.num_args = num_args self.parameters = parameters def __call__(self, *inputs): if self.num_args >= 0 and self.num_args != len(inputs): raise Exception("%s: need %d arguments but found %d" % (self, self.num_args, len(inputs))) for i, expr in enumerate(inputs): if not isinstance(expr, Expr): raise Exception("%s: arg %d is not an expression: %s" % (self, i, expr)) return Expr(self, inputs) def __str__(self): name = "%s.%s" % (self.name, self.op_type) if len(self.parameters) == 0: return name params = ",".join(["%s=%s" % (k, v.shape if hasattr(v, "shape") else v) for k, v in self.parameters.items()]) return "%s[%s]" % (name, params) def resolve(self, parameters): if self.name == "": return for k, v in parameters.items(): if k.startswith(self.name+"."): self.parameters[k[len(self.name)+1:]] = v def Const(c): if isinstance(c, (int, float)): c = float(c) elif hasattr(c, "shape"): c = c.astype(float) else: raise Exception("Const must be float or int or ndarray: %s" % c) return Expr(Op("", "Const", 0, { "value": c }), []) def Input(n): return Expr(Op(n, "Input", 0, {}), []) def Input2d(n, h, w, ic): return Expr(Op(n, "Input2d", 0, { "height": h, "width": w, "in_channels": ic }), []) def MaxPool2d(k, s): return Op("", "MaxPool2d", 1, { "kernel_size": k, "stride": s }) def ReLU(): return Op("", "ReLU", 1, {}) def Flatten(): return Op("", "Flatten", 1, {}) def Conv2d(n, ic, oc, k, p = 0): return Op(n, "Conv2d", 1, { "in_channels": ic, "out_channels": oc, "kernel_size": k, "padding": p }) def Linear(n, i, o): return Op(n, "Linear", 1, { "in_features": i, "out_features": o }) def Show(): return Op("", "Show", 1, {})