Source code for epimodels.validation.validators

"""
Validation functions for parameter values, initial conditions, and constraints.
"""

import re
import ast
import operator
from typing import Any
from epimodels.exceptions import ValidationError
from epimodels.validation.specs import ParameterSpec, VariableSpec


COMPARISON_OPS = {
    ast.Eq: operator.eq,
    ast.NotEq: operator.ne,
    ast.Lt: operator.lt,
    ast.LtE: operator.le,
    ast.Gt: operator.gt,
    ast.GtE: operator.ge,
    ast.And: lambda a, b: a and b,
    ast.Or: lambda a, b: a or b,
}


[docs] def validate_parameter_value( name: str, value: Any, spec: ParameterSpec, all_params: dict[str, Any] | None = None ) -> list[str]: """ Validate a parameter value against its specification. Args: name: Parameter name value: Parameter value to validate spec: Parameter specification all_params: All parameter values (for cross-parameter validation) Returns: List of error messages (empty if valid) Example: >>> spec = ParameterSpec(name="beta", symbol="β", bounds=(0, None)) >>> errors = validate_parameter_value("beta", -0.5, spec) >>> len(errors) 1 """ errors = [] if value is None: if spec.required: errors.append(f"Required parameter '{name}' is None") return errors if not isinstance(value, spec.dtype): if spec.dtype == float and isinstance(value, int): value = float(value) elif spec.dtype == int and isinstance(value, float) and value.is_integer(): value = int(value) else: errors.append( f"Parameter '{name}' has wrong type: expected {spec.dtype.__name__}, " f"got {type(value).__name__}" ) return errors if spec.bounds is not None and isinstance(value, (int, float)): min_val, max_val = spec.bounds if min_val is not None and value < min_val: errors.append(f"Parameter '{name}' value {value} is below minimum bound {min_val}") if max_val is not None and value > max_val: errors.append(f"Parameter '{name}' value {value} exceeds maximum bound {max_val}") for constraint_expr in spec.constraints: try: constraint_errors = _validate_single_constraint( constraint_expr, name, value, all_params or {} ) errors.extend(constraint_errors) except Exception as e: errors.append( f"Failed to evaluate constraint '{constraint_expr}' for parameter '{name}': {e}" ) return errors
[docs] def validate_initial_condition( name: str, value: float, spec: VariableSpec, all_values: dict[str, float] | None = None ) -> list[str]: """ Validate an initial condition value against its specification. Args: name: Variable name value: Initial condition value spec: Variable specification all_values: All initial condition values (for cross-variable validation) Returns: List of error messages (empty if valid) """ errors = [] if spec.non_negative and value < 0: errors.append(f"Initial condition '{name}' must be non-negative, got {value}") if spec.bounds is not None: min_val, max_val = spec.bounds if min_val is not None and value < min_val: errors.append(f"Initial condition '{name}' value {value} is below minimum {min_val}") if max_val is not None and value > max_val: errors.append(f"Initial condition '{name}' value {value} exceeds maximum {max_val}") for constraint_expr in spec.constraints: try: constraint_errors = _validate_single_constraint( constraint_expr, name, value, all_values or {} ) errors.extend(constraint_errors) except Exception as e: errors.append( f"Failed to evaluate constraint '{constraint_expr}' for variable '{name}': {e}" ) return errors
[docs] def evaluate_constraint(expression: str, context: dict[str, Any]) -> tuple[bool, str | None]: """ Evaluate a constraint expression in the given context. Args: expression: Constraint expression (e.g., "beta > gamma") context: Dictionary mapping names to values Returns: Tuple of (is_satisfied, error_message) Example: >>> satisfied, msg = evaluate_constraint("x > y", {"x": 5, "y": 3}) >>> satisfied True """ try: result = _safe_eval_expression(expression, context) if isinstance(result, bool): return result, None else: return False, f"Expression '{expression}' did not evaluate to boolean" except Exception as e: return False, f"Failed to evaluate expression: {e}"
def _validate_single_constraint( constraint_expr: str, param_name: str, value: Any, all_params: dict[str, Any] ) -> list[str]: """ Validate a single constraint expression for a parameter. The expression can use 'value' to refer to the current parameter value, or use the parameter name directly. """ errors = [] context = dict(all_params) context[param_name] = value context["value"] = value expr = constraint_expr.strip() if expr.startswith("value ") or " value " in expr: pass else: expr = expr.replace(param_name, "value", 1) satisfied, error_msg = evaluate_constraint(expr, context) if not satisfied: errors.append( f"Parameter '{param_name}' value {value} violates constraint: {constraint_expr}" + (f" ({error_msg})" if error_msg else "") ) return errors def _safe_eval_expression(expression: str, context: dict[str, Any]) -> Any: """ Safely evaluate a constraint expression. Uses AST parsing to allow only safe operations. Args: expression: Expression to evaluate context: Variable bindings Returns: Result of evaluation Raises: ValueError: If expression contains unsafe operations """ try: tree = ast.parse(expression, mode="eval") except SyntaxError as e: raise ValueError(f"Invalid expression syntax: {expression}") from e return _eval_node(tree.body, context) def _eval_node(node: ast.AST, context: dict[str, Any]) -> Any: """ Recursively evaluate an AST node. Only allows: - Numbers and strings - Variable names (looked up in context) - Comparison operators (==, !=, <, <=, >, >=) - Boolean operators (and, or) - Arithmetic operators (+, -, *, /, **) - Unary operators (+, -, not) """ if isinstance(node, ast.Constant): return node.value if isinstance(node, ast.Name): name = node.id if name in context: return context[name] raise ValueError(f"Unknown variable: {name}") if isinstance(node, ast.Compare): left = _eval_node(node.left, context) result = True prev_val = left for op, comparator in zip(node.ops, node.comparators): right = _eval_node(comparator, context) op_func = COMPARISON_OPS.get(type(op)) if op_func is None: raise ValueError(f"Unsupported comparison operator: {type(op).__name__}") if not op_func(prev_val, right): return False prev_val = right return True if isinstance(node, ast.BoolOp): op_func = COMPARISON_OPS.get(type(node.op)) if op_func is None: raise ValueError(f"Unsupported boolean operator: {type(node.op).__name__}") values = [_eval_node(v, context) for v in node.values] result = values[0] for v in values[1:]: result = op_func(result, v) return result if isinstance(node, ast.BinOp): left = _eval_node(node.left, context) right = _eval_node(node.right, context) if isinstance(node.op, ast.Add): return left + right elif isinstance(node.op, ast.Sub): return left - right elif isinstance(node.op, ast.Mult): return left * right elif isinstance(node.op, ast.Div): return left / right elif isinstance(node.op, ast.Pow): return left**right elif isinstance(node.op, ast.FloorDiv): return left // right elif isinstance(node.op, ast.Mod): return left % right else: raise ValueError(f"Unsupported binary operator: {type(node.op).__name__}") if isinstance(node, ast.UnaryOp): operand = _eval_node(node.operand, context) if isinstance(node.op, ast.UAdd): return +operand elif isinstance(node.op, ast.USub): return -operand elif isinstance(node.op, ast.Not): return not operand else: raise ValueError(f"Unsupported unary operator: {type(node.op).__name__}") raise ValueError(f"Unsupported expression type: {type(node).__name__}")