posts - 210, comments - 61, trackbacks - 0, articles - 0
   :: 首页 :: 新随笔 :: 联系 :: 聚合  :: 管理
# -*- coding: utf-8 -*-

import math
import datetime
import numpy as np
import pandas as pd
import sqlalchemy

'''
class UpsertHandler:

    # engine sqlalchemy的engine
    # buffer_size表示插入或更新数据缓存到多少才flush(即向数据库插入或更新),None表示在析构时flush,0表示不缓存
    def __init__(self, engine, buffer_size=None):
        pass

    # tablename为数据库表名
    # pk为主键的元组,可以不是真正的表主键,但是可以用来判重决定insert还是update,例如('exchange_id', 'trade_id')
    # data为单条数据,dict的形式,例如{'exchange_id': 'DCE', 'trade_id': '  1', 'price': 1.2, 'volume': 1}
    def upsert(self, tablename, pk, data):
        pass

    # 立即把缓冲器的数据推到数据库,会在buffer_size满了或者析构时自动调用,也可以手动调用
    def flush(self):
        pass
'''

def create_upsert_handler(engine, buffer_size=None):
    if engine.dialect.name.lower().find("mysql") != -1:
        return MySQLUpsertHandler(engine, buffer_size)
    elif engine.dialect.name.lower().find("postgresql") != -1:
        return PSQLUpsertHandler(engine, buffer_size)
    else:
        print(f"没有为{engine.dialect.name}实现特殊的Upsert,使用默认版本,请确认可以正常工作,建议特化一个专门版本")
        return UpsertHandlerBase(engine, buffer_size)


def is_duplicate_key(e):
    for T in UpsertHandlerBase.__subclasses__():
        if T.is_duplicate_key(e):
            return True
    return UpsertHandlerBase.is_duplicate_key(e)


class UpsertHandlerBase:
    def __init__(self, engine, buffer_size=5000):
        self.engine = engine
        self.tablename2pk = {}
        self.tablename2datas = {}
        self.buffer_size = buffer_size

    def __del__(self):
        self.flush()

    def flush(self):
        for (tablename, pk) in self.tablename2pk.items():
            datas = self.tablename2datas[tablename]
            if len(datas) > 0:
                with self.engine.connect() as conn:
                    self._flush(conn, tablename, pk, datas)
                self.tablename2datas[tablename] = []

    def _flush(self, conn, tablename, pk, datas):
        columns = datas[0].keys()
        sql = f"""INSERT INTO {tablename}({", ".join(columns)}) VALUES\n"""
        for i, data in enumerate(datas):
            if i != len(datas) - 1:
                sql += f"""  ({self._format_values(data.values())}),\n"""
            else:
                sql += f"""  ({self._format_values(data.values())});\n"""
        try:
            conn.execute(sql)
        except sqlalchemy.exc.IntegrityError as e:
            if self.is_duplicate_key(e):
                # 插入遇到重复KEY
                if len(datas) <= 500:
                    for data in datas:
                        self.upsert_one(conn, tablename, pk, data)
                else:
                    l = len(datas)
                    p = int(l // 2)
                    self._flush(conn, tablename, pk, datas[:p])
                    self._flush(conn, tablename, pk, datas[p:])
            else:
                raise e

    def upsert_one(self, conn, tablename, pk, data):
        r = conn.execute(f"UPDATE {tablename} SET {self._format_update_values(pk, data)} WHERE {self._format_update_conditions(pk, data)}")
        if not r or r.rowcount == 0:
            r = conn.execute(f"INSERT INTO {tablename}({', '.join(data.keys())}) VALUES({self._format_values(data.values())})")

    @staticmethod
    def is_duplicate_key(e):
        if type(e) != sqlalchemy.exc.IntegrityError:
            return False
        return (str(e.orig).lower().find("duplicate") != -1)

    def _isinf(self, x):
        return x>=9223372036854775807 or x<=-9223372036854775808

    def _format_value(self, v):
        if v is None:
            return "null"
        elif type(v) == float:
            if math.isnan(v) or math.isinf(v) or self._isinf(v):
                return "null"
            else:
                return f"{v}"
        elif type(v) == int:
            if self._isinf(v):
                return "null"
            else:
                return f"{v}"
        elif type(v) == datetime.datetime:
            return "'"+v.strftime("%Y-%m-%d %H:%M:%S")+"'"
        elif type(v) == datetime.date:
            return "'"+v.strftime("%Y-%m-%d")+"'"
        elif type(v) == pd.Timestamp:
            return "'"+v.strftime("%Y-%m-%d %H:%M:%S")+"'"
        elif type(v) == str:
            return f"'{v}'"
        else:
            return f"'{v}'"

    def _format_values(self, data):
        s = ''
        for i, e in enumerate(data):
            s += self._format_value(e)
            s += ''
        return s[:-2]

    def _format_update_values(self, pk, data):
        s = ''
        for i, (k, v) in enumerate(data.items()):
            if k not in pk:
                s += f"{k}={self._format_value(v)}, "
        return s[:-2]

    def _format_update_conditions(self, pk, data):
        s = ''
        for i, (k, v) in enumerate(data.items()):
            if k in pk:
                s += f"{k}={self._format_value(v)} and "
        return s[:-4]

    def upsert(self, tablename, pk, data):
        if self.buffer_size is not None and self.buffer_size == 0:
            with self.engine.connect() as conn:
                self.upsert_one(conn, tablename, pk, data)
        else:
            if pk:
                self.tablename2pk[tablename] = pk
            if tablename not in self.tablename2datas:
                self.tablename2datas[tablename] = []
            self.tablename2datas[tablename].append(data)

            if self.buffer_size is not None and len(self.tablename2datas[tablename]) >= self.buffer_size:
                with self.engine.connect() as conn:
                    self._flush(conn, tablename, self.tablename2pk[tablename], self.tablename2datas[tablename])
                self.tablename2datas[tablename] = []


class MySQLUpsertHandler(UpsertHandlerBase):
    def __init__(self, engine, buffer_size=None):
        super().__init__(engine, buffer_size)

    def __del__(self):
        super().__del__()

    @staticmethod
    def is_duplicate_key(e):
        if type(e) != sqlalchemy.exc.IntegrityError:
            return False
        if len(e.orig.args) > 1 and str(e.orig.args[1]).startswith("Duplicate entry"):
            return True
        return False

    def upsert_one(self, conn, tablename, pk, data):
        sql = f"""INSERT INTO {tablename}({", ".join(data.keys())}) VALUES
  ({self._format_values(data.values())})
  ON DUPLICATE KEY UPDATE
  {self._format_update_values(pk, data)}\n
"""
        conn.execute(sql)


class PSQLUpsertHandler(UpsertHandlerBase):
    def __init__(self, engine, buffer_size=None):
        super().__init__(engine, buffer_size)

    def __del__(self):
        super().__del__()

    @staticmethod
    def is_duplicate_key(e):
        if type(e) != sqlalchemy.exc.IntegrityError:
            return False
        if str(e.orig).startswith("duplicate key"):
            return True
        return False

    def upsert_one(self, conn, tablename, pk, data):
        sql = f"""INSERT INTO {tablename}({", ".join(data.keys())}) VALUES
  ({self._format_values(data.values())})
  on conflict ({", ".join(pk)})
  do update set {self._format_update_values(pk, data)}\n
"""
        conn.execute(sql)
只有注册用户登录后才能发表评论。