1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
| class SimpleCalculatorKJ(io.ComfyNode): @classmethod def define_schema(cls): template = io.Autogrow.TemplateNames(input=io.MultiType.Input("var", [io.Int, io.Float, io.Boolean], optional=True), names=["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k"], min=2) return io.Schema( node_id="SimpleCalculatorKJ", category="KJNodes/misc", description=""" Calculator node that evaluates a mathematical expression using inputs a and b. Supported operations: +, -, *, /, //, %, **, <<, >>, unary +/- Supported comparisons: ==, !=, <, <=, >, >= Supported logic: and, or, not Supported functions: abs(), round(), min(), max(), pow(), sqrt(), sin(), cos(), tan(), log(), log10(), exp(), floor(), ceil() Supported constants: pi, euler, True, False """, search_aliases=["math", "arithmetic", "expression", "logic"], inputs=[ io.String.Input("expression", default="a + b", multiline=True), io.Autogrow.Input("variables", template=template), ], outputs=[ io.Float.Output(), io.Int.Output(), io.Boolean.Output(), ], )
@classmethod def execute(cls, variables, expression, a=None, b=None) -> io.NodeOutput: import ast import operator
allowed_operators = { ast.Add: operator.add, ast.Sub: operator.sub, ast.Mult: operator.mul, ast.Div: operator.truediv, ast.FloorDiv: operator.floordiv, ast.Mod: operator.mod, ast.Pow: operator.pow, ast.USub: operator.neg, ast.UAdd: operator.pos, ast.LShift: operator.lshift, ast.RShift: operator.rshift, ast.Eq: operator.eq, ast.NotEq: operator.ne, ast.Lt: operator.lt, ast.LtE: operator.le, ast.Gt: operator.gt, ast.GtE: operator.ge, ast.And: operator.and_, ast.Or: operator.or_, ast.Not: operator.not_, }
allowed_functions = { 'abs': abs, 'round': round, 'min': min, 'max': max, 'pow': pow, 'sqrt': math.sqrt, 'sin': math.sin, 'cos': math.cos, 'tan': math.tan, 'log': math.log, 'log10': math.log10, 'exp': math.exp, 'floor': math.floor, 'ceil': math.ceil }
allowed_names = {'pi': math.pi, 'euler': math.e, 'True': True, 'False': False}
for var_name, var_value in variables.items(): allowed_names[var_name] = var_value
if a is not None: allowed_names['a'] = a if b is not None: allowed_names['b'] = b
def eval_node(node): if isinstance(node, ast.Constant): return node.value elif isinstance(node, ast.Name): if node.id in allowed_names: return allowed_names[node.id] raise ValueError(f"Name '{node.id}' is not allowed") elif isinstance(node, ast.BinOp): if type(node.op) not in allowed_operators: raise ValueError(f"Operator {type(node.op).__name__} is not allowed") left = eval_node(node.left) right = eval_node(node.right) return allowed_operators[type(node.op)](left, right) elif isinstance(node, ast.UnaryOp): if type(node.op) not in allowed_operators: raise ValueError(f"Operator {type(node.op).__name__} is not allowed") operand = eval_node(node.operand) return allowed_operators[type(node.op)](operand) elif isinstance(node, ast.Compare): left = eval_node(node.left) for op, comparator in zip(node.ops, node.comparators): if type(op) not in allowed_operators: raise ValueError(f"Operator {type(op).__name__} is not allowed") right = eval_node(comparator) result = allowed_operators[type(op)](left, right) if not result: return False left = right return True elif isinstance(node, ast.BoolOp): if type(node.op) not in allowed_operators: raise ValueError(f"Operator {type(node.op).__name__} is not allowed") values = [eval_node(value) for value in node.values] if isinstance(node.op, ast.And): return all(values) elif isinstance(node.op, ast.Or): return any(values) elif isinstance(node, ast.Call): if not isinstance(node.func, ast.Name): raise ValueError("Only simple function calls are allowed") if node.func.id not in allowed_functions: raise ValueError(f"Function '{node.func.id}' is not allowed") args = [eval_node(arg) for arg in node.args] return allowed_functions[node.func.id](*args) else: raise ValueError(f"Node type {type(node).__name__} is not allowed")
try: tree = ast.parse(expression, mode='eval') result = eval_node(tree.body) return io.NodeOutput(float(result), int(result), bool(result)) except Exception as e: logging.error(f"CalculatorKJ Error: {str(e)}") return io.NodeOutput(0.0, 0, False)
|