145 lines
3.8 KiB
Python
Executable file
145 lines
3.8 KiB
Python
Executable file
class Expr:
|
|
next_id = 0
|
|
|
|
def __init__(self, op, inputs):
|
|
self.op = op
|
|
self.inputs = inputs
|
|
self.id = Expr.next_id
|
|
Expr.next_id += 1
|
|
|
|
if not isinstance(op, Op):
|
|
raise Exception("Not an operator: %s" % op)
|
|
|
|
def __dfs_post(self, ids, visitor):
|
|
ids[self.id] = True
|
|
for expr in self.inputs:
|
|
if expr.id in ids:
|
|
continue
|
|
expr.__dfs_post(ids, visitor)
|
|
visitor(self)
|
|
|
|
def statements(self):
|
|
lines = []
|
|
self.__dfs_post({}, lambda that: lines.append("%s" % that))
|
|
return "\n".join(lines)
|
|
|
|
def __str__(self):
|
|
args = ",".join(["t%d" % expr.id for expr in self.inputs])
|
|
return "t%d = %s(%s)" % (self.id, self.op, args)
|
|
|
|
def __promote(r):
|
|
if isinstance(r, Expr):
|
|
return r
|
|
else:
|
|
return Const(r)
|
|
|
|
def __add__(self, r):
|
|
return Op("", "Add", 2, {})(self, Expr.__promote(r))
|
|
|
|
def __sub__(self, r):
|
|
return Op("", "Sub", 2, {})(self, Expr.__promote(r))
|
|
|
|
def __mul__(self, r):
|
|
return Op("", "Mul", 2, {})(self, Expr.__promote(r))
|
|
|
|
def __neg__(self):
|
|
return Op("", "Neg", 1, {})(self)
|
|
|
|
def compile(self, builder):
|
|
self.__dfs_post({}, lambda that: builder.append(that))
|
|
return builder.build()
|
|
|
|
def resolve(self, parameters):
|
|
self.__dfs_post({}, lambda that: that.op.resolve(parameters))
|
|
return self
|
|
|
|
|
|
class Op:
|
|
def __init__(self, name, op_type, num_args, parameters):
|
|
self.name = name
|
|
self.op_type = op_type
|
|
self.num_args = num_args
|
|
self.parameters = parameters
|
|
|
|
def __call__(self, *inputs):
|
|
if self.num_args >= 0 and self.num_args != len(inputs):
|
|
raise Exception("%s: need %d arguments but found %d" % (self, self.num_args, len(inputs)))
|
|
for i, expr in enumerate(inputs):
|
|
if not isinstance(expr, Expr):
|
|
raise Exception("%s: arg %d is not an expression: %s" % (self, i, expr))
|
|
return Expr(self, inputs)
|
|
|
|
def __str__(self):
|
|
name = "%s.%s" % (self.name, self.op_type)
|
|
if len(self.parameters) == 0:
|
|
return name
|
|
params = ",".join(["%s=%s" % (k, v.shape if hasattr(v, "shape") else v) for k, v in self.parameters.items()])
|
|
return "%s[%s]" % (name, params)
|
|
|
|
def resolve(self, parameters):
|
|
if self.name == "":
|
|
return
|
|
for k, v in parameters.items():
|
|
if k.startswith(self.name+"."):
|
|
self.parameters[k[len(self.name)+1:]] = v
|
|
|
|
|
|
def Const(c):
|
|
if isinstance(c, (int, float)):
|
|
c = float(c)
|
|
elif hasattr(c, "shape"):
|
|
c = c.astype(float)
|
|
else:
|
|
raise Exception("Const must be float or int or ndarray: %s" % c)
|
|
|
|
return Expr(Op("", "Const", 0, {
|
|
"value": c
|
|
}), [])
|
|
|
|
|
|
def Input(n):
|
|
return Expr(Op(n, "Input", 0, {}), [])
|
|
|
|
|
|
def Input2d(n, h, w, ic):
|
|
return Expr(Op(n, "Input2d", 0, {
|
|
"height": h,
|
|
"width": w,
|
|
"in_channels": ic
|
|
}), [])
|
|
|
|
|
|
def MaxPool2d(k, s):
|
|
return Op("", "MaxPool2d", 1, {
|
|
"kernel_size": k,
|
|
"stride": s
|
|
})
|
|
|
|
|
|
def ReLU():
|
|
return Op("", "ReLU", 1, {})
|
|
|
|
|
|
def Flatten():
|
|
return Op("", "Flatten", 1, {})
|
|
|
|
|
|
def Conv2d(n, ic, oc, k, p = 0):
|
|
return Op(n, "Conv2d", 1, {
|
|
"in_channels": ic,
|
|
"out_channels": oc,
|
|
"kernel_size": k,
|
|
"padding": p
|
|
})
|
|
|
|
|
|
def Linear(n, i, o):
|
|
return Op(n, "Linear", 1, {
|
|
"in_features": i,
|
|
"out_features": o
|
|
})
|
|
|
|
|
|
def Show():
|
|
return Op("", "Show", 1, {})
|
|
|