How to set up and tear down a database between tests in FastAPI?

For cleaning up after tests even when they fail (and setting up before tests), pytest provides pytest.fixture.

In your case you want to create all tables before each test, and drop them again afterwards. This can be achieved with the following fixture:

@pytest.fixture()
def test_db():
    Base.metadata.create_all(bind=engine)
    yield
    Base.metadata.drop_all(bind=engine)

And then use it in your tests like so:

def test_get_empty_todos_list(test_db):
    response = client.get('/todos/')

    assert response.status_code == 200
    assert response.json() == []

For each test that has test_db in its argument list pytest first runs Base.metadata.create_all(bind=engine), then yields to the test code, and afterwards makes sure that Base.metadata.drop_all(bind=engine) gets run, even when the tests fail.

The full code:

import pytest
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker

from main import app, get_db
from database import Base


SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db"

engine = create_engine(
    SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)


def override_get_db():
    try:
        db = TestingSessionLocal()
        yield db
    finally:
        db.close()


@pytest.fixture()
def test_db():
    Base.metadata.create_all(bind=engine)
    yield
    Base.metadata.drop_all(bind=engine)

app.dependency_overrides[get_db] = override_get_db

client = TestClient(app)


def test_get_todos(test_db):
    response = client.post("/todos/", json={"text": "some new todo"})
    data1 = response.json()
    response = client.post("/todos/", json={"text": "some even newer todo"})
    data2 = response.json()

    assert data1["user_id"] == data2["user_id"]

    response = client.get("/todos/")

    assert response.status_code == 200
    assert response.json() == [
        {"id": data1["id"], "user_id": data1["user_id"], "text": data1["text"]},
        {"id": data2["id"], "user_id": data2["user_id"], "text": data2["text"]},
    ]


def test_get_empty_todos_list(test_db):
    response = client.get("/todos/")

    assert response.status_code == 200
    assert response.json() == []

As your application grows, setting up and tearing down the whole database for each test might get slow.

A solution for that is to only set up the db once and then never actually commit anything to it.
This can be achieved using nested transactions and rollbacks:

import pytest
import sqlalchemy as sa
from fastapi.testclient import TestClient
from sqlalchemy.orm import sessionmaker

from database import Base
from main import app, get_db

SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db"

engine = sa.create_engine(
    SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

# Set up the database once
Base.metadata.drop_all(bind=engine)
Base.metadata.create_all(bind=engine)


# These two event listeners are only needed for sqlite for proper
# SAVEPOINT / nested transaction support. Other databases like postgres
# don't need them. 
# From: https://docs.sqlalchemy.org/en/14/dialects/sqlite.html#serializable-isolation-savepoints-transactional-ddl
@sa.event.listens_for(engine, "connect")
def do_connect(dbapi_connection, connection_record):
    # disable pysqlite's emitting of the BEGIN statement entirely.
    # also stops it from emitting COMMIT before any DDL.
    dbapi_connection.isolation_level = None


@sa.event.listens_for(engine, "begin")
def do_begin(conn):
    # emit our own BEGIN
    conn.exec_driver_sql("BEGIN")


# This fixture is the main difference to before. It creates a nested
# transaction, recreates it when the application code calls session.commit
# and rolls it back at the end.
# Based on: https://docs.sqlalchemy.org/en/14/orm/session_transaction.html#joining-a-session-into-an-external-transaction-such-as-for-test-suites
@pytest.fixture()
def session():
    connection = engine.connect()
    transaction = connection.begin()
    session = TestingSessionLocal(bind=connection)

    # Begin a nested transaction (using SAVEPOINT).
    nested = connection.begin_nested()

    # If the application code calls session.commit, it will end the nested
    # transaction. Need to start a new one when that happens.
    @sa.event.listens_for(session, "after_transaction_end")
    def end_savepoint(session, transaction):
        nonlocal nested
        if not nested.is_active:
            nested = connection.begin_nested()

    yield session

    # Rollback the overall transaction, restoring the state before the test ran.
    session.close()
    transaction.rollback()
    connection.close()


# A fixture for the fastapi test client which depends on the
# previous session fixture. Instead of creating a new session in the
# dependency override as before, it uses the one provided by the
# session fixture.
@pytest.fixture()
def client(session):
    def override_get_db():
        yield session

    app.dependency_overrides[get_db] = override_get_db
    yield TestClient(app)
    del app.dependency_overrides[get_db]


def test_get_empty_todos_list(client):
    response = client.get("/todos/")

    assert response.status_code == 200
    assert response.json() == []

Having two fixtures (session and client) here has an additional advantage:

If a test only talks to the API, then you don’t need to remember adding the db fixture explicitly (but it will still be invoked implicitly).
And if you want to write a test that directly talks to the db, you can do that as well:

def test_something(session):
    session.query(...)

Or both, if you for example want to prepare the db state before an API call:

def test_something_else(client, session):
    session.add(...)
    session.commit()
    client.get(...)

Both the application code and test code will see the same state of the db.

Leave a Comment