Skip to content

Part 2: Database

We create a new MyPilotDB class from scratch (not extending an existing DB). This shows the full lifecycle of adding a database to DiracX — from schema definition through to dependency injection.

Schema

The schema defines our two tables and the status enum:

gubbins-db/src/gubbins/db/sql/my_pilot_db/schema.py
from __future__ import annotations

from enum import StrEnum, auto

from diracx.db.sql.utils import datetime_now, str255
from sqlalchemy import Float, ForeignKey, String
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column


class MyPilotStatus(StrEnum):
    SUBMITTED = auto()
    RUNNING = auto()
    DONE = auto()
    FAILED = auto()


class Base(DeclarativeBase):
    type_annotation_map = {
        str255: String(255),
    }


class MyComputeElements(Base):
    __tablename__ = "MyComputeElements"
    name: Mapped[str255] = mapped_column("Name", primary_key=True)
    capacity: Mapped[int] = mapped_column("Capacity")
    success_rate: Mapped[float] = mapped_column("SuccessRate", type_=Float)
    enabled: Mapped[bool] = mapped_column("Enabled")


class MyPilotSubmissions(Base):
    __tablename__ = "MyPilotSubmissions"
    pilot_id: Mapped[int] = mapped_column(
        "PilotID", primary_key=True, autoincrement=True
    )
    ce_name: Mapped[str255] = mapped_column(
        "CEName", ForeignKey(MyComputeElements.name)
    )
    status: Mapped[str255] = mapped_column("Status")
    submitted_at: Mapped[datetime_now] = mapped_column("SubmittedAt")
    updated_at: Mapped[datetime_now] = mapped_column("UpdatedAt")

Let's unpack the DiracX database conventions used here:

  • DeclarativeBase — Each database module defines its own Base subclass. This keeps table metadata isolated so that different databases don't interfere with each other.
  • str255 — A DiracX type alias that maps to String(255) via type_annotation_map. Use it for any short text column.
  • datetime_now — Provides a server-default UTC timestamp, so you don't need to pass timestamps explicitly on insert.
  • metadata — The Base.metadata object tracks all tables belonging to this database. You'll pass it to BaseSQLDB in the next step.

The type_annotation_map pattern

The type_annotation_map on Base tells SQLAlchemy how to translate Python type annotations into SQL column types. When you write name: Mapped[str255], SQLAlchemy looks up str255 in this map and uses String(255). This keeps column type information in one place rather than repeating type_=String(255) on every column.

Why StrEnum instead of a SQLAlchemy Enum column?

