posts - 225, comments - 62, trackbacks - 0, articles - 0
   :: 首页 :: 新随笔 :: 联系 :: 聚合  :: 管理

通用的插入更新(Upsert)实现

Posted on 2020-08-31 18:31 魔のkyo 阅读(168) 评论(0)  编辑 收藏 引用
# -*- coding: utf-8 -*-

import math
import datetime
import numpy as np
import pandas as pd
import sqlalchemy


'''
engine: SQLAlchemy Engine
buffer_size: 缓存条目数,当缓存满时自动flush
update_on_duplicate: 当唯一键重复时的行为,默认是update,设置为False表示不更新,即忽略插入失败。
'''
def create_upsert_handler(engine, buffer_size=5000, update_on_duplicate=True):
    if engine.dialect.name.lower().find("mysql") != -1:
        return MySQLUpsertHandler(engine, buffer_size, update_on_duplicate)
    elif engine.dialect.name.lower().find("postgresql") != -1:
        return PSQLUpsertHandler(engine, buffer_size, update_on_duplicate)
    else:
        print(f"没有为{engine.dialect.name}实现特殊的Upsert,使用默认版本,请确认可以正常工作,建议特化一个专门版本")
        return UpsertHandlerBase(engine, buffer_size, update_on_duplicate)


def is_duplicate_key(e):
    for T in UpsertHandlerBase.__subclasses__():
        if T.is_duplicate_key(e):
            return True
    return UpsertHandlerBase.is_duplicate_key(e)


