posts - 225, comments - 62, trackbacks - 0, articles - 0
   :: 首页 :: 新随笔 :: 联系 :: 聚合  :: 管理

db.py
# -*- coding: utf-8 -*-
from sqlalchemy import Table, Column, Integer, Float, Numeric, String, DateTime, Date, Time, Text
from sqlalchemy.dialects.postgresql import BIGINT as Int8
from sqlalchemy import ForeignKey, PrimaryKeyConstraint, UniqueConstraint, CheckConstraint
from sqlalchemy import text, func, and_, or_, not_, asc, desc, distinct, inspect
from sqlalchemy.orm import relationship, backref, raiseload
from sqlalchemy.ext.declarative import declarative_base, declared_attr
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from contextlib import contextmanager
import re
import datetime
import types

class dbModelBase(object):
    @declared_attr
    def __tablename__(cls):
        # return cls.__name__.lower()
        return re.sub(r'([A-Z])', r'_\1', cls.__name__[0].lower()+cls.__name__[1:] ).lower()

    def __repr__(self):
        attrs = []
        for c in self.__table__.columns:
            attr = getattr(self, c.name)
            if type(attr) in (str, datetime.date, datetime.time, datetime.datetime):
                attrs.append(f"{c.name}='{attr}'")
            else:
                attrs.append(f"{c.name}={attr}")
        return f"{self.__class__.__name__}({', '.join(attrs)})"

    def keys(self):
        return [c.name for c in self.__table__.columns]

    def __getitem__(self, item):
        return getattr(self, item)

    def to_dict(self):
        return {c.name: getattr(self, c.name) for c in self.__table__.columns}

def to_dict(db_obj):
    return db_obj.to_dict()

def to_list(db_objs):
    return [db_obj.to_dict() for db_obj in db_objs]

# 把ORM对象转成可序列化成JSON的对象,对于ORM对象的list转换为dict的list,对于ORM对象转换成dict
def to_jsonable(o):
    if type(o) == list:
        return to_list(o)
    else:
        return to_dict(o)

engines = {}

def init_engine(
    url=None, name="main",
    dialect=None, username=None, password=None, server=None, dbname=None,
    **kwargs):
    if url is None:
        url = '{}://{}:{}@{}/{}'.format(dialect, username, password, server, dbname)
    engine = create_engine(url, **kwargs)
    Session = sessionmaker(expire_on_commit=False)
    Session.configure(bind=engine)
    @contextmanager
    def session_scope(Session):
        session = Session()
        try:
            yield session
            session.commit()
        except Exception as e:
            session.rollback()
            raise
        finally:
            session.close()
    engine.session_scope = types.MethodType(session_scope, Session)
    engines[name] = engine
    return engine

def get_engine(name="main"):
    if name in engines:
        return engines[name]
    else:
        return None

def session_scope(name="main"):
    if name in engines:
        return engines[name].session_scope()
    else:
        raise Exception("engine未初始化")

if __name__ == '__main__':
    Base = declarative_base(cls=dbModelBase)
    class LogRecord(Base):
        id = Column(Integer(), primary_key=True, autoincrement=True)
        strategy_id = Column(Integer(), comment="策略ID")
        account_id = Column(Integer(), comment="资金账户ID")
        content = Column(Text(), comment="内容")
        create_time = Column(DateTime(), comment="创建时间")
    # 连接一个数据库时
    init_engine('sqlite:///:memory:', echo=False)
    Base.metadata.create_all(get_engine())
    with session_scope() as session:
        record = LogRecord(
            strategy_id=1,
            account_id=1,
            content="Hello world"
            )
        session.add(record)
        r = session.query(LogRecord).all()
        print(r)


    # 同时连接两个不同数据库时
    engine1 = init_engine('sqlite:///:memory:', echo=False)
    Base.metadata.create_all(engine1)
    with engine1.session_scope() as session:
        record = LogRecord(
            strategy_id=1,
            account_id=1,
            content="Hello world"
            )
        session.add(record)
        r = session.query(LogRecord).all()
        print(r)

    Base2 = declarative_base(cls=dbModelBase)
    class Stu(Base2):
        id = Column(Integer(), primary_key=True, autoincrement=True)
        name = Column(String(), comment="姓名")
    engine2 = init_engine('sqlite:///:memory:', echo=True)
    Base2.metadata.create_all(engine2)
    with engine2.session_scope() as session:
        record = Stu(
            name="Tom"
            )
        session.add(record)
        r = session.query(Stu).all()
        print(r)

        print(session.query(LogRecord).all()) # 会报错,因为这个数据库没有LogRecord表,还会看到ROLLBACK

补充一下在另一个文件引入db.py定义数据库模型,并定义 唯一键、索引、外键、多对一(一对多)映射、多对多映射
# -*- coding: utf-8 -*-
import logging
import re
import datetime
from sqlalchemy import Column, Integer, Date, DateTime, Text, String, Float
from sqlalchemy import UniqueConstraint, Index, ForeignKey
from sqlalchemy.orm import relationship, backref, raiseload
from sqlalchemy.ext.declarative import declarative_base
from .db import ModelBase, init_engine as _init_engine, session_scope as _session_scope

