From f62e05fd06215320fd7929e134a63dd73c173569 Mon Sep 17 00:00:00 2001 From: Ilya Bezrukov Date: Tue, 30 Jul 2024 02:37:36 +0300 Subject: [PATCH] Add DatabaseMiddleware --- mybot/__init__.py | 2 +- mybot/database/__init__.py | 1 + mybot/database/models.py | 5 ++++- mybot/middlewares/__init__.py | 6 ++++-- mybot/middlewares/database.py | 29 +++++++++++++++++++++++++++++ 5 files changed, 39 insertions(+), 4 deletions(-) create mode 100644 mybot/middlewares/database.py diff --git a/mybot/__init__.py b/mybot/__init__.py index 0ff840a..51ff870 100644 --- a/mybot/__init__.py +++ b/mybot/__init__.py @@ -19,7 +19,7 @@ def create_bot(config: Config, i18n: I18N, engine): use_class_middlewares=True, state_storage=state_storage) register_handlers(bot) - setup_middlewares(bot, i18n) + setup_middlewares(bot, i18n, engine) add_custom_filters(bot, config) return bot diff --git a/mybot/database/__init__.py b/mybot/database/__init__.py index b0d6015..5891c9a 100644 --- a/mybot/database/__init__.py +++ b/mybot/database/__init__.py @@ -2,6 +2,7 @@ from sqlalchemy import create_engine from sqlalchemy.orm import DeclarativeBase from ..config import DatabaseConfig +from .models import User def get_engine(config: DatabaseConfig): diff --git a/mybot/database/models.py b/mybot/database/models.py index 2a54791..485df99 100644 --- a/mybot/database/models.py +++ b/mybot/database/models.py @@ -8,5 +8,8 @@ class User (Base): __tablename__ = "user" id: Mapped[int] = mapped_column(BIGINT, primary_key=True, unique=True, autoincrement=False) - username: Mapped[int] = mapped_column(String(32), unique=True, nullable=True) + username: Mapped[str] = mapped_column(String(32), unique=True, nullable=True) # additional fields go here + + def __init__(self, id: int, username: str): + super().__init__(id=id, username=username) diff --git a/mybot/middlewares/__init__.py b/mybot/middlewares/__init__.py index 6efea50..4c1afbc 100644 --- a/mybot/middlewares/__init__.py +++ b/mybot/middlewares/__init__.py @@ -1,8 +1,10 @@ from telebot import TeleBot -from .arguments import ArgumentsMiddleware from ..i18n import I18N +from .arguments import ArgumentsMiddleware +from .database import DatabaseMiddleware -def setup_middlewares(bot: TeleBot, i18n: I18N): +def setup_middlewares(bot: TeleBot, i18n: I18N, engine): bot.setup_middleware(ArgumentsMiddleware(i18n)) + bot.setup_middleware(DatabaseMiddleware(engine)) diff --git a/mybot/middlewares/database.py b/mybot/middlewares/database.py new file mode 100644 index 0000000..560e709 --- /dev/null +++ b/mybot/middlewares/database.py @@ -0,0 +1,29 @@ +from telebot.handler_backends import BaseMiddleware +from telebot.types import Message, CallbackQuery + +from sqlalchemy.orm import Session + +from ..database import User + + +class DatabaseMiddleware (BaseMiddleware): + def __init__(self, engine): + super().__init__() + self.engine = engine + self.update_types = ["message", "callback_query"] + + def pre_process(self, obj: [Message, CallbackQuery], data: dict): + session = Session(self.engine) + user = session.get(User, obj.from_user.id) + if user is None: + user = User(id=obj.from_user.id, username=obj.from_user.username) + session.add(user) + session.commit() + data["db"] = session + data["user"] = user + + def post_process(self, message, data: dict, exception: BaseException): + if "db" in data: + session: Session = data["db"] + session.rollback() + session.close()