MelonBlog

SqlAlchemy的简单用法

SqlAlchemy是python的一个orm框架,类似java种的mybatis。

安装

pip install sqlalchemy
# 教程使用mysql数据库
pip install pymysql

集成

1. 获取engine对象

from sqlalchemy import create_engine
import pymysql
pymysql.install_as_MySQLdb()
engine = create_engine("mysql://root:xxx@127.0.0.1:3306/company", echo=True)

2. 创建entity

from sqlalchemy import String, Integer
from sqlalchemy.orm import Mapped, mapped_column
from entity.base import Base
class Company(Base):
    __tablename__ = "company"
    code: Mapped[str] = mapped_column(String(32))
    name: Mapped[str] = mapped_column(String(32))
    exchange: Mapped[int] = mapped_column(Integer)
    data_status: Mapped[int] = mapped_column(Integer)
    disabled: Mapped[int] = mapped_column(Integer)
    def __repr__(self):
        return f'id:{self.id}, code:{self.code}, name:{self.name}, exchange:{self.exchange}, data_status:{self.data_status}, disabled:{self.disabled}, deleted:{self.deleted}'

3. 使用engine对象crud

from typing import Sequence
from integration.sql_alchemy import engine
from sqlalchemy import select, ScalarResult
from sqlalchemy.orm import Session
from entity.company import Company
class CompanyService:
    def __init__(self):
        pass
    def all(self) -> Sequence[Company]:
        with Session(engine) as session:
            stmt = select(Company).where(Company.disabled.__eq__(0)).where(Company.deleted.__eq__(0))
            result: ScalarResult[Company] = session.scalars(stmt)
            return result.all()
    def list_page(self, page: int, page_size: int, data_status: int = None) -> Sequence[Company]:
        with Session(engine) as session:
            stmt = select(Company).where(Company.disabled.__eq__(0)).where(Company.deleted.__eq__(0))
            if data_status is not None:
                stmt = stmt.where(Company.data_status.__eq__(data_status))
            stmt = stmt.limit(page_size).offset((page - 1) * page_size)
            result: ScalarResult[Company] = session.scalars(stmt)
            return result.all()
    def update_data_status(self, _id: int, data_status: int):
        with Session(engine) as session:
            stmt = select(Company).where(Company.id.__eq__(_id))
            company: Company = session.scalars(stmt).one()
            company.data_status = data_status
            session.commit()

4. 测试

from service.company_service import CompanyService
company_service = CompanyService()
company_list = company_service.list_page(1, 10)
for company in company_list:
    print(company)

教程中的mysql表

create table company
(
    id          bigint unsigned auto_increment primary key,
    code        varchar(32)                not null,
    name        varchar(32)                not null,
    exchange    tinyint unsigned default 1 not null comment '1:上证,2:深证,3:创业板,4:科创板',
    data_status tinyint unsigned default 0 not null comment '0:未抓取,1:已抓取,2:已结构化',
    disabled    tinyint unsigned default 0 not null,
    deleted     tinyint unsigned default 0 not null,
    constraint code unique (code)
);