Base = declarative_base(cls=ModelBase)


def init_engine(*args, **kwargs):
    engine = _init_engine(name="strategy_db", *args, **kwargs)
    Base.metadata.create_all(engine)
    return engine


def session_scope():
    return _session_scope(name="strategy_db")

class RotationSignalPosition(Base):
    __table_args__ = (
        UniqueConstraint('strategy_code''account_code''signal_name''tradingday''cycle_num'),  # Unique
        Index('ix_rotation_signal_position_stra_acco_trad''strategy_code''account_code''tradingday'),  # 联合Index
        )
    id = Column(Integer(), primary_key=True, autoincrement=True)
    strategy_code = Column(String(64), nullable=False, comment="策略编码")
    account_code = Column(String(64), nullable=False, comment="资金账户编码")
    signal_name = Column(String(64), nullable=False, comment="信号名称")
    tradingday = Column(Date(), nullable=False, index=True, comment="交易日")  # 单字段Index
    cycle_num = Column(Integer(), nullable=False, comment="日内:序号(1~30);日间:序号238")
    contract_ic = Column(String(64), nullable=False, comment="IC合约代码")
    contract_ifh = Column(String(64), nullable=False, comment="IFH合约代码")
    position_ic = Column(Integer(), nullable=False, comment="IC当前持仓量")
    position_ifh = Column(Integer(), nullable=False, comment="IFH当前持仓量")
    open_tradingday = Column(Date(), comment="开仓交易日")


class Fund(Base):
    id = Column(Integer(), primary_key=True, autoincrement=True)
    code = Column(String(), unique=True, nullable=False, comment="基金编码")
    name = Column(String(), comment="基金名称")

class Strategy(Base):
    id = Column(Integer(), primary_key=True, autoincrement=True)
    code = Column(String(), unique=True, nullable=False, comment="策略编码")
    name = Column(String(), comment="策略名称")
    fund_id = Column(Integer(), ForeignKey('fund.id'), comment="基金ID")  # 外键

    fund = relationship('Fund', backref=backref('strategies', order_by=id))  # 多对一(一对多)映射关系


class Trader(Base):
    id = Column(Integer(), primary_key=True, autoincrement=True)
    username = Column(String(), unique=True, nullable=False, comment="交易员用户名")
    password = Column(String(), nullable=False, comment="密码(加密存储)")
    name = Column(String(), comment="交易员姓名")

    strategies = relationship('Strategy', secondary='strategy_trader', backref=backref('traders'))  # 多对多映射关系

# 多对多关系表
strategy_trader = Table('strategy_trader', Base.metadata,
   Column('strategy_id', Integer(), ForeignKey('strategy.id'), primary_key=True),
   Column('trader_id', Integer(), ForeignKey('trader.id'), primary_key=True))

下面是以前写的例子,上面的例子关于建立映射关系的部分没写,可以参照下面
--------------------分割线---------------------------
用例包含了:
连接数据库
建立数据库Table和Python class的对应关系
建立多对一,多对多的映射关系
事务自动提交和自动回滚(遇到未处理异常时)
对象的json序列化

from sqlalchemy import Table, Column, Integer, Float, Numeric, String, DateTime, Date, Time
from sqlalchemy.dialects.postgresql import BIGINT as Int8
from sqlalchemy import ForeignKey, PrimaryKeyConstraint, UniqueConstraint, CheckConstraint
from sqlalchemy import text, func, and_, or_, not_, asc, inspect, desc, distinct
from sqlalchemy.orm import relationship, backref
from sqlalchemy.ext.declarative import declarative_base, declared_attr
import re
import datetime

def xor_(exp1, exp2):
    
return and_(or_(exp1, exp2), not_(and_(exp1, exp2)))

def nxor_(exp1, exp2):
    
return not_(xor_(exp1, exp2))

class dbModelBase(object):
    @declared_attr
    
def __tablename__(cls):
        clsname 
= cls.__name__
        
# 数据库表名 对应 类名(驼峰转小写加下划线)
        return re.sub(r'([A-Z])', r'_\1', clsname[0].lower()+clsname[1:] ).lower()

    
def __repr__(self):
        attrs 
= []
        
for c in self.__table__.columns:
            attr 
= getattr(self, c.name)
            
if type(attr) in (str, datetime.date, datetime.time, datetime.datetime):
                attrs.append(f
"{c.name}='{attr}'")
            
else:
                attrs.append(f
"{c.name}={attr}")
        
return f"{self.__class__.__name__}({', '.join(attrs)})"

    
def keys(self):
        
return [c.name for c in self.__table__.columns]

    
def __getitem__(self, item):
        
return getattr(self, item)

    
def to_dict(self):
        
return {c.name: getattr(self, c.name) for c in self.__table__.columns}

def to_dict(db_obj):
    
return db_obj.to_dict()

def to_list(db_objs):
    
return [db_obj.to_dict() for db_obj in db_objs]

