from typing import TextIO,List,Dict, Union
from datetime import datetime,timedelta
from datetime import time as dtime
from time import time
import numpy as np
import pandas as pd
import csv
import os
from multiprocessing import cpu_count
from concurrent.futures import ProcessPoolExecutor
from vnpy.trader.database import database_manager
from vnpy.trader.constant import (Exchange, Interval)
from vnpy.trader.object import (BarData, TickData)
from vnpy.trader.utility import (load_json, save_json,remain_alpha,extract_vt_symbol,round_to)
from peewee import chunked
from vnpy.app.cta_strategy.usual_method import CheckFutureNews
from vnpy.event import EventEngine
from vnpy.trader.engine import MainEngine
one_token_data_path = CheckFutureNews().one_token_data_path
LINK_SIGN = CheckFutureNews().link_sign #路径连接符
class PostgresOperation():
def __init__(self):
pass
#--------------------------------------------------------------------------------------------
#读取bar数据
def load_bar_data(self,symbol: str, exchange: Exchange, interval: Interval,start_time,end_time):
""""""
return database_manager.load_bar_data(symbol, exchange, interval, start_time, end_time )
#--------------------------------------------------------------------------------------------
#读取tick数据
def load_tick_data(self,symbol: str, exchange: Exchange,start_time,end_time):
""""""
return database_manager.load_tick_data(symbol, exchange,start_time, end_time )
#--------------------------------------------------------------------------------------------
#保存日线数据
def save_day_data(self, file_path: TextIO, symbol: str, exchange: Exchange,update_bar : bool, interval: Interval = Interval.DAILY):
"""
load by text mode file handle
"""
bars = []
start_time = None
count = 0
time_consuming_start = time()
with open(file_path,'r',encoding="utf-8") as f1:
reader = csv.reader((line.replace('\0','') for line in f1),delimiter=",") #line.replace移除NULL值
for item in reader:
if len(item[0]) >= 10:
bar = BarData(
symbol=symbol,
exchange=exchange,
datetime = datetime.strptime(item[0], '%Y/%m/%d'),
interval=interval,
open_price=item[1],
high_price=item[2],
low_price=item[3],
close_price=item[4],
volume=item[5],
open_interest=item[6],
gateway_name = "DATABASE"
)
bars.append(bar)
count += 1
if not start_time:
start_time = bar.datetime
end_time = bar.datetime
for bar_data in chunked(bars, 10000): #分批保存数据
database_manager.save_bar_data(bar_data,update_bar) #保存数据到数据库
time_consuming_end =time()
print(f'载入{bar.vt_symbol}:日线数据,开始时间:{start_time},结束时间:{end_time},数据量:{count},耗时:{round(time_consuming_end-time_consuming_start,3)}秒')
#--------------------------------------------------------------------------------------------
def save_bar_data(self, file_path: TextIO, symbol: str, exchange: Exchange,update_bar : bool, interval: Interval = Interval.MINUTE):
"""
保存bar数据
"""
bars = []
start_time = None
count = 0
time_consuming_start = time()
with open(file_path,'r',encoding="utf-8") as f1:
reader = csv.reader((line.replace('\0','') for line in f1),delimiter=",") #line.replace移除NULL值
for item in reader:
if len(item[0]) >= 10:
bar = BarData(
symbol=symbol,
exchange=exchange,
datetime = datetime.strptime(item[0], '%Y/%m/%d %H:%M'),
interval=interval,
open_price=item[1],
high_price=item[2],
low_price=item[3],
close_price=item[4],
volume=item[5],
open_interest=item[6],
gateway_name = "DATABASE"
)
if not start_time:
start_time = bar.datetime
bars.append(bar)
count += 1
end_time = bar.datetime
for bar_data in chunked(bars, 10000): #分批保存数据
database_manager.save_bar_data(bar_data,update_bar) #保存数据到数据库
time_consuming_end =time()
print(f'载入{bar.vt_symbol}:bar数据,开始时间:{start_time},结束时间:{end_time},数据量:{count},耗时:{round(time_consuming_end-time_consuming_start,3)}秒')
#--------------------------------------------------------------------------------------------
def save_json_data(self, file_path: TextIO, symbol: str, exchange: Exchange,update_bar : bool, interval: Interval = Interval.MINUTE):
"""
保存1token导出的json数据
"""
bars = []
start_time = None
count = 0
time_consuming_start = time()
#读取json数据
reader = load_json(file_path)
for item in reader:
bar = BarData(
symbol=symbol,
exchange=exchange,
datetime = datetime.fromtimestamp(item['timestamp']),
interval=interval,
open_price=item['open'],
high_price=item['high'],
low_price=item['low'],
close_price=item['close'],
volume=item['volume'],
gateway_name = "DATABASE"
)
if not start_time:
start_time = bar.datetime
bars.append(bar)
count += 1
end_time = bar.datetime
for bar_data in chunked(bars, 10000): #分批保存数据
database_manager.save_bar_data(bar_data,update_bar) #保存数据到数据库
time_consuming_end =time()
print(f'载入{bar.vt_symbol}:bar数据,开始时间:{start_time},结束时间:{end_time},数据量:{count},耗时:{round(time_consuming_end-time_consuming_start,3)}秒')
#--------------------------------------------------------------------------------------------
def save_postgres_csv(self,file_path,update_bar : bool, interval: Interval = Interval.MINUTE):
"""
保存postgres导出的bar csv数据
"""
bars = []
start_time = None
count = 0
time_consuming_start = time()
with open(file_path,'r',encoding="utf-8") as f1:
reader = csv.reader((line.replace('\0','') for line in f1),delimiter=",") #line.replace移除NULL值
for item in reader:
if len(item[3]) >= 10:
bar = BarData(
symbol=item[1],
exchange=Exchange(item[2]),
interval=interval,
open_price=item[7],
high_price=item[8],
low_price=item[9],
close_price=item[10],
volume=item[5],
open_interest = item[6],
gateway_name = "DATABASE"
)
if "." in item[3]:
bar.datetime = datetime.strptime(item[3], "%Y-%m-%d %H:%M:%S.%f")
else:
bar.datetime = datetime.strptime(item[3], "%Y-%m-%d %H:%M:%S")
if not start_time:
start_time = bar.datetime
bars.append(bar)
count += 1
end_time = bar.datetime
for bar_data in chunked(bars, 10000): #分批保存数据
database_manager.save_bar_data(bar_data,update_bar) #保存数据到数据库
time_consuming_end =time()
print(f'载入postgres csv分钟数据,开始时间:{start_time},结束时间:{end_time},数据量:{count},耗时:{round(time_consuming_end-time_consuming_start,3)}秒')
#--------------------------------------------------------------------------------------------
def save_postgres_tick(self,file_path: TextIO,update_tick : bool):
"""
保存postgres导出的tick csv数据
"""
ticks = []
start_time = None
count = 0
time_consuming_start = time()
with open(file_path,'r',encoding="utf-8") as f1:
reader = csv.reader((line.replace('\0','') for line in f1),delimiter=",") #line.replace移除NULL值
for item in reader:
if len(item[3]) >= 10:
tick = TickData(
symbol=item[1],
exchange=Exchange(item[2]),
name = item[4],
volume=item[5],
open_interest = item[6],
last_price = item[7],
open_price = item[11],
high_price = item[12],
low_price = item[13],
pre_close = item[14],
bid_price_1 = item[15],
bid_price_2 = item[16],
bid_price_3 = item[17],
bid_price_4 = item[18],
bid_price_5 = item[19],
ask_price_1 = item[20],
ask_price_2 = item[21],
ask_price_3 = item[22],
ask_price_4 = item[23],
ask_price_5 = item[24],
bid_volume_1 = item[25],
bid_volume_2 = item[26],
bid_volume_3 = item[27],
bid_volume_4 = item[28],
bid_volume_5 = item[29],
ask_volume_1 = item[30],
ask_volume_2 = item[31],
ask_volume_3 = item[32],
ask_volume_4 = item[33],
ask_volume_5 = item[34],
gateway_name = "DATABASE"
)
if "." in item[3]:
tick.datetime = datetime.strptime(item[3], "%Y-%m-%d %H:%M:%S.%f")
else:
tick.datetime = datetime.strptime(item[3], "%Y-%m-%d %H:%M:%S")
#tick.bid_price_2没有数据,跳过本次循环
if not tick.bid_price_2:
continue
if not start_time:
start_time = tick.datetime
ticks.append(tick)
count += 1
end_time = tick.datetime
for tick_data in chunked(ticks, 10000): #分批保存数据
database_manager.save_tick_data(tick_data,update_tick) #保存数据到数据库
time_consuming_end =time()
print(f'载入postgres csv tick数据,开始时间:{start_time},结束时间:{end_time},数据量:{count},耗时:{round(time_consuming_end-time_consuming_start,3)}秒')
#--------------------------------------------------------------------------------------------
def save_tdx_data(self,file_path:TextIO,vt_symbol:str,price_tick:Union[int,float],update_bar:bool,future_download:bool,interval: Interval = Interval.MINUTE):
"""
保存通达信导出的lc1分钟数据,期货对齐datetime到文华财经
"""
bars = []
start_time = None
count = 0
time_consuming_start = time()
symbol,exchange,gateway_name = extract_vt_symbol(vt_symbol)
#读取二进制文件
dt = np.dtype([
('date', 'u2'),
('time', 'u2'),
('open_price', 'f4'),
('high_price', 'f4'),
('low_price', 'f4'),
('close_price', 'f4'),
('amount', 'f4'),
('volume', 'u4'),
('reserve','u4')])
data = np.fromfile(file_path, dtype=dt)
df = pd.DataFrame(data, columns=data.dtype.names)
df.eval('''
year=floor(date/2048)+2004
month=floor((date%2048)/100)
day=floor(date%2048%100)
hour = floor(time/60)
minute = time%60
''',inplace=True)
df.index=pd.to_datetime(df.loc[:,['year','month','day','hour','minute']])
df.drop(['date','time','year','month','day','hour','minute',"amount","reserve"],axis=1,inplace=True)
for index in range(len(df)):
bar = BarData(
symbol=symbol,
exchange=exchange,
interval=interval,
open_price=round_to(df["open_price"][index],price_tick),
high_price=round_to(df["high_price"][index],price_tick),
low_price=round_to(df["low_price"][index],price_tick),
close_price=round_to(df["close_price"][index],price_tick),
volume=round_to(df["volume"][index],1),
datetime = df.index[index],
gateway_name = "DATABASE"
)
if future_download:
if bar.datetime.weekday() == 0:
if bar.datetime.time() >= dtime(21,0):
bar.datetime -= timedelta(days=3)
if bar.datetime.time() <= dtime(2,30):
bar.datetime -= timedelta(days=2)
else:
if bar.datetime.time() >= dtime(21,0):
bar.datetime -= timedelta(days=1)
bar.datetime-= timedelta(minutes=1)
if not start_time:
start_time = bar.datetime
bars.append(bar)
count += 1
end_time = bar.datetime
for bar_data in chunked(bars, 10000): #分批保存数据
database_manager.save_bar_data(bar_data,update_bar) #保存数据到数据库
time_consuming_end =time()
msg = f'载入通达信标的:{vt_symbol} 分钟数据,开始时间:{start_time},结束时间:{end_time},数据量:{count},耗时:{round(time_consuming_end-time_consuming_start,3)}秒'
print(msg)
#--------------------------------------------------------------------------------------------
def save_tick_data(self, file_path: TextIO, symbol: str, exchange: Exchange,update_tick : bool):
"""
保存tick数据
"""
ticks = []
start_time = None
count = 0
time_consuming_start = time()
with open(file_path,'r',encoding="utf-8") as f2:
reader = csv.reader((line.replace('\0','') for line in f2),delimiter=",") #line.replace移除NULL值
for item in reader:
if len(item[0]) >= 10:
tick = TickData(
symbol=symbol,
exchange=exchange,
ask_price_1=item[1],
ask_volume_1=item[6],
bid_price_1=item[11],
bid_volume_1=item[16],
open_price=item[31],
high_price=item[25],
last_price=item[26],
last_volume=item[27],
limit_down=item[28],
limit_up=item[36],
low_price=item[29],
pre_close=item[32],
volume=item[37],
open_interest=item[30],
gateway_name = "DATABASE"
)
if '.' in item[35]:#item[35]time,item[21]date
tick.datetime = datetime.strptime(' '.join([item[21],item[35]]), '%Y%m%d %H:%M:%S.%f')
else:
tick.datetime = datetime.strptime(' '.join([item[21],item[35]]), '%Y%m%d %H:%M:%S')
ticks.append(tick)
count += 1
if not start_time:
start_time = tick.datetime
end_time = tick.datetime
for tick_data in chunked(ticks, 10000): #分批保存数据
database_manager.save_tick_data(tick_data,update_tick) #保存数据到数据库
time_consuming_end =time()
print(f"载入{bar.vt_symbol}:tick数据,开始时间:{start_time},结束时间:{end_time},数据量:{count},耗时:{round(time_consuming_end-time_consuming_start,3)}秒")
#--------------------------------------------------------------------------------------------
def delete_bar(self, symbol: str, interval: "Interval",start_time: datetime):
"""
删除bar数据
"""
database_manager.delete_bar(symbol,interval,start_time)
#--------------------------------------------------------------------------------------------
def delet_tick(self, symbol: str):
"""删除tick数据"""
database_manager.delet_tick(symbol)
#--------------------------------------------------------------------------------------------
def get_bar_vt_symbol(self, interval: "Interval"):
"""获取分钟合约列表"""
return database_manager.get_bar_vt_symbol(interval)
#--------------------------------------------------------------------------------------------
def get_tick_vt_symbol(self):
"""获取tick合约列表"""
return database_manager.get_tick_vt_symbol()
#--------------------------------------------------------------------------------------------
if __name__ == '__main__':
#读取数据
read_data = False
if read_data:
start_time = datetime(2019,1,1)
end_time = datetime.now()
load_data = PostgresOperation().load_bar_data('a99', Exchange.DCE, Interval.MINUTE,start_time,end_time)
for data in load_data:
print(data)
#--------------------------------------------------------------------------------------------
#删除数据
delet_data = False
if delet_data:
delet_symbol = ['bu99','rb99']
for X in delet_symbol:
PostgresOperation().delet_data(X,Interval.MINUTE) #删除分钟周期数据
#--------------------------------------------------------------------------------------------
#获取所有合约
get_vt_symbol = False
if get_vt_symbol:
print(PostgresOperation().get_bar_vt_symbol(Interval.MINUTE))
#--------------------------------------------------------------------------------------------
#保存目录所有合约数据到数据库
save_all = True
if save_all:
file_path = one_token_data_path.replace("one_token数据","tdx数据")
event_engine = EventEngine()
main_engine = MainEngine(event_engine)
contracts = main_engine.load_contracts()
symbol_contract_map = {}
vt_symbol = ""
price_tick = 0
file_names:List[str] =[] # 文件名列表
params = {} # vt_symbol:price_tick映射字典
future_download = True # True期货数据下载状态,False股票数据下载状态
for contract in list(contracts.values()):
symbol_contract_map[contract.symbol.upper()] = contract
pool = ProcessPoolExecutor(max_workers=cpu_count())
for dirpath, dirnames, filenames in os.walk(file_path):
for file_name in filenames: #当前目录所有文件名
#过滤压缩文件
if file_name.split(".")[1] in ["rar","7z"]:
continue
if file_name.endswith("lc1"):
if file_name not in file_names:
file_names.append(f"{file_path}{file_name}")
if future_download:
symbol = file_name.split(".")[0].split("#")[1]
#指数合约symbol合成
if symbol.endswith("L9"):
symbol = symbol.split("L9")[0] + "99"
if symbol in symbol_contract_map:
vt_symbol = symbol_contract_map[symbol].vt_symbol
price_tick = symbol_contract_map[symbol].price_tick
else:
symbol = file_name.split(".")[0]
if symbol.startswith("sh"):
exchange_str = "SSE"
elif symbol.startswith("sz"):
exchange_str = "SZSE"
vt_symbol = symbol[-6:] + "_" + exchange_str + "/XTP"
price_tick = 0.01
params.update({vt_symbol:price_tick})
assert params,"tdx数据文件夹下合约列表为空"
for setting in list(zip(file_names,list(params),list(params.values()))):
setting += (False,future_download,)
pool.submit(PostgresOperation().save_tdx_data, *setting)
#保存股票列表
if not future_download:
stock_vt_symbols:List[str] = load_json("stock_vt_symbols.json")
for vt_symbol in list(params):
if vt_symbol not in stock_vt_symbols:
stock_vt_symbols.append(vt_symbol)
save_json("stock_vt_symbols.json",stock_vt_symbols)