'''
class UpsertHandler:

    # 传入的engine类型应该和使用的UpsertHandler支持的数据库类型相匹配
    # buffer_size表示插入或更新数据缓存到多少才flush(即向数据库插入或更新),None表示在析构时flush,0表示不缓存
    # update_on_duplicate当唯一键重复时的行为,默认是update,设置为False表示不更新,即忽略插入失败。
    def __init__(self, engine, buffer_size=None, update_on_duplicate=True):
        pass

    # tablename为数据库表名
    # pk为主键的元组,可以不是真正的表主键,但是可以用来判重决定insert还是update,例如('exchange_id', 'trade_id')
    # data为单条数据,dict的形式,例如{'exchange_id': 'DCE', 'trade_id': '  1', 'price': 1.2, 'volume': 1}
    def upsert(self, tablename, pk, data):
        pass

    # 立即把缓冲器的数据推到数据库,会在buffer_size满了或者析构时自动调用,也可以手动调用
    def flush(self):
        pass
'''
class UpsertHandlerBase:
    def __init__(self, engine, buffer_size=None, update_on_duplicate=True):
        self.engine = engine
        self.tablename2pk = {}
        self.tablename2datas = {}
        self.buffer_size = buffer_size
        self.update_on_duplicate = update_on_duplicate

    def __del__(self):
        self.flush()

    def flush(self):
        for (tablename, pk) in self.tablename2pk.items():
            datas = self.tablename2datas[tablename]
            if len(datas) > 0:
                with self.engine.connect() as conn:
                    self._flush(conn, tablename, pk, datas)
                self.tablename2datas[tablename] = []

    def _flush(self, conn, tablename, pk, datas):
        columns = datas[0].keys()
        sql = f"""INSERT INTO {tablename}({", ".join(columns)}) VALUES\n"""
        for i, data in enumerate(datas):
            if i != len(datas) - 1:
                sql += f"""  ({self._format_values(data.values())}),\n"""
            else:
                sql += f"""  ({self._format_values(data.values())});\n"""
        try:
            conn.execute(sql)
        except sqlalchemy.exc.IntegrityError as e:
            if self.is_duplicate_key(e):
                # 插入遇到重复KEY
                if len(datas) <= 500:
                    for data in datas:
                        self.upsert_one(conn, tablename, pk, data)
                else:
                    l = len(datas)
                    p = int(l // 2)
                    self._flush(conn, tablename, pk, datas[:p])
                    self._flush(conn, tablename, pk, datas[p:])
            else:
                raise e

    def upsert_one(self, conn, tablename, pk, data):
        r = None
        if self.update_on_duplicate:
            update_str = self._format_update_values(pk, data)
        if self.update_on_duplicate and update_str.strip():
            r = conn.execute(f"UPDATE {tablename} SET {update_str} WHERE {self._format_update_conditions(pk, data)}")
        if not r or r.rowcount == 0:
            try:
                r = conn.execute(f"INSERT INTO {tablename}({', '.join(data.keys())}) VALUES({self._format_values(data.values())})")
            except sqlalchemy.exc.IntegrityError as e:
                if self.is_duplicate_key(e):
                    pass
                else:
                    raise e

    @staticmethod
    def is_duplicate_key(e):
        if type(e) != sqlalchemy.exc.IntegrityError:
            return False
        return (str(e.orig).lower().find("duplicate") != -1)

    def _isinf(self, x):
        return x>=9223372036854775807 or x<=-9223372036854775808

    def _format_value(self, v):
        if v is None:
            return "null"
        elif type(v) == float:
            if math.isnan(v) or math.isinf(v) or self._isinf(v):
                return "null"
            else:
                return f"{v}"
        elif type(v) == int:
            if self._isinf(v):
                return "null"
            else:
                return f"{v}"
        elif type(v) == datetime.datetime:
            return "'"+v.strftime("%Y-%m-%d %H:%M:%S")+"'"
        elif type(v) == datetime.date:
            return "'"+v.strftime("%Y-%m-%d")+"'"
        elif type(v) == pd.Timestamp:
            return "'"+v.strftime("%Y-%m-%d %H:%M:%S")+"'"
        elif type(v) == str:
            return f"'{v}'"
        else:
            return f"'{v}'"

    def _format_values(self, data):
        s = ''
        for i, e in enumerate(data):
            s += self._format_value(e)
            s += ''
        return s[:-2]

    def _format_update_values(self, pk, data):
        s = ''
        for i, (k, v) in enumerate(data.items()):
            if k not in pk:
                s += f"{k}={self._format_value(v)}, "
        return s[:-2]

    def _format_update_conditions(self, pk, data):
        s = ''
        for i, (k, v) in enumerate(data.items()):
            if k in pk:
                s += f"{k}={self._format_value(v)} and "
        return s[:-4]

    def upsert(self, tablename, pk, data):
        if self.buffer_size is not None and self.buffer_size == 0:
            with self.engine.connect() as conn:
                self.upsert_one(conn, tablename, pk, data)
        else:
            if pk:
                self.tablename2pk[tablename] = pk
            if tablename not in self.tablename2datas:
                self.tablename2datas[tablename] = []
            self.tablename2datas[tablename].append(data)

            if self.buffer_size is not None and len(self.tablename2datas[tablename]) >= self.buffer_size:
                with self.engine.connect() as conn:
                    self._flush(conn, tablename, self.tablename2pk[tablename], self.tablename2datas[tablename])
                self.tablename2datas[tablename] = []

    def upsert_dataframe(self, tablename, pk, df):
        if len(df) <= 2000:
            with self.engine.connect() as conn:
                for index, row in df.iterrows():
                    self.upsert_one(conn, tablename, pk, row.to_dict())
        else:
            l = len(df)
            p = int(l // 2)
            self.upsert_dataframe(tablename, pk, df[:p])
            self.upsert_dataframe(tablename, pk, df[p:])


class MySQLUpsertHandler(UpsertHandlerBase):
    def __init__(self, engine, buffer_size=None, update_on_duplicate=True):
        super().__init__(engine, buffer_size, update_on_duplicate)

    def __del__(self):
        super().__del__()

    @staticmethod
    def is_duplicate_key(e):
        if type(e) != sqlalchemy.exc.IntegrityError:
            return False
        if len(e.orig.args) > 1 and str(e.orig.args[1]).startswith("Duplicate entry"):
            return True
        return False

    def upsert_one(self, conn, tablename, pk, data):
        if self.update_on_duplicate:
            update_str = self._format_update_values(pk, data)
        if self.update_on_duplicate and update_str.strip():
            duplicate_do_str = f"UPDATE {update_str}"
        else:
            duplicate_do_str = f"UPDATE {pk[0]}=VALUES({pk[0]})"  # 等价于do nothing
        sql = f"""INSERT INTO {tablename}({", ".join(data.keys())}) VALUES
            ({self._format_values(data.values())})
            ON DUPLICATE KEY
            {duplicate_do_str}\n
"""
        conn.execute(sql)


class PSQLUpsertHandler(UpsertHandlerBase):
    def __init__(self, engine, buffer_size=None, update_on_duplicate=True):
        super().__init__(engine, buffer_size, update_on_duplicate)

    def __del__(self):
        super().__del__()

    @staticmethod
    def is_duplicate_key(e):
        if type(e) != sqlalchemy.exc.IntegrityError:
            return False
        if str(e.orig).startswith("duplicate key"):
            return True
        return False

    def upsert_one(self, conn, tablename, pk, data):
        if self.update_on_duplicate:
            update_str = self._format_update_values(pk, data)
        if self.update_on_duplicate and update_str.strip():
            duplicate_do_str = f"do update set {update_str}"
        else:
            duplicate_do_str = f"do nothing"
        sql = f"""INSERT INTO {tablename}({", ".join(data.keys())}) VALUES
            ({self._format_values(data.values())})
            on conflict ({", ".join(pk)})
            {duplicate_do_str}\n
"""
        conn.execute(sql)
只有注册用户登录后才能发表评论。