209 lines
7.1 KiB
Python
209 lines
7.1 KiB
Python
from __future__ import annotations
|
|
from abc import ABC, ABCMeta, abstractmethod
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
import aiohttp
|
|
|
|
from app.permit.engine import get_session, init_database
|
|
from sqlalchemy import select, and_, or_
|
|
from app.permit.models import PMSectionUser, PMSections
|
|
|
|
|
|
# ---------------------- SINGLETON META ----------------------
|
|
class SingletonMeta(type):
|
|
"""Ensure only one instance of each class exists."""
|
|
_instances = {}
|
|
|
|
def __call__(cls, *args, **kwargs):
|
|
if cls not in cls._instances:
|
|
instance = super().__call__(*args, **kwargs)
|
|
cls._instances[cls] = instance
|
|
return cls._instances[cls]
|
|
|
|
# ---------------------- ABC WITH SINGLETON SUPPORT ----------------------
|
|
class ABCSingletonMeta(ABCMeta, SingletonMeta):
|
|
"""Metaclass that combines ABC and Singleton functionality."""
|
|
pass
|
|
|
|
# ---------------------- ABSTRACT CONNECTION ----------------------
|
|
class Connection(ABC, metaclass=ABCSingletonMeta):
|
|
@abstractmethod
|
|
async def ping(self):
|
|
"""Check connection status."""
|
|
pass
|
|
|
|
|
|
# ---------------------- HTTP CONNECTION ----------------------
|
|
class HTTPConnection(Connection, metaclass=SingletonMeta):
|
|
def __init__(self, base_url: str):
|
|
self.base_url = base_url
|
|
self.session: aiohttp.ClientSession | None = None
|
|
|
|
@classmethod
|
|
async def create(cls, base_url: str) -> HTTPConnection:
|
|
"""Async factory method for HTTPConnection."""
|
|
self = cls(base_url)
|
|
self.session = aiohttp.ClientSession(base_url=base_url)
|
|
print(f"🌐 HTTPConnection created for {base_url}")
|
|
return self
|
|
|
|
async def ping(self) -> bool:
|
|
print(f"🌐 Ping from HTTPConnection: {self.base_url}")
|
|
return True
|
|
|
|
async def check_perm(self, action: str, user: str) -> bool:
|
|
print(f"🌐 Checking permission for user '{user}' with action '{action}'")
|
|
return True
|
|
|
|
async def close(self):
|
|
if self.session:
|
|
await self.session.close()
|
|
print("🌐 HTTPConnection session closed")
|
|
|
|
|
|
# ---------------------- DATABASE CONNECTION ----------------------
|
|
class DBConnection(Connection, metaclass=SingletonMeta):
|
|
def __init__(self, db_type: str, host: str, port: int, username: str, password: str, db_name: str):
|
|
self.db_type = db_type
|
|
self.host = host
|
|
self.port = port
|
|
self.username = username
|
|
self.password = password
|
|
self.db_name = db_name
|
|
self.session: AsyncSession | None = None
|
|
|
|
def get_connection_string(self) -> str:
|
|
"""Build a valid async connection string."""
|
|
if self.db_type == "mysql":
|
|
return f"mysql+aiomysql://{self.username}:{self.password}@{self.host}:{self.port}/{self.db_name}"
|
|
elif self.db_type in ("psql", "postgresql"):
|
|
return f"postgresql+asyncpg://{self.username}:{self.password}@{self.host}:{self.port}/{self.db_name}"
|
|
else:
|
|
raise ValueError(f"Unsupported DB type: {self.db_type}")
|
|
|
|
@classmethod
|
|
async def create(
|
|
cls,
|
|
db_type: str,
|
|
host: str,
|
|
port: int,
|
|
username: str,
|
|
password: str,
|
|
db_name: str,
|
|
) -> DBConnection:
|
|
"""Async factory method for DBConnection."""
|
|
self = cls(db_type, host, port, username, password, db_name)
|
|
|
|
conn_str = self.get_connection_string()
|
|
await init_database(conn_str)
|
|
self.session = await get_session()
|
|
print(f"🗄️ Database connection initialized: {conn_str}")
|
|
return self
|
|
|
|
async def ping(self) -> bool:
|
|
print(f"🗄️ Ping from DBConnection: {self.host}:{self.port}")
|
|
|
|
return True
|
|
|
|
async def check_perm(self, user_id: int, action_tag: str, section_tag: str = None) -> bool:
|
|
"""
|
|
بررسی دسترسی کاربر با استفاده از ORM
|
|
"""
|
|
try:
|
|
if not self.session:
|
|
raise Exception("Database session not initialized")
|
|
|
|
query = (
|
|
select(PMSections.section_tag, PMSections.action_tag)
|
|
.join(PMSectionUser, PMSectionUser.section_id == PMSections.id)
|
|
.where(
|
|
and_(
|
|
PMSectionUser.user_id == user_id,
|
|
PMSectionUser.deleted_at.is_(None),
|
|
PMSections.deleted_at.is_(None),
|
|
PMSectionUser.state == 0
|
|
)
|
|
)
|
|
)
|
|
|
|
result = await self.session.execute(query)
|
|
permissions = result.fetchall()
|
|
|
|
if not permissions:
|
|
return False
|
|
|
|
if section_tag:
|
|
for perm in permissions:
|
|
if perm.section_tag == section_tag and perm.action_tag == action_tag:
|
|
return True
|
|
else:
|
|
for perm in permissions:
|
|
if perm.action_tag == action_tag:
|
|
return True
|
|
|
|
return False
|
|
|
|
except Exception as e:
|
|
print(f"Error checking permission: {str(e)}")
|
|
return False
|
|
|
|
async def close(self):
|
|
if self.session:
|
|
await self.session.close()
|
|
print("🗄️ Database session closed")
|
|
|
|
|
|
# ---------------------- FACTORIES ----------------------
|
|
class ConnectionFactory(ABC):
|
|
@abstractmethod
|
|
async def create_connection(self) -> Connection:
|
|
pass
|
|
|
|
|
|
class FactoryHTTPConnection(ConnectionFactory):
|
|
def __init__(self, base_url: str):
|
|
self.base_url = base_url
|
|
|
|
async def create_connection(self) -> HTTPConnection:
|
|
return await HTTPConnection.create(self.base_url)
|
|
|
|
|
|
class FactoryDBConnection(ConnectionFactory):
|
|
def __init__(self, db_type: str, host: str, port: int, username: str, password: str, db_name: str):
|
|
self.db_type = db_type
|
|
self.host = host
|
|
self.port = port
|
|
self.username = username
|
|
self.password = password
|
|
self.db_name = db_name
|
|
|
|
async def create_connection(self) -> DBConnection:
|
|
return await DBConnection.create(
|
|
self.db_type,
|
|
self.host,
|
|
self.port,
|
|
self.username,
|
|
self.password,
|
|
self.db_name,
|
|
)
|
|
|
|
|
|
# ---------------------- FACTORY SELECTOR ----------------------
|
|
class FactorySelector:
|
|
@staticmethod
|
|
async def get_factory(connection_type: str, **kwargs) -> ConnectionFactory:
|
|
if connection_type == "http":
|
|
return await FactoryHTTPConnection(kwargs.get("base_url", "http://localhost")).create_connection()
|
|
|
|
elif connection_type == "db":
|
|
return await FactoryDBConnection(
|
|
db_type=kwargs.get("type", "postgresql"),
|
|
host=kwargs.get("host", "localhost"),
|
|
port=kwargs.get("port", 5432),
|
|
username=kwargs.get("username", "admin"),
|
|
password=kwargs.get("password", "password"),
|
|
db_name=kwargs.get("db_name", "test_db"),
|
|
).create_connection()
|
|
|
|
else:
|
|
raise ValueError(f"Unknown connection type: {connection_type}")
|