# 把ORM对象转成可序列化成JSON的对象,对于ORM对象的list转换为dict的list,对于ORM对象转换成dict
def to_jsonable(o):
    
if type(o) == list:
        
return to_list(o)
    
else:
        
return to_dict(o)

Base 
= declarative_base(cls=dbModelBase)

class Fund(Base):
    id 
= Column(Integer(), primary_key=True, autoincrement=True)
    code 
= Column(String(), unique=True, nullable=False, comment="基金编码")
    name 
= Column(String(), comment="基金名称")

class Strategy(Base):
    id 
= Column(Integer(), primary_key=True, autoincrement=True)
    code 
= Column(String(), unique=True, nullable=False, comment="策略编码")
    name 
= Column(String(), comment="策略名称")
    fund_id 
= Column(Integer(), ForeignKey('fund.id'), comment="基金ID")

    fund 
= relationship('Fund', backref=backref('strategies', order_by=id))

class Trader(Base):
    id 
= Column(Integer(), primary_key=True, autoincrement=True)
    username 
= Column(String(), unique=True, nullable=False, comment="交易员用户名")
    password 
= Column(String(), nullable=False, comment="密码")
    name 
= Column(String(), comment="交易员姓名")

    strategies 
= relationship('Strategy', secondary='strategy_trader', backref=backref('traders')) # 多对多映射关系

# 多对多关系表
strategy_trader = Table('strategy_trader', Base.metadata,
   Column(
'strategy_id', Integer(), ForeignKey('strategy.id'), primary_key=True),
   Column(
'trader_id', Integer(), ForeignKey('trader.id'), primary_key=True))

from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from contextlib import contextmanager

Session 
= sessionmaker(expire_on_commit=False)
@contextmanager
def session_scope():
    
"""Provide a transactional scope around a series of operations."""
    session 
= Session()
    
try:
        
yield session
        session.commit()
    
except Exception as e:
        session.rollback()
        
raise
    
finally:
        session.close()

def init_engine(dialect=None, username=None, password=None, server=None, dbname=None,
    url
=None, recreate_all=False, **kwargs):
    
if url == None:
        url 
= '{}://{}:{}@{}/{}'.format(dialect, username, password, server, dbname)
    engine 
= create_engine(url, **kwargs)
    Session.configure(bind
=engine)
    
if recreate_all:
        Base.metadata.drop_all(engine)
    Base.metadata.create_all(engine)
    
return engine

if __name__ == '__main__':
    
# init_engine(dialect = 'postgresql+psycopg2', username = 'dbuser', password = 'xxxxxx', server = 'ip_address', dbname = 'dbname',
    #     client_encoding='utf8',
    #     echo=True,
    #     isolation_level="REPEATABLE_READ")
    init_engine(url='sqlite:///:memory:', echo=True)
    with session_scope() as session:
        fund 
= Fund(
            code 
= "fund001",
            name 
= "基金一号"
            ) 
# 创建一个fund
        fund.strategies.append(Strategy(
            code 
= "s01",
            name 
= "策略一号"
            )) 
# 创建strategy加入到fund.strategies中,在fund插入的时候会被级联插入
        fund.strategies.append(Strategy(
            code 
= "s02",
            name 
= "策略二号"
            ))
        
assert(inspect(fund).transient) # 创建出来但未调用session.add的对象处于transient态
        session.add(fund)
        
assert(inspect(fund).pending)   # 调用过session.add之后还尚未flush到数据库中,这时处于pending态
        session.flush()
        
assert(inspect(fund).persistent)# 手动调用flush 或者 等session结束(确切的说应该是事务被commit的时候),会变成persistent态
        strategy = Strategy(
            code 
= "s03",
            name 
= "三号策略",
            fund_id 
= fund.id # 只有fund对象真的写入了数据库也就是变成persistent态的时候才能取得fund.id,因为其自增key是数据库维护的
            )
        session.add(strategy) 
# 单独插入一个strategy,手动维护fund_id

    with session_scope() as session:
        strategies 
= session.query(Strategy).filter(
            Strategy.name.like(
'%策略%')
            ).all() 
# 返回查询结果list
        for s in strategies:
            
print(s)

    with session_scope() as session:
        trader 
= Trader(
            username 
= "daimingzhuang",
            password 
= "xxxxxx"
            )
        strategy 
= session.query(Strategy).filter(
            Strategy.code 
== "s01"
            ).first() 
# 返回第一个查询结果
        trader.strategies.append(strategy)  # 如果操作trader端建立关系

        session.add(trader)
        session.flush()

        strategy 
= session.query(Strategy).filter(
            Strategy.code 
== "s02"
            ).first()
        strategy.traders.append(trader) 
# 如果操作strategy端建立关系

        
# session 结束自动commit,上面的strategy的修改不需要显式保存,会自动保存

    with session_scope() as session:
        trader 
= session.query(Trader).filter(Trader.username=="daimingzhuang").first()
        
print(trader)
        
for s in trader.strategies:
            
print(s) # 可以看到s01、s02都被打印了
只有注册用户登录后才能发表评论。