Compare commits
4 Commits
d2c6b972e5
...
f8c09d963e
| Author | SHA1 | Date | |
|---|---|---|---|
| f8c09d963e | |||
| cd5845c180 | |||
| f62e05fd06 | |||
| 3ec83e408c |
@ -19,7 +19,7 @@ def create_bot(config: Config, i18n: I18N, engine):
|
|||||||
use_class_middlewares=True,
|
use_class_middlewares=True,
|
||||||
state_storage=state_storage)
|
state_storage=state_storage)
|
||||||
register_handlers(bot)
|
register_handlers(bot)
|
||||||
setup_middlewares(bot, i18n)
|
setup_middlewares(bot, i18n, engine)
|
||||||
add_custom_filters(bot, config)
|
add_custom_filters(bot, config)
|
||||||
return bot
|
return bot
|
||||||
|
|
||||||
|
|||||||
@ -54,10 +54,14 @@ class StateStorageConfig:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class DatabaseConfig:
|
class DatabaseConfig:
|
||||||
url: str
|
url: str
|
||||||
|
pool_recycle: int
|
||||||
|
pool_pre_ping: bool
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_env(cls):
|
def from_env(cls):
|
||||||
return cls(os.getenv("DATABASE_URL", "sqlite:///bot.db"))
|
return cls(os.getenv("DB_URL", "sqlite:///bot.db"),
|
||||||
|
int(os.getenv("DB_POOL_RECYCLE", 3600)),
|
||||||
|
bool(int(os.getenv("DB_POOL_PRE_PING", True))))
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@ -6,8 +6,8 @@ from ..config import DatabaseConfig
|
|||||||
|
|
||||||
def get_engine(config: DatabaseConfig):
|
def get_engine(config: DatabaseConfig):
|
||||||
engine = create_engine(config.url,
|
engine = create_engine(config.url,
|
||||||
pool_recycle=3600,
|
pool_recycle=config.pool_recycle,
|
||||||
pool_pre_ping=True)
|
pool_pre_ping=config.pool_pre_ping)
|
||||||
return engine
|
return engine
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -8,5 +8,8 @@ class User (Base):
|
|||||||
__tablename__ = "user"
|
__tablename__ = "user"
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(BIGINT, primary_key=True, unique=True, autoincrement=False)
|
id: Mapped[int] = mapped_column(BIGINT, primary_key=True, unique=True, autoincrement=False)
|
||||||
username: Mapped[int] = mapped_column(String(32), unique=True, nullable=True)
|
username: Mapped[str] = mapped_column(String(32), unique=True, nullable=True)
|
||||||
# additional fields go here
|
# additional fields go here
|
||||||
|
|
||||||
|
def __init__(self, id: int, username: str):
|
||||||
|
super().__init__(id=id, username=username)
|
||||||
|
|||||||
@ -1,8 +1,10 @@
|
|||||||
from telebot import TeleBot
|
from telebot import TeleBot
|
||||||
|
|
||||||
from .arguments import ExtraArguments
|
|
||||||
from ..i18n import I18N
|
from ..i18n import I18N
|
||||||
|
from .arguments import ArgumentsMiddleware
|
||||||
|
from .database import DatabaseMiddleware
|
||||||
|
|
||||||
|
|
||||||
def setup_middlewares(bot: TeleBot, i18n: I18N):
|
def setup_middlewares(bot: TeleBot, i18n: I18N, engine):
|
||||||
bot.setup_middleware(ExtraArguments(i18n))
|
bot.setup_middleware(ArgumentsMiddleware(i18n))
|
||||||
|
bot.setup_middleware(DatabaseMiddleware(engine))
|
||||||
|
|||||||
@ -2,7 +2,7 @@ from telebot.handler_backends import BaseMiddleware
|
|||||||
from telebot.types import Message, CallbackQuery
|
from telebot.types import Message, CallbackQuery
|
||||||
|
|
||||||
|
|
||||||
class ExtraArguments(BaseMiddleware):
|
class ArgumentsMiddleware (BaseMiddleware):
|
||||||
def __init__(self, i18n):
|
def __init__(self, i18n):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.i18n = i18n
|
self.i18n = i18n
|
||||||
|
|||||||
29
mybot/middlewares/database.py
Normal file
29
mybot/middlewares/database.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
from telebot.handler_backends import BaseMiddleware
|
||||||
|
from telebot.types import Message, CallbackQuery
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from ..database.models import User
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseMiddleware (BaseMiddleware):
|
||||||
|
def __init__(self, engine):
|
||||||
|
super().__init__()
|
||||||
|
self.engine = engine
|
||||||
|
self.update_types = ["message", "callback_query"]
|
||||||
|
|
||||||
|
def pre_process(self, obj: [Message, CallbackQuery], data: dict):
|
||||||
|
session = Session(self.engine)
|
||||||
|
user = session.get(User, obj.from_user.id)
|
||||||
|
if user is None:
|
||||||
|
user = User(id=obj.from_user.id, username=obj.from_user.username)
|
||||||
|
session.add(user)
|
||||||
|
session.commit()
|
||||||
|
data["db"] = session
|
||||||
|
data["user"] = user
|
||||||
|
|
||||||
|
def post_process(self, message, data: dict, exception: BaseException):
|
||||||
|
if "db" in data:
|
||||||
|
session: Session = data["db"]
|
||||||
|
session.rollback()
|
||||||
|
session.close()
|
||||||
Loading…
x
Reference in New Issue
Block a user