INTRODUCTION
While working on a predictive analytics AI platform for a global logistics client, our engineering team was tasked with modeling spatial boundaries and physical constraints for various freight types. To achieve high-performance physics simulations and hardware acceleration, we leveraged Python with JAX and the Equinox library. Equinox is phenomenal for defining neural networks and mathematical models as PyTrees, but it comes with strict architectural opinions.
During the implementation phase, we encountered a situation where we needed to define geometric bounds for cargo containers. We had already created an abstract base class for shapes and a concrete class for elliptical footprints. When it came time to model standard cylindrical drums, our immediate instinct was to simply subclass the ellipse to create a circle. However, Equinox strongly enforces the abstract/final design pattern, meaning concrete classes should never be subclassed.
This architectural constraint matters in production because violating it breaks the predictable PyTree registration required by JAX, leading to subtle compilation failures during Just-In-Time (JIT) processing and loss of type safety. This challenge inspired this article, demonstrating how to seamlessly wrap a concrete subclass using composition instead of inheritance so other engineering teams can avoid the same pitfall.
PROBLEM CONTEXT: MODELING SPATIAL BOUNDARIES
In our logistics platform, accurately computing the area and physical footprint of freight was critical for the load-balancing simulation. We defined an AbstractShape using equinox.Module, which requires explicit type annotations and abstract methods to ensure consistency across all shape models.
Here is a simplified look at the initial class structure we were working with:
from abc import abstractmethod
import jax.numpy as jnp
from jaxtyping import Float
import equinox as eqx
class AbstractShape(eqx.Module):
@property
@abstractmethod
def area(self) -> Float:
pass
@abstractmethod
def compute_area(self) -> Float:
raise NotImplementedError
class Ellipse(AbstractShape):
major_axis: Float
minor_axis: Float
@property
def area(self) -> Float:
return self.compute_area()
def compute_area(self) -> Float:
return jnp.pi * self.major_axis * self.minor_axis
When engineering leaders look to hire python developers for scalable data systems, they expect teams to foresee how foundational models will scale. Our goal was to add a Circle model. A circle is mathematically a specific type of ellipse where the major and minor axes are equal. Naturally, traditional Object-Oriented Programming (OOP) principles suggest that Circle should inherit from Ellipse.
WHAT WENT WRONG: THE INHERITANCE TRAP
We initially drafted a Circle class that subclassed Ellipse, overriding the __init__ method to accept only a radius and pass it twice to the parent constructor. However, the Equinox framework explicitly forbids this. In the abstract/final design pattern, a class is either strictly abstract (meant to be subclassed) or strictly final (concrete and never subclassed).
Attempting to subclass Ellipse resulted in a few key issues:
- PyTree Registration Clashes: Equinox registers
equinox.Modulesubclasses as JAX PyTrees. Deep inheritance trees with multiple concrete layers confuse the flattening and unflattening processes required for JIT compilation and vectorization (vmap). - Dataclass Mutability Conflicts: Because
equinox.Moduleleverages Python dataclasses under the hood, subclassing a concrete dataclass often causes unexpected behavior with field resolution and initialization. - Type Checker Warnings: Static type checkers (like MyPy or Pyright) flag subclasses of concrete Equinox modules, degrading the strict type-safety we rely on for enterprise AI systems.
HOW WE APPROACHED THE SOLUTION: COMPOSITION OVER INHERITANCE
To adhere to the abstract/final pattern, we had to rethink our approach. The Equinox documentation suggests: “prefer composition over inheritance. Write a wrapper that forwards each method as appropriate.”
We needed to implement a Circle class that perfectly satisfied the AbstractShape interface without inheriting from Ellipse. Instead, the Circle class would contain an instance of Ellipse and delegate the heavy lifting to it. This approach is known as the Wrapper or Decorator pattern (in the architectural sense, not Python decorators).
When companies hire ai developers for production deployment, understanding these nuanced framework constraints is what separates prototype code from production-grade, maintainable software. By adopting composition, we decoupled the internal mathematical logic of the ellipse from the outward-facing API of the circle. We debated the minor performance overhead of delegating method calls, but because JAX aggressively inlines and compiles these operations down to XLA (Accelerated Linear Algebra), the runtime overhead was literally zero.
FINAL IMPLEMENTATION: THE WRAPPER PATTERN
Here is the sanitized, generalized code demonstrating the final implementation. We created the Circle as a direct subclass of AbstractShape (maintaining the abstract/final rule) and wrapped the Ellipse internally.
class Circle(AbstractShape):
# Delegate internal representation to the concrete Ellipse class
_delegate: Ellipse
def __init__(self, radius: Float):
# Initialize the wrapped concrete class
self._delegate = Ellipse(major_axis=radius, minor_axis=radius)
@property
def area(self) -> Float:
# Forward the property call
return self._delegate.area
def compute_area(self) -> Float:
# Forward the method call
return self._delegate.compute_area()
Validation Steps:
- We passed instances of
Circlethroughjax.jitto verify that the PyTree flattening worked correctly. Because_delegateis naturally recognized as a valid PyTree node by Equinox, JAX unflattened the nested structure seamlessly. - We verified that static type checking passed without warnings. The
Circlefully implemented theAbstractShapecontract. - We ensured no mutations were attempted, preserving the immutable nature expected of JAX data structures.
LESSONS FOR ENGINEERING TEAMS
When an enterprise decides to hire software developer talent, they expect engineers to solve problems while aligning with framework philosophies, not fighting them. Here are the actionable takeaways from this architectural adjustment:
- Respect Framework Paradigms: If a tool like Equinox demands the abstract/final pattern, lean into it. Fighting the framework leads to fragile, unmaintainable code.
- Embrace Composition Over Inheritance: Deep inheritance hierarchies create brittle code. Composition provides flexible, modular designs that are easier to test and reason about.
- Forwarding is Explicit: While writing wrapper methods (like forwarding
compute_area) feels like boilerplate, it creates highly explicit, readable contracts. - JAX PyTrees Demand Flatness: JAX functions operate best on shallow, predictable data structures. Composition keeps the PyTree schema clean.
- Type Safety is Non-Negotiable: By subclassing only abstract bases, your static analysis tools can definitively prove that all concrete subclasses fulfill the required interface.
WRAP UP
Applying the abstract/final design pattern requires a shift in mindset for engineers accustomed to deep OOP inheritance trees. By utilizing class composition and method forwarding, we successfully extended our geometric modeling system within the strict bounds of Python’s Equinox library. This approach maintained type safety, guaranteed JAX PyTree compatibility, and ultimately improved the modularity of our predictive AI platform. If your team is navigating complex Python architectural challenges and you want to ensure your systems scale securely, contact us.
Social Hashtags
#Python #JAX #Equinox #MachineLearning #AIEngineering #SoftwareArchitecture #PythonProgramming #OOP #DataEngineering #PyTree #TypeSafety #CleanCode #MLOps #AIInfrastructure #Developers
Frequently Asked Questions
Equinox uses the abstract/final pattern to ensure predictable class structures for JAX PyTree registration. Allowing concrete classes to be subclassed introduces ambiguity in how dataclass fields are flattened and reconstructed, which can break JAX transformations like JIT and vmap.
While standard Python incurs a tiny overhead for an extra function call when forwarding methods, JAX completely eliminates this during JIT compilation. The XLA compiler aggressively inlines the delegated methods, resulting in zero runtime overhead.
You should seek specialized Python developers when your legacy numerical or scientific computing pipelines need to be modernized for GPU/TPU acceleration. JAX and Equinox require a functional programming mindset and an understanding of immutable state, which differs from traditional backend Python development.
Inheritance makes Circle a literal subtype of Ellipse (an "is-a" relationship). Composition makes Circle a standalone shape that internally uses an Ellipse to perform calculations (a "has-a" relationship), thereby avoiding the framework's restrictions on concrete subclassing.
Success Stories That Inspire
See how our team takes complex business challenges and turns them into powerful, scalable digital solutions. From custom software and web applications to automation, integrations, and cloud-ready systems, each project reflects our commitment to innovation, performance, and long-term value.

California-based SMB Hired Dedicated Developers to Build a Photography SaaS Platform

Swedish Agency Built a Laravel-Based Staffing System by Hiring a Dedicated Remote Team

