We store status as a plain string rather than using a SQL ENUM type. This makes the schema portable across database backends (SQLite doesn't support ENUM) and avoids costly ALTER TABLE commands when adding new statuses. The StrEnum on the Python side still gives you autocompletion and validation.

DB class

The DB class wraps SQLAlchemy queries behind a clean async interface:

gubbins-db/src/gubbins/db/sql/my_pilot_db/db.py
from __future__ import annotations

from datetime import datetime, timezone

from diracx.db.sql.utils import BaseSQLDB
from sqlalchemy import func, insert, select, update

from .schema import Base as MyPilotDBBase
from .schema import MyComputeElements, MyPilotStatus, MyPilotSubmissions


class MyPilotDB(BaseSQLDB):
    """Database for managing pilot submissions to compute elements."""

    metadata = MyPilotDBBase.metadata

    async def add_ce(
        self, name: str, capacity: int, success_rate: float, enabled: bool = True
    ) -> None:
        stmt = insert(MyComputeElements).values(
            name=name, capacity=capacity, success_rate=success_rate, enabled=enabled
        )
        await self.conn.execute(stmt)

    async def get_available_ces(self) -> list[dict]:
        active_counts = (
            select(
                MyPilotSubmissions.ce_name,
                func.count().label("active"),
            )
            .where(
                MyPilotSubmissions.status.in_(
                    [MyPilotStatus.SUBMITTED, MyPilotStatus.RUNNING]
                )
            )
            .group_by(MyPilotSubmissions.ce_name)
            .subquery()
        )

        active = func.coalesce(active_counts.c.active, 0)
        stmt = (
            select(
                MyComputeElements.name.label("name"),
                MyComputeElements.capacity.label("capacity"),
                MyComputeElements.success_rate.label("success_rate"),
                active.label("active_pilots"),
            )
            .outerjoin(
                active_counts,
                MyComputeElements.name == active_counts.c.ce_name,
            )
            .where(
                MyComputeElements.enabled.is_(True),
                MyComputeElements.capacity > active,
            )
        )
        result = await self.conn.execute(stmt)
        return [
            {
                "name": row.name,
                "capacity": row.capacity,
                "success_rate": row.success_rate,
                "available_slots": row.capacity - row.active_pilots,
            }
            for row in result
        ]


    async def submit_pilot(self, ce_name: str) -> int:
        stmt = insert(MyPilotSubmissions).values(
            ce_name=ce_name,
            status=MyPilotStatus.SUBMITTED,
        )
        result = await self.conn.execute(stmt)
        return result.lastrowid

    async def update_pilot_status(self, pilot_id: int, status: MyPilotStatus) -> None:
        stmt = (
            update(MyPilotSubmissions)
            .where(MyPilotSubmissions.pilot_id == pilot_id)
            .values(
                status=status,
                updated_at=datetime.now(tz=timezone.utc),
            )
        )
        await self.conn.execute(stmt)

    async def get_pilots_by_status(self, status: MyPilotStatus) -> list[dict]:
        stmt = select(
            MyPilotSubmissions.pilot_id.label("pilot_id"),
            MyPilotSubmissions.ce_name.label("ce_name"),
            MyPilotSubmissions.status.label("status"),
            MyPilotSubmissions.submitted_at.label("submitted_at"),
        ).where(MyPilotSubmissions.status == status)
        result = await self.conn.execute(stmt)
        return [
            {
                "pilot_id": row.pilot_id,
                "ce_name": row.ce_name,
                "status": row.status,
                "submitted_at": row.submitted_at,
            }
            for row in result
        ]

    async def get_ce_success_rate(self, ce_name: str) -> float:
        stmt = select(MyComputeElements.success_rate).where(
            MyComputeElements.name == ce_name
        )
        result = await self.conn.execute(stmt)
        row = result.one()
        return row[0]

    async def get_pilot_summary(self) -> dict[str, int]:
        stmt = select(
            MyPilotSubmissions.status.label("status"),
            func.count().label("total"),
        ).group_by(MyPilotSubmissions.status)
        result = await self.conn.execute(stmt)
        return {row.status: row.total for row in result}

BaseSQLDB gives you several things out of the box:

  • self.conn — An async database connection, scoped to the current transaction
  • Transaction lifecycle — Transactions are opened when you enter async with db: and committed (or rolled back) on exit
  • metadata — Links the DB class to its schema so DiracX can auto-create tables

The subquery + coalesce pattern in get_available_ces()

The get_available_ces() method is the most interesting query here. It uses a subquery to count active pilots per CE, then outer-joins this to the CE table. The func.coalesce(active_counts.c.active, 0) handles CEs with no active pilots (where the outer join produces NULL). This is a common SQLAlchemy pattern for "count related rows and filter by the result".

For a deeper understanding of how transactions work, see the DB transaction model reference. The Databases explanation covers the broader architecture.

Create the __init__.py

Create an empty my_pilot_db/__init__.py file in the same directory.

Register the entry point, export, and dependency

The next three steps connect your database to the rest of DiracX. Each step serves a different purpose in the registration pipeline.

Why three registration steps?

  1. Entry point (in pyproject.toml) — Tells DiracX's plugin system that this DB exists. The entry point name becomes the DB's identifier in configuration and connection URLs.
  2. Package export (in __init__.py) — Makes the DB class importable from the top-level package so other code (routers, tasks) can reference it.
  3. Dependency injection (in depends.py) — Creates an Annotated type that FastAPI and the task worker can resolve automatically, wrapping the DB in a transaction context.

See Entrypoints and Dependency injection for the full picture.

Entry point

Add under [project.entry-points."diracx.dbs.sql"]:

gubbins-db/pyproject.toml
MyPilotDB = "gubbins.db.sql:MyPilotDB"

Package export

gubbins-db/src/gubbins/db/sql/__init__.py
from .my_pilot_db.db import MyPilotDB  # noqa: F401

__all__ += ("MyPilotDB",)  # type: ignore[assignment]

Dependency injection

gubbins-tasks/src/gubbins/tasks/depends.py
from gubbins.db.sql import MyPilotDB as _MyPilotDB  # noqa: E402

MyPilotDB = Annotated[_MyPilotDB, DBDepends(_MyPilotDB.transaction)]

__all__ += ("MyPilotDB",)  # type: ignore[assignment]

The DBDepends wrapper ensures the transaction commits before the HTTP response is sent. When used in task code, the same type annotation lets the task worker inject a database connection automatically.

Helm chart

For deployed environments (including CI), the database must also be listed in the extension's Helm chart values. This tells the infrastructure to create the database and set the DIRACX_DB_URL_MYPILOTDB environment variable that the application reads at startup.

Add under diracx.diracx.sqlDbs.dbs in your chart's values.yaml:

gubbins-charts/values.yaml
        MyPilotDB:

Local dev doesn't need this

pixi run local-start uses generate-local-urls which auto-discovers databases from entry points. The Helm chart step is only required for Kubernetes-based deployments.

Checkpoint

At this point, verify the database layer works before moving on:

pixi run pytest-gubbins-db -- -k my_pilot