influxdb 有可视化的页面,数据处理起来挺方便的
放在site-packages\vnpy\database\influxdb2
influxdbv2下载放在site-packages\vnpy\database\influxdb2\bin:
influx.exe
influxd.exe
site-packages\vnpy\database\influxdb2__init__.py
from .influxdb2_database import database_manager
site-packages\vnpy\database\influxdb2\influxdb2_database.py
""""""
from datetime import datetime
from typing import List, Tuple, Dict
from pathlib import Path
import shelve
import pickle
import os
import inspect
import time
import psutil
import subprocess
from influxdb_client import InfluxDBClient, Point, WritePrecision
from influxdb_client.client.write_api import SYNCHRONOUS
from influxdb_client.domain.write_precision import WritePrecision
from vnpy.trader.constant import Exchange, Interval
from vnpy.trader.object import BarData, TickData
from vnpy.trader.database import (
BaseDatabase,
BarOverview,
DB_TZ,
convert_tz
)
from vnpy.trader.setting import SETTINGS
from vnpy.trader.utility import (
generate_vt_symbol,
extract_vt_symbol,
extract_symbol,
get_file_path,
TRADER_DIR,
)
class Influxdb2Database(BaseDatabase):
""""""
overview_filename = "influxdb2_overview"
overview_filepath = str(get_file_path(overview_filename))
def __init__(self) -> None:
""""""
self.org = SETTINGS["database.user"]
database = SETTINGS["database.database"]
host = SETTINGS["database.host"]
port = SETTINGS["database.port"]
token = SETTINGS["database.authentication_source"]
self.client = InfluxDBClient(f'http://{host}:{port}', token, org=self.org, timeout=60_000)
while True:
try:
if self.client.ready().status == 'ready':
break
except:
self.start_influxd(host, port, database)
time.sleep(1)
self.write_api = self.client.write_api(write_options=SYNCHRONOUS)
self.query_api = self.client.query_api()
self.delete_api = self.client.delete_api()
self.bucket_api = self.client.buckets_api()
self.organizations_api = self.client.organizations_api()
self.org_id = self.organizations_api.find_organizations(org=self.org)[0].id
self.overviews: Dict[str, BarOverview] = shelve.open(self.overview_filepath, protocol=pickle.HIGHEST_PROTOCOL, writeback=True)
def start_influxd(self, host, port, database):
bin_name = 'influxd'
if os.name == 'nt':
bin_name += '.exe'
if bin_name not in (p.name() for p in psutil.process_iter()):
bin_path = str(Path(inspect.getfile(self.__class__)).parent.joinpath(f'bin/{bin_name}'))
args = [bin_path, f'--http-bind-address={host}:{port}']
if database:
database_path = TRADER_DIR.joinpath(database)
args += [f'--bolt-path={database_path}/influxd.bolt', f'--engine-path={database_path}/engine']
if os.name == 'nt':
DETACHED_PROCESS = 0x00000008
subprocess.Popen(args, shell=True, close_fds=True, creationflags=DETACHED_PROCESS)
else:
os.system(f'nohup {" ".join(args)} &')
def save_bar_data(self, bars: List[BarData]) -> bool:
""""""
bucket_points = {}
key_bars = {}
key_info = {}
for bar in bars:
code, date_str = extract_symbol(bar.symbol)
bucket = f'{code}.{bar.exchange.value}'
if bucket not in bucket_points:
if self.bucket_api.find_bucket_by_name(bucket) is None:
self.bucket_api.create_bucket(bucket_name=bucket, org_id=self.org_id)
bucket_points[bucket] = []
point = (
Point(measurement_name=date_str)
.tag('interval', bar.interval.value)
.field('open_price', bar.open_price)
.field('high_price', bar.high_price)
.field('low_price', bar.low_price)
.field('close_price', bar.close_price)
.field('volume', bar.volume)
.field('open_interest', bar.open_interest)
.time(bar.datetime.isoformat(), write_precision=WritePrecision.MS)
)
bucket_points[bucket].append(point)
key = f'{bar.vt_symbol}_{bar.interval.value}'
if key not in key_bars:
key_bars[key] = []
key_info[key] = (bucket, date_str)
if len(key_bars[key]) < 2:
key_bars[key].append(bar)
else:
key_bars[key][-1] = bar
for bucket, record in bucket_points.items():
n = 1000
for i in range(0, len(record), n):
self.write_api.write(bucket=bucket, record=record[i:i + n])
# Update bar overview
for key, bars in key_bars.items():
overview = self.overviews.get(key, None)
if not overview:
overview = BarOverview(
symbol=bars[0].symbol,
exchange=bars[0].exchange,
interval=bars[0].interval
)
overview.count = len(bars)
overview.start = bars[0].datetime
overview.end = bars[-1].datetime
else:
overview.start = min(overview.start, bars[0].datetime)
overview.end = max(overview.end, bars[-1].datetime)
bucket, date_str = key_info[key]
query = (
f'from(bucket: "{bucket}")'
f' |> range(start: 0)'
f' |> filter(fn: (r) => r._measurement == "{date_str}" and r.interval == "{bars[0].interval.value}")'
f' |> count()'
)
overview.count = self.query_api.query(query)[0].records[0].get_value()
self.overviews[key] = overview
pass
def save_tick_data(self, ticks: List[TickData]) -> bool:
""""""
bucket_points = {}
for tick in ticks:
code, date_str = extract_symbol(tick.symbol)
bucket = f'{code}.{tick.exchange.value}'
if bucket not in bucket_points:
if self.bucket_api.find_bucket_by_name(bucket) is None:
self.bucket_api.create_bucket(bucket_name=bucket, org_id=self.org_id)
bucket_points[bucket] = []
point = (
Point(measurement_name=date_str)
.tag('interval', Interval.TICK.value)
.field('name', tick.name)
.field('volume', tick.volume)
.field('open_interest', tick.open_interest)
.field('last_price', tick.last_price)
.field('last_volume', tick.last_volume)
.field('limit_up', tick.limit_up)
.field('limit_down', tick.limit_down)
.field('open_price', tick.open_price)
.field('high_price', tick.high_price)
.field('low_price', tick.low_price)
.field('pre_close', tick.pre_close)
.field('bid_price_1', tick.bid_price_1)
.field('bid_price_2', tick.bid_price_2)
.field('bid_price_3', tick.bid_price_3)
.field('bid_price_4', tick.bid_price_4)
.field('bid_price_5', tick.bid_price_5)
.field('ask_price_1', tick.ask_price_1)
.field('ask_price_2', tick.ask_price_2)
.field('ask_price_3', tick.ask_price_3)
.field('ask_price_4', tick.ask_price_4)
.field('ask_price_5', tick.ask_price_5)
.field('bid_volume_1', tick.bid_volume_1)
.field('bid_volume_2', tick.bid_volume_2)
.field('bid_volume_3', tick.bid_volume_3)
.field('bid_volume_4', tick.bid_volume_4)
.field('bid_volume_5', tick.bid_volume_5)
.field('ask_volume_1', tick.ask_volume_1)
.field('ask_volume_2', tick.ask_volume_2)
.field('ask_volume_3', tick.ask_volume_3)
.field('ask_volume_4', tick.ask_volume_4)
.field('ask_volume_5', tick.ask_volume_5)
.time(tick.datetime.isoformat(), write_precision=WritePrecision.MS)
)
bucket_points[bucket].append(point)
for bucket, record in bucket_points.items():
n = 1000
for i in range(0, len(record), n):
self.write_api.write(bucket=bucket, record=record[i:i + n])
def load_bar_data(
self,
symbol: str,
exchange: Exchange,
interval: Interval,
start: datetime,
end: datetime
) -> List[BarData]:
""""""
code, date_str = extract_symbol(symbol)
bucket = f'{code}.{exchange.value}'
query = (
f'from(bucket: "{bucket}")'
#f' |> range(start: -7d)'
f' |> range(start: time(v: "{start.astimezone(DB_TZ).isoformat()}"), stop: time(v: "{end.astimezone(DB_TZ).isoformat()}"))'
f' |> filter(fn: (r) => r._measurement == "{date_str}" and r.interval == "{interval.value}")'
f' |> drop(columns: ["_start", "_stop", "_measurement", "interval"])'
f' |> pivot(rowKey:["_time"], columnKey: ["_field"], valueColumn: "_value")'
)
bars: List[BarData] = []
for tb in self.query_api.query(query):
for row in tb.records:
bar = BarData(
symbol=symbol,
exchange=exchange,
interval=interval,
datetime=row.get_time().astimezone(DB_TZ),
open_price=row['open_price'],
high_price=row['high_price'],
low_price=row['low_price'],
close_price=row['close_price'],
volume=row['volume'],
open_interest=row['open_interest'],
gateway_name="DB"
)
bars.append(bar)
return bars
def load_tick_data(
self,
symbol: str,
exchange: Exchange,
start: datetime,
end: datetime
) -> List[TickData]:
""""""
code, date_str = extract_symbol(symbol)
bucket = f'{code}.{exchange.value}'
query = (
f'from(bucket: "{bucket}")'
f' |> range(start: time(v: "{start.astimezone(DB_TZ).isoformat()}"), stop: time(v: "{end.astimezone(DB_TZ).isoformat()}"))'
f' |> filter(fn: (r) => r._measurement == "{date_str}" and r.interval == "{Interval.TICK.value}")'
f' |> drop(columns: ["_start", "_stop", "_measurement", "interval"])'
f' |> pivot(rowKey:["_time"], columnKey: ["_field"], valueColumn: "_value")'
)
ticks: List[TickData] = []
for tb in self.query_api.query(query):
for row in tb.records:
tick = TickData(
symbol=symbol,
exchange=exchange,
datetime=row.get_time().astimezone(DB_TZ),
name=row['name'],
volume=row["volume"],
open_interest=row["open_interest"],
last_price=row["last_price"],
last_volume=row["last_volume"],
limit_up=row["limit_up"],
limit_down=row["limit_down"],
open_price=row["open_price"],
high_price=row["high_price"],
low_price=row["low_price"],
pre_close=row["pre_close"],
bid_price_1=row["bid_price_1"],
bid_price_2=row["bid_price_2"],
bid_price_3=row["bid_price_3"],
bid_price_4=row["bid_price_4"],
bid_price_5=row["bid_price_5"],
ask_price_1=row["ask_price_1"],
ask_price_2=row["ask_price_2"],
ask_price_3=row["ask_price_3"],
ask_price_4=row["ask_price_4"],
ask_price_5=row["ask_price_5"],
bid_volume_1=row["bid_volume_1"],
bid_volume_2=row["bid_volume_2"],
bid_volume_3=row["bid_volume_3"],
bid_volume_4=row["bid_volume_4"],
bid_volume_5=row["bid_volume_5"],
ask_volume_1=row["ask_volume_1"],
ask_volume_2=row["ask_volume_2"],
ask_volume_3=row["ask_volume_3"],
ask_volume_4=row["ask_volume_4"],
ask_volume_5=row["ask_volume_5"],
gateway_name="DB"
)
ticks.append(tick)
return ticks
def delete_bar_data(
self,
symbol: str,
exchange: Exchange,
interval: Interval
) -> int:
""""""
code, date_str = extract_symbol(symbol)
bucket = f'{code}.{exchange.value}'
# Query data count
query = (
f'from(bucket: "{bucket}")'
f' |> range(start: 0)'
f' |> filter(fn: (r) => r._measurement == "{date_str}" and r.interval == "{interval.value}")'
f' |> count()'
)
count = 0
for tb in self.query_api.query(query):
for row in tb.records:
count = row.get_value()
# Delete data
self.delete_api.delete(
datetime.fromtimestamp(0, tz=DB_TZ).isoformat(),
datetime.now(tz=DB_TZ).isoformat(),
f'_measurement="{date_str}" and interval="{interval.value}"',
bucket=bucket,
org=self.org,
)
# Delete overview
vt_symbol = generate_vt_symbol(symbol, exchange)
key = f"{vt_symbol}_{interval.value}"
if key in self.overviews:
self.overviews.pop(key)
return count
def delete_tick_data(
self,
symbol: str,
exchange: Exchange
) -> int:
""""""
code, date_str = extract_symbol(symbol)
bucket = f'{code}.{exchange.value}'
# Query data count
query = (
f'from(bucket: "{bucket}")'
f' |> range(start: 0)'
f' |> filter(fn: (r) => r._measurement == "{date_str}" and r.interval == "{Interval.TICK.value}")'
f' |> count()'
)
count = 0
for tb in self.query_api.query(query):
for row in tb.records:
count = row.get_value()
# Delete data
self.delete_api.delete(
datetime.fromtimestamp(0, tz=DB_TZ).isoformat(),
datetime.now(tz=DB_TZ).isoformat(),
f'_measurement="{date_str}" and interval="{Interval.TICK.value}"',
bucket=bucket,
org=self.org,
)
return count
def get_bar_overview(self) -> List[BarOverview]:
"""
Return data avaible in database.
"""
# Init bar overview if not exists
buckets = set()
last_id = ''
while True:
_buckets = self.bucket_api.find_buckets(after=last_id, limit=100).buckets
if not _buckets:
break
buckets.update([bucket.name for bucket in _buckets if not bucket.name.startswith('_')])
last_id = _buckets[-1].id
overview_buckets = set()
for key in self.overviews.keys():
symbol, exchange = extract_vt_symbol(key.split('_')[0])
code, date_str = extract_symbol(symbol)
overview_buckets.add(f'{code}.{exchange.value}')
if buckets != overview_buckets:
self.overviews.clear()
for bucket in buckets:
code, exchange = bucket.split('.')
query = (
f'from(bucket: "{bucket}")'
f' |> range(start: 0)'
f' |> group(columns: ["_measurement", "interval"])'
f' |> count()'
)
for data in self.query_api.query(query):
for row in data.records:
date_str = row.get_measurement()
interval = row['interval']
symbol = f'{code}{date_str}'
key = f'{symbol}.{exchange}_{interval}'
overview = BarOverview(
symbol=symbol,
exchange=Exchange(exchange),
interval=Interval(interval),
count=int(row.get_value() / (30 if interval == Interval.TICK else 6))
)
overview.start = self.get_bar_datetime(bucket, date_str, interval, 'first')
overview.end = self.get_bar_datetime(bucket, date_str, interval, 'last')
self.overviews[key] = overview
return list(self.overviews.values())
def get_bar_datetime(self, bucket: str, measurement: str, interval: str, order: str) -> Tuple[datetime, datetime]:
""""""
query = (
f'from(bucket: "{bucket}")'
f' |> range(start: 0)'
f' |> filter(fn: (r) => r._measurement == "{measurement}" and r.interval == "{interval}")'
f' |> {order}()'
)
return self.query_api.query(query)[0].records[0].get_time().astimezone(DB_TZ)
database_manager = Influxdb2Database()