influxdb 有可视化的页面,数据处理起来挺方便的
放在site-packages\vnpy\database\influxdb2

influxdbv2下载放在site-packages\vnpy\database\influxdb2\bin:
influx.exe
influxd.exe

site-packages\vnpy\database\influxdb2__init__.py

from .influxdb2_database import database_manager

site-packages\vnpy\database\influxdb2\influxdb2_database.py

""""""
from datetime import datetime
from typing import List, Tuple, Dict
from pathlib import Path
import shelve
import pickle
import os
import inspect
import time
import psutil
import subprocess


from influxdb_client import InfluxDBClient, Point, WritePrecision
from influxdb_client.client.write_api import SYNCHRONOUS
from influxdb_client.domain.write_precision import WritePrecision

from vnpy.trader.constant import Exchange, Interval
from vnpy.trader.object import BarData, TickData
from vnpy.trader.database import (
    BaseDatabase,
    BarOverview,
    DB_TZ,
    convert_tz
)
from vnpy.trader.setting import SETTINGS
from vnpy.trader.utility import (
    generate_vt_symbol,
    extract_vt_symbol,
    extract_symbol,
    get_file_path,
    TRADER_DIR,
)


class Influxdb2Database(BaseDatabase):
    """"""
    overview_filename = "influxdb2_overview"
    overview_filepath = str(get_file_path(overview_filename))

    def __init__(self) -> None:
        """"""
        self.org = SETTINGS["database.user"]
        database = SETTINGS["database.database"]
        host = SETTINGS["database.host"]
        port = SETTINGS["database.port"]
        token = SETTINGS["database.authentication_source"]

        self.client = InfluxDBClient(f'http://{host}:{port}', token, org=self.org, timeout=60_000)
        while True:
            try:
                if self.client.ready().status == 'ready':
                    break
            except:
                self.start_influxd(host, port, database)
            time.sleep(1)

        self.write_api = self.client.write_api(write_options=SYNCHRONOUS)
        self.query_api = self.client.query_api()
        self.delete_api = self.client.delete_api()
        self.bucket_api = self.client.buckets_api()
        self.organizations_api = self.client.organizations_api()
        self.org_id = self.organizations_api.find_organizations(org=self.org)[0].id

        self.overviews: Dict[str, BarOverview] = shelve.open(self.overview_filepath, protocol=pickle.HIGHEST_PROTOCOL, writeback=True)

    def start_influxd(self, host, port, database):
        bin_name = 'influxd'
        if os.name == 'nt':
            bin_name += '.exe'
        if bin_name not in (p.name() for p in psutil.process_iter()):
            bin_path = str(Path(inspect.getfile(self.__class__)).parent.joinpath(f'bin/{bin_name}'))
            args = [bin_path, f'--http-bind-address={host}:{port}']
            if database:
                database_path = TRADER_DIR.joinpath(database)
                args += [f'--bolt-path={database_path}/influxd.bolt', f'--engine-path={database_path}/engine']
            if os.name == 'nt':
                DETACHED_PROCESS = 0x00000008
                subprocess.Popen(args, shell=True, close_fds=True, creationflags=DETACHED_PROCESS)
            else:
                os.system(f'nohup {" ".join(args)} &')

    def save_bar_data(self, bars: List[BarData]) -> bool:
        """"""
        bucket_points = {}
        key_bars = {}
        key_info = {}

        for bar in bars:
            code, date_str = extract_symbol(bar.symbol)
            bucket = f'{code}.{bar.exchange.value}'
            if bucket not in bucket_points:
                if self.bucket_api.find_bucket_by_name(bucket) is None:
                    self.bucket_api.create_bucket(bucket_name=bucket, org_id=self.org_id)
                bucket_points[bucket] = []

            point = (
                Point(measurement_name=date_str)
                .tag('interval', bar.interval.value)
                .field('open_price', bar.open_price)
                .field('high_price', bar.high_price)
                .field('low_price', bar.low_price)
                .field('close_price', bar.close_price)
                .field('volume', bar.volume)
                .field('open_interest', bar.open_interest)
                .time(bar.datetime.isoformat(), write_precision=WritePrecision.MS)
            )

            bucket_points[bucket].append(point)

            key = f'{bar.vt_symbol}_{bar.interval.value}'
            if key not in key_bars:
                key_bars[key] = []
                key_info[key] = (bucket, date_str)
            if len(key_bars[key]) < 2:
                key_bars[key].append(bar)
            else:
                key_bars[key][-1] = bar

        for bucket, record in bucket_points.items():
            n = 1000
            for i in range(0, len(record), n):
                self.write_api.write(bucket=bucket, record=record[i:i + n])

        # Update bar overview
        for key, bars in key_bars.items():
            overview = self.overviews.get(key, None)

            if not overview:
                overview = BarOverview(
                    symbol=bars[0].symbol,
                    exchange=bars[0].exchange,
                    interval=bars[0].interval
                )
                overview.count = len(bars)
                overview.start = bars[0].datetime
                overview.end = bars[-1].datetime
            else:
                overview.start = min(overview.start, bars[0].datetime)
                overview.end = max(overview.end, bars[-1].datetime)

                bucket, date_str = key_info[key]
                query = (
                    f'from(bucket: "{bucket}")'
                    f'  |> range(start: 0)'
                    f'  |> filter(fn: (r) => r._measurement == "{date_str}" and r.interval == "{bars[0].interval.value}")'
                    f'  |> count()'
                )

                overview.count = self.query_api.query(query)[0].records[0].get_value()

            self.overviews[key] = overview
        pass

    def save_tick_data(self, ticks: List[TickData]) -> bool:
        """"""
        bucket_points = {}

        for tick in ticks:
            code, date_str = extract_symbol(tick.symbol)
            bucket = f'{code}.{tick.exchange.value}'
            if bucket not in bucket_points:
                if self.bucket_api.find_bucket_by_name(bucket) is None:
                    self.bucket_api.create_bucket(bucket_name=bucket, org_id=self.org_id)
                    bucket_points[bucket] = []

            point = (
                Point(measurement_name=date_str)
                .tag('interval', Interval.TICK.value)
                .field('name', tick.name)
                .field('volume', tick.volume)
                .field('open_interest', tick.open_interest)
                .field('last_price', tick.last_price)
                .field('last_volume', tick.last_volume)
                .field('limit_up', tick.limit_up)
                .field('limit_down', tick.limit_down)

                .field('open_price', tick.open_price)
                .field('high_price', tick.high_price)
                .field('low_price', tick.low_price)
                .field('pre_close', tick.pre_close)

                .field('bid_price_1', tick.bid_price_1)
                .field('bid_price_2', tick.bid_price_2)
                .field('bid_price_3', tick.bid_price_3)
                .field('bid_price_4', tick.bid_price_4)
                .field('bid_price_5', tick.bid_price_5)

                .field('ask_price_1', tick.ask_price_1)
                .field('ask_price_2', tick.ask_price_2)
                .field('ask_price_3', tick.ask_price_3)
                .field('ask_price_4', tick.ask_price_4)
                .field('ask_price_5', tick.ask_price_5)

                .field('bid_volume_1', tick.bid_volume_1)
                .field('bid_volume_2', tick.bid_volume_2)
                .field('bid_volume_3', tick.bid_volume_3)
                .field('bid_volume_4', tick.bid_volume_4)
                .field('bid_volume_5', tick.bid_volume_5)

                .field('ask_volume_1', tick.ask_volume_1)
                .field('ask_volume_2', tick.ask_volume_2)
                .field('ask_volume_3', tick.ask_volume_3)
                .field('ask_volume_4', tick.ask_volume_4)
                .field('ask_volume_5', tick.ask_volume_5)
                .time(tick.datetime.isoformat(), write_precision=WritePrecision.MS)
            )

            bucket_points[bucket].append(point)

        for bucket, record in bucket_points.items():
            n = 1000
            for i in range(0, len(record), n):
                self.write_api.write(bucket=bucket, record=record[i:i + n])


    def load_bar_data(
        self,
        symbol: str,
        exchange: Exchange,
        interval: Interval,
        start: datetime,
        end: datetime
    ) -> List[BarData]:
        """"""
        code, date_str = extract_symbol(symbol)
        bucket = f'{code}.{exchange.value}'
        query = (
            f'from(bucket: "{bucket}")'
            #f'  |> range(start: -7d)'
            f'  |> range(start: time(v: "{start.astimezone(DB_TZ).isoformat()}"), stop: time(v: "{end.astimezone(DB_TZ).isoformat()}"))'
            f'  |> filter(fn: (r) => r._measurement == "{date_str}" and r.interval == "{interval.value}")'
            f'  |> drop(columns: ["_start", "_stop", "_measurement", "interval"])'
            f'  |> pivot(rowKey:["_time"], columnKey: ["_field"], valueColumn: "_value")'
        )

        bars: List[BarData] = []

        for tb in self.query_api.query(query):
            for row in tb.records:
                bar = BarData(
                    symbol=symbol,
                    exchange=exchange,
                    interval=interval,
                    datetime=row.get_time().astimezone(DB_TZ),
                    open_price=row['open_price'],
                    high_price=row['high_price'],
                    low_price=row['low_price'],
                    close_price=row['close_price'],
                    volume=row['volume'],
                    open_interest=row['open_interest'],
                    gateway_name="DB"
                )
                bars.append(bar)

        return bars

    def load_tick_data(
        self,
        symbol: str,
        exchange: Exchange,
        start: datetime,
        end: datetime
    ) -> List[TickData]:
        """"""
        code, date_str = extract_symbol(symbol)
        bucket = f'{code}.{exchange.value}'
        query = (
            f'from(bucket: "{bucket}")'
            f'  |> range(start: time(v: "{start.astimezone(DB_TZ).isoformat()}"), stop: time(v: "{end.astimezone(DB_TZ).isoformat()}"))'
            f'  |> filter(fn: (r) => r._measurement == "{date_str}" and r.interval == "{Interval.TICK.value}")'
            f'  |> drop(columns: ["_start", "_stop", "_measurement", "interval"])'
            f'  |> pivot(rowKey:["_time"], columnKey: ["_field"], valueColumn: "_value")'
        )

        ticks: List[TickData] = []

        for tb in self.query_api.query(query):
            for row in tb.records:
                tick = TickData(
                    symbol=symbol,
                    exchange=exchange,
                    datetime=row.get_time().astimezone(DB_TZ),
                    name=row['name'],
                    volume=row["volume"],
                    open_interest=row["open_interest"],
                    last_price=row["last_price"],
                    last_volume=row["last_volume"],
                    limit_up=row["limit_up"],
                    limit_down=row["limit_down"],
                    open_price=row["open_price"],
                    high_price=row["high_price"],
                    low_price=row["low_price"],
                    pre_close=row["pre_close"],
                    bid_price_1=row["bid_price_1"],
                    bid_price_2=row["bid_price_2"],
                    bid_price_3=row["bid_price_3"],
                    bid_price_4=row["bid_price_4"],
                    bid_price_5=row["bid_price_5"],
                    ask_price_1=row["ask_price_1"],
                    ask_price_2=row["ask_price_2"],
                    ask_price_3=row["ask_price_3"],
                    ask_price_4=row["ask_price_4"],
                    ask_price_5=row["ask_price_5"],
                    bid_volume_1=row["bid_volume_1"],
                    bid_volume_2=row["bid_volume_2"],
                    bid_volume_3=row["bid_volume_3"],
                    bid_volume_4=row["bid_volume_4"],
                    bid_volume_5=row["bid_volume_5"],
                    ask_volume_1=row["ask_volume_1"],
                    ask_volume_2=row["ask_volume_2"],
                    ask_volume_3=row["ask_volume_3"],
                    ask_volume_4=row["ask_volume_4"],
                    ask_volume_5=row["ask_volume_5"],
                    gateway_name="DB"
                )
                ticks.append(tick)

        return ticks


    def delete_bar_data(
        self,
        symbol: str,
        exchange: Exchange,
        interval: Interval
    ) -> int:
        """"""
        code, date_str = extract_symbol(symbol)
        bucket = f'{code}.{exchange.value}'

        # Query data count
        query = (
            f'from(bucket: "{bucket}")'
            f'  |> range(start: 0)'
            f'  |> filter(fn: (r) => r._measurement == "{date_str}" and r.interval == "{interval.value}")'
            f'  |> count()'
        )

        count = 0
        for tb in self.query_api.query(query):
            for row in tb.records:
                count = row.get_value()

        # Delete data
        self.delete_api.delete(
            datetime.fromtimestamp(0, tz=DB_TZ).isoformat(),
            datetime.now(tz=DB_TZ).isoformat(),
            f'_measurement="{date_str}" and interval="{interval.value}"',
            bucket=bucket,
            org=self.org,
        )

        # Delete overview
        vt_symbol = generate_vt_symbol(symbol, exchange)
        key = f"{vt_symbol}_{interval.value}"
        if key in self.overviews:
            self.overviews.pop(key)

        return count

    def delete_tick_data(
        self,
        symbol: str,
        exchange: Exchange
    ) -> int:
        """"""
        code, date_str = extract_symbol(symbol)
        bucket = f'{code}.{exchange.value}'

        # Query data count
        query = (
            f'from(bucket: "{bucket}")'
            f'  |> range(start: 0)'
            f'  |> filter(fn: (r) => r._measurement == "{date_str}" and r.interval == "{Interval.TICK.value}")'
            f'  |> count()'
        )

        count = 0
        for tb in self.query_api.query(query):
            for row in tb.records:
                count = row.get_value()

        # Delete data
        self.delete_api.delete(
            datetime.fromtimestamp(0, tz=DB_TZ).isoformat(),
            datetime.now(tz=DB_TZ).isoformat(),
            f'_measurement="{date_str}" and interval="{Interval.TICK.value}"',
            bucket=bucket,
            org=self.org,
        )

        return count


    def get_bar_overview(self) -> List[BarOverview]:
        """
        Return data avaible in database.
        """
        # Init bar overview if not exists
        buckets = set()
        last_id = ''
        while True:
            _buckets = self.bucket_api.find_buckets(after=last_id, limit=100).buckets
            if not _buckets:
                break
            buckets.update([bucket.name for bucket in _buckets if not bucket.name.startswith('_')])
            last_id = _buckets[-1].id

        overview_buckets = set()
        for key in self.overviews.keys():
            symbol, exchange = extract_vt_symbol(key.split('_')[0])
            code, date_str = extract_symbol(symbol)
            overview_buckets.add(f'{code}.{exchange.value}')

        if buckets != overview_buckets:
            self.overviews.clear()
            for bucket in buckets:
                code, exchange = bucket.split('.')
                query = (
                    f'from(bucket: "{bucket}")'
                    f'  |> range(start: 0)'
                    f'  |> group(columns: ["_measurement", "interval"])'
                    f'  |> count()'
                )

                for data in self.query_api.query(query):
                    for row in data.records:
                        date_str = row.get_measurement()
                        interval = row['interval']
                        symbol = f'{code}{date_str}'
                        key = f'{symbol}.{exchange}_{interval}'

                        overview = BarOverview(
                            symbol=symbol,
                            exchange=Exchange(exchange),
                            interval=Interval(interval),
                            count=int(row.get_value() / (30 if interval == Interval.TICK else 6))
                        )
                        overview.start = self.get_bar_datetime(bucket, date_str, interval, 'first')
                        overview.end = self.get_bar_datetime(bucket, date_str, interval, 'last')

                        self.overviews[key] = overview

        return list(self.overviews.values())

    def get_bar_datetime(self, bucket: str, measurement: str, interval: str, order: str) -> Tuple[datetime, datetime]:
        """"""
        query = (
            f'from(bucket: "{bucket}")'
            f'  |> range(start: 0)'
            f'  |> filter(fn: (r) => r._measurement == "{measurement}" and r.interval == "{interval}")'
            f'  |> {order}()'
        )
        return self.query_api.query(query)[0].records[0].get_time().astimezone(DB_TZ)


database_manager = Influxdb2Database()