From 65cf19f2c4cd70f07ee132543341784e73c5bf8e Mon Sep 17 00:00:00 2001 From: eclipse Date: Fri, 25 Jul 2025 12:25:06 +0200 Subject: [PATCH] minor bugfixes and impromvements --- tests/conftest.py | 2 +- tests/integration/test_int_genre.py | 15 +++++++-------- tests/unit/test_unit_genre.py | 27 +++++++++++++-------------- tests/unit/test_unit_home.py | 8 ++++++++ the_works/views/home.py | 23 +++++++++++------------ 5 files changed, 40 insertions(+), 35 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 9c98ea2..4648ef0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,7 @@ from the_works.database import db as _db TEST_DATABASE_URI = "sqlite:///:memory:" @pytest.fixture() -def app(): +def _app(): test_config = { "ENV": "Testing", "SQLALCHEMY_DATABASE_URI": TEST_DATABASE_URI, diff --git a/tests/integration/test_int_genre.py b/tests/integration/test_int_genre.py index d337998..59823db 100644 --- a/tests/integration/test_int_genre.py +++ b/tests/integration/test_int_genre.py @@ -1,12 +1,12 @@ from sqlalchemy import select -from the_works.database import db -from the_works.models import Genre from sqlalchemy.exc import IntegrityError import pytest +from the_works.database import db +from the_works.models import Genre -def test_genre_create(client, app): +def test_genre_create(client, _app): """Integrated testing of adding a Genre record.""" - response = client.post("/genre/create", data={"form_Genre": "Test-Genre"}, follow_redirects=True) + response = client.post("/genre/create", data={"form_Genre": "spam"}, follow_redirects=True) # assert there was exactly 1 redirect assert len(response.history) == 1 @@ -15,13 +15,12 @@ def test_genre_create(client, app): assert response.status_code == 200 # assert record was successfully added to DB - with app.app_context(): - genre = db.session.scalars(select(Genre).where(Genre.Genre == "Test-Genre")).all() + with _app.app_context(): + genre = db.session.scalars(select(Genre).where(Genre.Genre == "spam")).all() assert len(genre) == 1 assert isinstance(genre[0], Genre) # assert uniqueness of records with pytest.raises(IntegrityError) as excinfo: - response = client.post("/genre/create", data={"form_Genre": "Test-Genre"}) + response = client.post("/genre/create", data={"form_Genre": "spam"}) assert "UNIQUE constraint failed" in str(excinfo.value) - diff --git a/tests/unit/test_unit_genre.py b/tests/unit/test_unit_genre.py index 34a21e9..8bea3f9 100644 --- a/tests/unit/test_unit_genre.py +++ b/tests/unit/test_unit_genre.py @@ -1,17 +1,16 @@ -from the_works.models import Genre import pytest +from the_works.models import Genre def test_genre_all(client, mocker): """Test view all() from genre.py.""" # mock database function - # Note: The original method returns an sqlalchemy.engine.Result.ScalarResult, not a list, but the template code - # uses the return value in a way that works for both ScalarResult and list -# mocker.patch("the_works.database.db.session.scalars", return_value=[ - mocker.patch("flask_sqlalchemy.SQLAlchemy.session.scalars", return_value=[ - Genre(ID=4, Genre="bla"), - Genre(ID=26, Genre="blubb") + # Note: The original scalars() method returns an sqlalchemy.engine.Result.ScalarResult, not a list + # but the template code uses the return value in a way which works for both ScalarResult and list + mocker.patch("flask_sqlalchemy.session.Session.scalars", return_value=[ + Genre(ID=4, Genre="spam"), + Genre(ID=26, Genre="eggs") ]) # test case: get request @@ -19,7 +18,7 @@ def test_genre_all(client, mocker): assert response.status_code == 200 assert response.data.count(b'\n") + +# more test cases: POST, \ No newline at end of file diff --git a/the_works/views/home.py b/the_works/views/home.py index 3a6f070..8af7f4a 100644 --- a/the_works/views/home.py +++ b/the_works/views/home.py @@ -1,40 +1,39 @@ +import inspect from flask import Blueprint, render_template, request, jsonify from sqlalchemy import select from the_works.database import db import the_works.models -import inspect bp = Blueprint("home", __name__) -# prepare list of ORM classes to be searched by search_all() -tables = [] -for name, obj in inspect.getmembers(the_works.models): - if inspect.isclass(obj) and issubclass(obj, the_works.models.Base) and obj.__name__ != "Base": - tables.append(obj) -print(f"tables is {tables}") #DEBUG +# prepare list of ORM classes to be searched by search_all() +def __tables(): + return [obj for name, obj in inspect.getmembers(the_works.models) if inspect.isclass(obj) and issubclass(obj, the_works.models.Base) and obj.__name__ != "Base"] + @bp.route("/") def startpage(): return render_template("views/home.html") + @bp.route("/search") def search_all(): # return when query is empty if not request.args.get("query"): return jsonify({}) - + # get URL parameters s = request.args.get("query") matchCase = True if request.args.get("case").lower() == "match" else False result = {} # loop over database tables - for table in tables: - text_columns = [column.key for column in table.__table__.columns if type(column.type) == db.types.TEXT] + for table in __tables(): + text_columns = [column.key for column in table.__table__.columns if isinstance(column.type, db.types.TEXT)] hits = [] # loop over table rows for row in db.session.execute(select(table)): - # loop over each text column in row + # loop over each text column in row for column in text_columns: if row[0].__getattribute__(column) is None: continue @@ -46,7 +45,7 @@ def search_all(): if s.lower() in row[0].__getattribute__(column).lower(): hits.append(row[0].asdict()) break - if hits != []: + if hits: result[table.__table__.fullname] = hits # return results return jsonify(result)