文章太长,再分一贴吧。
4. 交易时间段的处理
4.1 交易时间段处理的复杂性
一个合约的交易时间段信息,就包含在一个字符串中。通常看起来是这样的:
"21:00-23:00,09:00-10:15,10:30-11:30,13:30-15:00"
它看似简单,实则非常复杂!简单在于它只是一个字符串,其实它能够表达非常复杂的交易时间规定。例如交易时间可以少到只有1段,也可以4到5个段,可跨日,也可以跨多日,如遇到周末或者长假。但是长假太难处理了,我们这也不处理各种各样的假日规定,因为那个太复杂了!不过好在时下很多软件,著名的和非著名的软件,几乎都不处理跨长假的问题,不处理的原因也是和我分析的一样,不过这也没有影响他们多软件被广大用户接受的程度。所以我们也不处理跨长假的问题。
当然想处理跨长假也不成,条件不具备呀。因为毕竟我们不是交易所,不知道各种各样的休假规定,不同市场,不同国家的节假日,千奇百怪,太难处理了。而且我们也不能说不处理哪个市场或者国家的投资品种吧?绝大部分软件都不处理长假对K线对齐方式的影响,原因就在于此,没有什么别的说辞!
4.2 交易时间段处理具有的功能
- 从交易时间段字符串中提取出各段的起止时间(天内的秒数) 列表
- 给定一个时间,可以得到其交易日及日期时间格式的交易时间段列表,无效交易时间返回空
- 给定一个时间,得到其在日内的交易时间、所在窗口的索引、窗口开始时间和截止时间
4.3 交易时间段处理的实现
在vnpy\usertools下创建一个名称为trading_hours.py,其代码如下:
"""
本文件主要实现合约的交易时间段:TradingHours
作者:hxxjava
日期:2022-03-28
修改:2022-06-09 修改内容:TradingHours的get_intraday_window()处理时间段错误
"""
from calendar import month
from typing import Callable,List,Dict, Tuple, Union
from enum import Enum
from datetime import datetime,date,timedelta, tzinfo
import numpy as np
import pytz
CHINA_TZ = pytz.timezone("Asia/Shanghai")
from vnpy.trader.constant import Interval
def to_china_tz(dt: datetime) -> datetime:
"""
Convert a datetime object to a CHINA_TZ localize datetime object
"""
return CHINA_TZ.localize(dt.replace(tzinfo=None))
INTERVAL_MAP = {
Interval.MINUTE:60,
Interval.HOUR:3600,
Interval.DAILY:3600*24,
Interval.WEEKLY:3600*24*7,
}
def get_time_segments(trading_hours:str) -> List:
"""
从交易时间段字符串中提取出各段的起止时间(天内的秒数) 列表
"""
time_sepments = []
# 提取各段
str_segments = trading_hours.split(',')
pre_start,day_offset = None,0
for s in reversed(str_segments): # 反向遍历各段
# 提取段的起止时间
start,stop = s.split('-')
# 计算开始时间天内秒
hh,mm = start.split(':')
start_s = int(hh)*3600+int(mm)*60
# 计算截止时间天内秒
hh,mm = stop.split(':')
stop_s = int(hh)*3600+int(mm)*60
if pre_start and start > pre_start:
day_offset -= 1
pre_start = start
# 加入列表
time_sepments.insert(0,(day_offset,start_s,stop_s))
return time_sepments
def in_segments(trade_segments:List,trade_dt:datetime):
""" 判断一个时间是否在一个交易时间段列表中 """
trade_dt = to_china_tz(trade_dt)
for start,stop in trade_segments:
if start <= trade_dt < stop:
return True
return False
class TradingHours(object):
"""
交易时间段处理
"""
def __init__(self,trading_hours:str):
"""
初始化函数 。
参数说明:
trading_hours:交易时间段字符串,例如:"21:00-23:00,09:00-10:15,10:30-11:30,13:30-15:00"
pre_open: 集合竞价时段长度,单位分钟。例如:国内期货pre_open=5
after_close: 交易日收盘后结算时长。例如国内期货持续到15:20,那么after_close=20
"""
self.time_segments:List[Tuple[int,int,int]] = get_time_segments(trading_hours)
def day_trade_time(self,interval:Interval) -> int:
"""
一个交易日的交易时长,单位由interval 规定,不足的部分+1
"""
seconds = 0.0
for _,start,stop in self.time_segments:
seconds += stop - start + (0 if start < stop else INTERVAL_MAP[Interval.DAILY])
if not interval:
return seconds
else:
return np.ceil(seconds/INTERVAL_MAP[interval])
def get_auction_closes_segments(self,trade_dt:datetime) -> Tuple[date,List]:
"""
得到一个交易时间所在的交易日及集合竞价时间段和所有休市时段的列表
"""
if not self.auction_closes:
return (None,[])
trade_dt = to_china_tz(trade_dt)
dates = [trade_dt.date()+timedelta(days=i) for i in range(-3,4)]
# 根据 self.auction_closes 构造出一周内的日期时间格式的非连续交易时间段字典
week_seqments = {
dt:
[(to_china_tz(datetime(dt.year,dt.month,dt.day))+timedelta(days=days-(2 if days == -1 and dt.weekday()==0 else 0),seconds=start),
to_china_tz(datetime(dt.year,dt.month,dt.day))+timedelta(days=days-(2 if days == -1 and dt.weekday()==0 else 0)+(1 if start>stop else 0),seconds=stop))
for days,start,stop in self.auction_closes]
for dt in dates if dt.weekday() not in [5,6]
}
# 在非交易时间段字典中查找trade_dt所在集合竞价时间段,确定所属交易日
for dt,datetime_segments in week_seqments.items():
# 遍历一周中的每日
if in_segments(datetime_segments,trade_dt):
return (dt,datetime_segments)
return (None,[])
def get_trade_hours(self,trade_dt:datetime) -> Tuple[date,List[Tuple[datetime,datetime]]]:
"""
得到一个时间的交易日及日期时间格式的交易时间段列表,无效交易时间返回空
"""
# 构造trade_dt加前后三天共7的日期
trade_dt = to_china_tz(trade_dt)
dates = [trade_dt.date()+timedelta(days=i) for i in range(-3,4)]
# 根据 self.time_segments 构造出一周内的日期时间格式的交易时间段字典
week_seqments = {
dt:
[(to_china_tz(datetime(dt.year,dt.month,dt.day))+timedelta(days=days-(2 if days == -1 and dt.weekday()==0 else 0),seconds=start),
to_china_tz(datetime(dt.year,dt.month,dt.day))+timedelta(days=days-(2 if days == -1 and dt.weekday()==0 else 0)+(1 if start>stop else 0),seconds=stop))
for days,start,stop in self.time_segments]
for dt in dates if dt.weekday() not in [5,6]
}
trade_day,trading_segments = None,[]
# 在交易时间段字典中查找trade_dt所在交易时间段,确定所属交易日
for dt,datetime_segments in week_seqments.items():
# 遍历一周中的每日
for start,stop in datetime_segments:
# 遍历一日中的每个交易时间段
if start <= trade_dt < stop:
# 找到了,确定dt为trade_dt的交易日
trade_day = dt
break
if trade_day:
# 已经找到,停止
trading_segments = datetime_segments
break
return (trade_day,trading_segments)
def get_trading_segments(self,tradeday:date): # List[Tuple[datetime,datetime]]
"""
得到某个交易日的就要时间段。注:只考虑周末,不考虑法定假
"""
segments = []
weekday = tradeday.weekday()
if weekday not in [5,6]:
# 周一至周五
# 周一跨日需插入2天
insert_days = -2 if tradeday.weekday() == 0 else 0
y,m,d = tradeday.year,tradeday.month,tradeday.day
for day,start,stop in self.time_segments:
days = insert_days + day if day < 0 else day
start_dt = datetime(y,m,d,0,0,0) + timedelta(days=days,seconds=start)
stop_dt = datetime(y,m,d,0,0,0) + timedelta(days=days+(0 if start < stop else 1),seconds=stop)
segments.append((start_dt,stop_dt))
return segments
def get_intraday_window(self,trade_dt:datetime,window:int) -> Tuple[date,List[Tuple[datetime,datetime]]]:
"""
得到一个时间的日内交易时间、窗口索引、窗口开始时间和截止时间
"""
trade_dt = to_china_tz(trade_dt)
interval = Interval.MINUTE
oneday_minutes = self.day_trade_time(interval)
if window > oneday_minutes:
raise f"In day window can't exceed {oneday_minutes} minutes !"
result = (None,[])
if window == 0:
# window==0 无意义
return result
# 求dt的交易日
trade_day,segment_datetimes = self.get_trade_hours(trade_dt)
if not trade_day:
# 无效的交易日
return result
if np.sum([start <= trade_dt < stop for start,stop in segment_datetimes]) == 0:
# 如果dt不在各个交易时间段内为无效的交易时间
return result
# 交易日的开盘时间
t0 = segment_datetimes[0][0]
# 构造各个交易时间段的起止数组
starts = np.array([(seg_dt[0]-t0).seconds*1.0 for seg_dt in segment_datetimes])
stops = np.array([(seg_dt[1]-t0).seconds*1.0 for seg_dt in segment_datetimes])
# 求dt在交易日中的自然时间
nature_t = (trade_dt - t0).seconds
# 求dt已经走过的交易时间
traded_t = np.sum(nature_t - starts[starts<=nature_t]) - np.sum(nature_t-stops[stops<nature_t])
if traded_t < 0:
# 开盘之前的为无效交易时间
return result
# 求当前所在窗口的宽度、索引、开始交易时间及截止时间
window_width = window * INTERVAL_MAP[interval]
window_idx = np.floor(traded_t/window_width)
window_start = window_idx * window_width
window_stop = window_start + window_width
# 求各个交易时间段的宽度
segment_widths = stops - starts
# print("!!!3",window_start,window_stop,segment_widths)
# 求各个交易时间段累计日内交易时间
sums = [np.sum(segment_widths[:(i+1)]) for i in range(len(segment_widths))]
if window_stop > sums[-1]:
# 不可以跨日处理
window_stop = sums[-1]
# 累计日内交易时间数组
seg_sum = np.array(sums)
# 每段开始累计日内交易时间数组
seg_start_sum = np.array([0] + sums)
# 求窗口开始和截止时间的时间段索引
s1,s2 = seg_sum - window_start,seg_sum - window_stop
start_idx,stop_idx = np.sum(s1 <= 0),np.sum(s2<0)
# 求窗口开始和截止时间的在其时间段中的偏移量
start_offset = (window_start-seg_start_sum)[start_idx]
stop_offset = (window_stop-seg_start_sum)[stop_idx]
# 求窗口包含的时间片段列表
window_segments = []
for idx in range(start_idx,stop_idx+1):
start,stop = segment_datetimes[idx]
t1 = start + timedelta(seconds=start_offset) if idx == start_idx else start
t2 = start + timedelta(seconds=stop_offset) if idx == stop_idx else stop
window_segments.append((t1,t2))
# 窗口所属交易日及包含的时间片段列表
result = (trade_day,window_segments)
return result
def get_week_tradedays(self,trade_dt:datetime) -> List[date]:
""" 得到一个交易时间所在周的交易日 """
trade_dt = to_china_tz(trade_dt)
trade_day,trade_segments = self.get_trade_hours(trade_dt)
if not trade_day:
return []
monday = trade_dt.date() - timedelta(days=trade_dt.weekday())
week_dates = [monday + timedelta(days=i) for i in range(5)]
if trade_day not in week_dates:
next_7days = [(trade_dt + timedelta(days=i+1)) for i in range(7)]
week_dates = [day.date() for day in next_7days if day.weekday() not in [5,6]]
return week_dates
def get_month_tradedays(self,trade_dt:datetime) -> List[date]:
""" 得到一个交易时间所在月的交易日 """
trade_dt = to_china_tz(trade_dt)
trade_day,trade_segments = self.get_trade_hours(trade_dt)
if not trade_day:
return []
first_day = date(year=trade_day.year,month=trade_day.month,day=1)
this_month = trade_day.month
days32 = [first_day + timedelta(days = i) for i in range(32)]
month_dates = [day for day in days32 if day.weekday() not in [5,6] and day.month==this_month]
return month_dates
def get_year_tradedays(self,trade_dt:datetime) -> List[date]:
""" 得到一个交易时间所在年的交易日 """
trade_dt = to_china_tz(trade_dt)
trade_day,trade_segments = self.get_trade_hours(trade_dt)
if not trade_day:
return []
new_years_day = date(year=trade_day.year,month=1,day=1)
this_year = trade_day.year
days366 = [new_years_day + timedelta(days = i) for i in range(366)]
trade_dates = [day for day in days366 if day.weekday() not in [5,6] and day.year==this_year]
return trade_dates
def has_night_tradetime(self) -> bool:
""" 有夜盘交易时间吗? """
for (days,start,stop) in self.time_segments:
if start >= 18*INTERVAL_MAP(Interval.HOUR):
return True
return False
def has_day_tradetime(self) -> bool:
""" 有日盘交易时间吗 ? """
for (days,start,stop) in self.time_segments:
if start < 18*INTERVAL_MAP(Interval.HOUR):
return True
return False
5. 日内对齐等交易时长K线生成器的实现
5.1 确定K线生成器MyBarGenerator的生成规则
5.1.1 一步到位地解决问题
先给它取个名称,就叫MyBarGenerator吧,它是对BarGenerator的扩展。
不过在构思MyBarGenerator的时候,我发现它其实不应该叫“日内对齐等交易时长K线生成器”。因为我们不应该只局限于日内的n分钟K线生成器,难道vnpy系统就不应该、不能够或者不使用日线以上的K线了吗?我们只能够使用日内K线进行量化交易吗?难道大家都没有过这方面的需求吗?我想答案是否定的。
那好,所幸就设计一个全功能的K线生成器:MyBarGenerator。
为此我们需要扩展Interval的定义,因为Interval是表示K线周期的常量,可是它的格局不够,最大只能到周一级WEEKLY。也就是说您用目前的Interval是没有办法表达月和年这样的周期的。
class Interval(Enum):
"""
Interval of bar data.
"""
MINUTE = "1m"
HOUR = "1h"
DAILY = "d"
WEEKLY = "w"
TICK = "tick"
MONTHLY = "month" # hxxjava add
YEARLY = "year" # hxxjava add
顺便在这里吐槽一下BarGenerator:
- 目前的BarData中包含了一个interval字段的,可是它在BarGenerator的时候根本就没有使用过,而使用它本是信手拈来的事情,但是没有却没有使用。如果不信,你可以去看看用它产生出来的bar的内容。
- 另外本来还应该增加一个秒单位(SECONDLY = "1s")的,这个单位其实对高频交易也是很有需求的,可是现在却没有。不知道大家对此有什么看法。
5.1.2 按周期对K线分类
在系统且并详细分析之后,把K线分类为:日内K线、日K线,周K线、月K线及年K线等周期K线五类。
1)日内K线包括1~n分钟K线,如1分钟、n分钟两类,其中n小于正常交易日的最大交易分钟数。日内K线取消对小时周期单位支持,因为可以通过n分钟的方式来实现。如:
- 1小时K线可以通过60分钟来表达
- 2小时K线可以通过120分钟来表达
- 4小时K线可以通过240分钟来表达
这么做的好处是:非常容易地实现90分钟的日内K线,而这是系统自带BarGenerator无法做到的。
2)日K线:每个交易日产生一个,它包含一到多个交易时间段。根据是否包含夜盘交易时间段,又可以分为跨日K线和不跨日K线。
3)周K线:由周一至周五中所有交易日的交易数据合成得到,它其实是一种特殊的n日K线,只是n<=5而已。
4)月K线:由每月1日至月末最后一个交易日的交易数据合成得到,除去所有周末,它最多包含23个交易日,遇到本月有长假日,其所包含的交易日会更少。
5)年K线:由每年1月1日至12月31日中的所有交易日的交易数据合成得到,除去所有周末。它可以理解为由一年中的所有交易日数据合成的,也可以理解为由一年中的12个月的交易日数据合成的。
5.1.3 确定K线生成规则:
1)日内K线(包括1~n分钟K线)生成规则:
- K线对齐交易日的开盘
- 等交易时长生成
- 忽略中间休市时间
- 不跨日生成,遇收市强行截止
- 周期单位必须为分钟,n小于日交易最大分钟数
2)日K线生成规则:
- 对齐其交易日的开盘时间
- 休市时间收到的数据为非法数据
- 交易日收盘时间生成或者在收到下收到大于收盘时间交易数据时生成
3)周K线生成规则:
- 对齐周一开盘时间
- 收到周一或者第一个交易日的日K线时创建
- 收到周二到四等交易日日K线时继续合成
- 收到周五或者下周交易日日K线时生成
4)月K线生成规则:
- 对齐当月1日的开盘时间,去除所有周末构成本月可能的交易日期
- 收到当月第一个交易日的日K线时创建
- 月K线创建后在未收到本月可能的交易日期的日K线时继续合成
- 收到月可能的交易日期的最后一个交易日K线或者下个月的第一交易日日K线时生成
5)年K线生成规则:
年K线可以由两种方式进行合成:一种是用日K线合成,另一种是用月K线合成。我们这里选择用日K线来合成。
- 年K线对齐每年的1月1日,从1月1日至12月31日,去除所有周末,构成所有的可能的交易日
- 遇到当年的第一个日K线时创建年K线
- 年K线创建后,在收到日K线的交易日期未到最后一个可能的交易日时继续合成
- 收到日K线的交易日为本年可能的交易日期或者下一年的交易日时生成
5.2 MyBargenerator的实现
在vnpy\usertools\utility.py中加入如下面的两个部分:
5.2.1 加入引用部分:
from copy import deepcopy
from typing import List,Dict,Tuple,Optional,Sequence,Callable
from datetime import date,datetime,timedelta
from vnpy.trader.constant import Interval
from vnpy.trader.object import TickData,BarData
from vnpy.trader.utility import extract_vt_symbol
from vnpy.usertools.trading_hours import TradingHours,in_segments
from vnpy.usertools.trade_hours import CHINA_TZ
def generate_temp_bar(small:BarData,big:BarData,interval:Interval):
""" get temp intra day small_bar """
small_bar:BarData = deepcopy(small) # 1 minute small_bar
big_bar:BarData = deepcopy(big)
if big_bar and small_bar:
big_bar.high_price = max(big_bar.high_price,small_bar.high_price)
big_bar.low_price = min(big_bar.low_price,small_bar.low_price)
big_bar.close_price = small_bar.close_price
big_bar.open_interest = small_bar.open_interest
big_bar.volume += small_bar.volume
big_bar.turnover += small_bar.turnover
elif not big_bar and small_bar:
big_bar = BarData(
symbol=small_bar.symbol,
exchange=small_bar.exchange,
interval=interval,
datetime=small_bar.datetime,
gateway_name=small_bar.gateway_name,
open_price=small_bar.open_price,
high_price=small_bar.high_price,
low_price=small_bar.low_price,
close_price = small_bar.close_price,
open_interest = small_bar.open_interest,
volume = small_bar.volume,
turnover = small_bar.turnover
)
return big_bar
5.2.2 MyBarGenerator的完整代码
class MyBarGenerator():
"""
An align bar generator.
Comment's for parameters:
on_bar : callback function on 1 minute bar is generated.
window : window bar's width.
on_window_bar : callback function on x interval bar is generated.
interval : window bar's unit.
trading_hours: trading hours with which the window bar can be generated.
"""
def __init__(
self,
on_bar: Callable,
window: int = 0,
on_window_bar: Callable = None,
interval: Interval = Interval.MINUTE,
trading_hours:str = ""
):
""" Constructor """
self.bar: BarData = None
self.on_bar: Callable = on_bar
self.interval: Interval = interval
self.interval_count: int = 0
self.intra_day_bar: BarData = None
self.day_bar: BarData = None
self.week_bar: BarData = None
self.month_bar: BarData = None
self.year_bar: BarData = None
self.day_bar_cnt:int = 0 # 日K线的1分钟K线计数
self.week_daybar_cnt:int = 0 # 周K线的日K线计数
self.window: int = window
self.on_window_bar: Callable = on_window_bar
self.last_tick: TickData = None
if interval not in [Interval.MINUTE,Interval.DAILY,Interval.WEEKLY,Interval.MONTHLY,Interval.YEARLY]:
raise ValueError(f"MyBarGenerator support MINUTE,DAILY,WEEKLY,MONTHLY and YEARLY bar generation only , please check it !")
if not trading_hours:
raise ValueError(f"MyBarGenerator need trading hours setting , please check it !")
# trading hours object
self.trading_hours = TradingHours(trading_hours)
self.day_total_minutes = int(self.trading_hours.day_trade_time(Interval.MINUTE))
self.tick_windows = (None,[])
# current intraday window bar's contains trading day and time segments list
self.intraday_bar_window = (None,[]) # (trade_day,[])
# current daily bar's window containts trading day and time segment list
self.daily_bar_window = (None,[])
# current weekly bar's window containts all trade days
self.weekly_bar_window = []
# current monthly bar's window containts all trade days
self.monthly_bar_window = []
# current yearly bar's window containts all trade days
self.yearly_bar_window = []
def update_tick(self, tick: TickData) -> None:
"""
Update new tick data into generator.
"""
new_minute = False
# Filter tick data with 0 last price
if not tick.last_price:
return
# Filter tick data with older timestamp
if self.last_tick and tick.datetime < self.last_tick.datetime:
print(f"特别tick【{tick}】!")
return
if self.tick_windows == (None,[]) or not in_segments(self.tick_windows[1],tick.datetime):
# 判断tick是否在连续交易时间段或者集合竞价时间段中
self.tick_windows = self.trading_hours.get_trade_hours(tick.datetime)
if self.tick_windows == (None,[]):
# 不在连续交易时间段
print(f"特别tick【{tick}】")
return
if not self.bar:
new_minute = True
elif (
(self.bar.datetime.minute != tick.datetime.minute)
or (self.bar.datetime.hour != tick.datetime.hour)
):
self.bar.datetime = self.bar.datetime.replace(
second=0, microsecond=0
)
self.on_bar(self.bar)
new_minute = True
if new_minute:
self.bar = BarData(
symbol=tick.symbol,
exchange=tick.exchange,
interval=Interval.MINUTE,
datetime=to_china_tz(tick.datetime),
gateway_name=tick.gateway_name,
open_price=tick.last_price,
high_price=tick.last_price,
low_price=tick.last_price,
close_price=tick.last_price,
open_interest=tick.open_interest
)
else:
self.bar.high_price = max(self.bar.high_price, tick.last_price)
if tick.high_price > self.last_tick.high_price:
self.bar.high_price = max(self.bar.high_price, tick.high_price)
self.bar.low_price = min(self.bar.low_price, tick.last_price)
if tick.low_price < self.last_tick.low_price:
self.bar.low_price = min(self.bar.low_price, tick.low_price)
self.bar.close_price = tick.last_price
self.bar.open_interest = tick.open_interest
self.bar.datetime = to_china_tz(tick.datetime)
if self.last_tick:
volume_change = tick.volume - self.last_tick.volume
self.bar.volume += max(volume_change, 0)
turnover_change = tick.turnover - self.last_tick.turnover
self.bar.turnover += max(turnover_change, 0)
self.last_tick = tick
def update_bar(self, bar: BarData) -> None:
"""
Update 1 minute bar into generator
"""
if self.interval == Interval.MINUTE and self.window > 0:
# update inday bar
self.update_intraday_bar(bar)
elif self.interval in [Interval.DAILY,Interval.WEEKLY,Interval.MONTHLY,Interval.YEARLY]:
# update daily,weekly,monthly or yearly bar
self.update_daily_bar(bar)
def update_intraday_bar(self, bar: BarData) -> None:
""" update intra day x window bar """
if bar:
bar.datetime = to_china_tz(bar.datetime)
if self.interval != Interval.MINUTE or self.window <= 1:
return
if self.intraday_bar_window == (None,[]):
# 首次调用日内K线更新函数
trade_day,time_segments = self.trading_hours.get_intraday_window(bar.datetime,self.window)
if (trade_day,time_segments) == (None,[]):
# 无效的1分钟K线
return
# 更新当前日内K线交易时间
self.intraday_bar_window = (trade_day,time_segments)
# 创建新的日内K线
self.intra_day_bar = BarData(
symbol=bar.symbol,
exchange=bar.exchange,
interval=Interval.MINUTE,
datetime=bar.datetime,
gateway_name=bar.gateway_name,
open_price=bar.open_price,
high_price=bar.high_price,
low_price=bar.low_price,
)
elif not in_segments(self.intraday_bar_window[1],bar.datetime):
# 1分钟K线不属于当前日内K线
str1 = f"bar.datetime={bar.datetime}\nintraday_bar_window:{self.intraday_bar_window}"
trade_day,time_segments = self.trading_hours.get_intraday_window(bar.datetime,self.window)
if (trade_day,time_segments) == (None,[]):
# 无效的1分钟K线
return
# 当前日内K线已经生成,推送当前日内K线
if self.on_window_bar:
self.on_window_bar(self.intra_day_bar)
# 更新当前日内K线交易时间
self.intraday_bar_window = (trade_day,time_segments)
str1 += f"\nintraday_bar_window:{self.intraday_bar_window}"
print(str1)
# 创建新的日内K线
self.intra_day_bar = BarData(
symbol=bar.symbol,
exchange=bar.exchange,
interval=Interval.MINUTE,
datetime=bar.datetime,
gateway_name=bar.gateway_name,
open_price=bar.open_price,
high_price=bar.high_price,
low_price=bar.low_price,
)
# 1分钟K线属于当前日内K线
# 更新当前日内K线
self.intra_day_bar.high_price = max(self.intra_day_bar.high_price,bar.high_price)
self.intra_day_bar.low_price = min(self.intra_day_bar.low_price,bar.low_price)
self.intra_day_bar.close_price = bar.close_price
self.intra_day_bar.open_interest = bar.open_interest
self.intra_day_bar.volume += bar.volume
self.intra_day_bar.turnover += bar.turnover
# 判断当前日内K线是否结束
close_time = self.intraday_bar_window[1][-1][1]
next_minute_dt = bar.datetime + timedelta(minutes=1)
if close_time <= next_minute_dt:
# 当前日K内线已经结束
# 当前日内K线已经生成,推送之
if self.on_window_bar:
print(f"close_time={close_time},next_minute_dt={next_minute_dt}")
self.on_window_bar(self.intra_day_bar)
self.intraday_bar_window = (None,[])
self.intra_day_bar = None
def update_daily_bar(self, bar: BarData) -> bool:
""" update daily bar using 1 minute bar """
result = False
if bar:
bar.datetime = to_china_tz(bar.datetime)
if self.daily_bar_window == (None,[]):
# 首次调用日K线更新函数
trade_day,trade_segments = self.trading_hours.get_trade_hours(bar.datetime)
if (trade_day,trade_segments) == (None,[]):
# 无效的1分钟K线
return result
# 更新当前日K线交易时间
self.daily_bar_window = (trade_day,trade_segments)
# 创建新的日K线
self.day_bar = BarData(
symbol=bar.symbol,
exchange=bar.exchange,
interval=Interval.DAILY,
datetime=bar.datetime,
gateway_name=bar.gateway_name,
open_price=bar.open_price,
high_price=bar.high_price,
low_price=bar.low_price,
)
self.day_bar_cnt = 1
if not in_segments(self.daily_bar_window[1],bar.datetime):
# 1分钟K线不属于当前日K线
trade_day,trade_segments = self.trading_hours.get_trade_hours(bar.datetime)
if (trade_day,trade_segments) == (None,[]):
# 无效的1分钟K线
return
# 当前日K线已经生成
if self.interval == Interval.DAILY:
# 推送当前日K线
if self.on_window_bar:
self.on_window_bar(self.day_bar)
self.day_bar_cnt = 0
else:
# 更新更大周期K线
if self.update_weekly_bar(self.day_bar):
self.week_daybar_cnt += 1
self.update_monthly_bar(self.day_bar)
self.update_yearly_bar(self.day_bar)
# 更新当前日K线交易时间
self.daily_bar_window = (trade_day,trade_segments)
# 创建新的日K线
self.day_bar = BarData(
symbol=bar.symbol,
exchange=bar.exchange,
interval=Interval.DAILY,
datetime=bar.datetime,
gateway_name=bar.gateway_name,
open_price=bar.open_price,
high_price=bar.high_price,
low_price=bar.low_price,
)
# 1分钟K线属于当前交易日
# 更新当前日K线
self.day_bar.high_price = max(self.day_bar.high_price,bar.high_price)
self.day_bar.low_price = min(self.day_bar.low_price,bar.low_price)
self.day_bar.close_price = bar.close_price
self.day_bar.open_interest = bar.open_interest
self.day_bar.volume += bar.volume
self.day_bar.turnover += bar.turnover
result = True
# 判断当前日K线是否结束
close_time = self.daily_bar_window[1][-1][1]
next_minute_dt = bar.datetime + timedelta(minutes=1)
if close_time <= next_minute_dt or self.day_total_minutes == self.day_bar_cnt:
# 当前日K线已经结束
# 当前日K线已经生成
if self.interval == Interval.DAILY:
# 推送当前日K线
if self.on_window_bar:
self.on_window_bar(self.day_bar)
else:
# 更新更大周期K线
if self.update_weekly_bar(self.day_bar):
self.week_daybar_cnt += 1
self.update_monthly_bar(self.day_bar)
self.update_yearly_bar(self.day_bar)
self.daily_bar_window = (None,[])
self.day_bar = None
self.day_bar_cnt = 0
return result
def update_weekly_bar(self, bar: BarData) -> bool:
""" update weekly bar using a daily bar """
result = False
if bar:
bar.datetime = to_china_tz(bar.datetime)
if self.interval != Interval.WEEKLY:
# 设定周期单位不是周,不处理
return result
if not self.weekly_bar_window:
# 首次调用周K线更新函数
week_tradedays = self.trading_hours.get_week_tradedays(bar.datetime)
if not week_tradedays:
# 无效的日K线
return result
# 更新当前周K线交易日列表
self.weekly_bar_window = week_tradedays
# 创建新的周K线
self.week_bar = BarData(
symbol=bar.symbol,
exchange=bar.exchange,
interval=Interval.WEEKLY,
datetime=bar.datetime,
gateway_name=bar.gateway_name,
open_price=bar.open_price,
high_price=bar.high_price,
low_price=bar.low_price,
)
if bar.datetime not in self.weekly_bar_window:
# 日线不属于当前周K线
week_tradedays = self.trading_hours.get_week_tradedays(bar.datetime)
if not week_tradedays:
# 无效的日K线
return result
# 当前周K线已经生成,推送
if self.on_window_bar:
self.on_window_bar(self.week_bar)
self.week_daybar_cnt = 0
# 更新当前周K线交易日列表
self.weekly_bar_window = week_tradedays
# 创建新的周K线
self.week_bar = BarData(
symbol=bar.symbol,
exchange=bar.exchange,
interval=Interval.WEEKLY,
datetime=bar.datetime,
gateway_name=bar.gateway_name,
open_price=bar.open_price,
high_price=bar.high_price,
low_price=bar.low_price,
)
# 更新当前周K线
self.week_bar.high_price = max(self.week_bar.high_price,bar.high_price)
self.week_bar.low_price = min(self.week_bar.low_price,bar.low_price)
self.week_bar.close_price = bar.close_price
self.week_bar.open_interest = bar.open_interest
self.week_bar.volume += bar.volume
self.week_bar.turnover += bar.turnover
result = True
# 判断当前周K线是否结束
trade_day,_ = self.trading_hours.get_trade_hours(bar.datetime)
if trade_day >= self.weekly_bar_window[-1] or self.week_daybar_cnt == 5:
# 当前周K线已经结束,推送当前周K线
if self.on_window_bar:
self.on_window_bar(self.week_bar)
self.week_daybar_cnt = 0
# 复位当前周交易日列表及周K线
self.weekly_bar_window = []
self.week_bar = None
return result
def update_monthly_bar(self, bar: BarData) -> bool:
""" update monthly bar using a daily bar """
result = False
if bar:
bar.datetime = to_china_tz(bar.datetime)
if self.interval != Interval.MONTHLY:
# 设定周期单位不是月,不处理
return result
if not self.monthly_bar_window:
# 首次调用月K线更新函数
month_tradedays = self.trading_hours.get_month_tradedays(bar.datetime)
if not month_tradedays:
# 无效的日K线
return result
# 更新当前月K线交易日列表
self.monthly_bar_window = month_tradedays
# 创建新的月K线
self.month_bar = BarData(
symbol=bar.symbol,
exchange=bar.exchange,
interval=Interval.MONTHLY,
datetime=bar.datetime,
gateway_name=bar.gateway_name,
open_price=bar.open_price,
high_price=bar.high_price,
low_price=bar.low_price,
)
if bar.datetime not in self.monthly_bar_window:
# 日线不属于当前月K线
month_tradedays = self.trading_hours.get_month_tradedays(bar.datetime)
if not month_tradedays:
# 无效的日K线
return result
# 当前月K线已经生成,推送
if self.on_window_bar:
self.on_window_bar(self.month_bar)
# 更新当前月交易日列表
self.monthly_bar_window = month_tradedays
# 创建新的月K线
self.month_bar = BarData(
symbol=bar.symbol,
exchange=bar.exchange,
interval=Interval.MONTHLY,
datetime=bar.datetime,
gateway_name=bar.gateway_name,
open_price=bar.open_price,
high_price=bar.high_price,
low_price=bar.low_price,
)
# 更新当前月K线
self.month_bar.high_price = max(self.month_bar.high_price,bar.high_price)
self.month_bar.low_price = min(self.month_bar.low_price,bar.low_price)
self.month_bar.close_price = bar.close_price
self.month_bar.open_interest = bar.open_interest
self.month_bar.volume += bar.volume
self.month_bar.turnover += bar.turnover
result = True
# 判断当前月K线是否结束
trade_day,_ = self.trading_hours.get_trade_hours(bar.datetime)
if trade_day >= self.monthly_bar_window[-1]:
# 当前月K线已经结束,推送当前月K线
if self.on_window_bar:
self.on_window_bar(self.month_bar)
# 复位当前月交易日列表及月K线
self.monthly_bar_window = []
self.month_bar = None
return result
def update_yearly_bar(self, bar: BarData) -> bool:
""" update yearly bar using a daily bar """
result = False
if bar:
bar.datetime = to_china_tz(bar.datetime)
if self.interval != Interval.YEARLY:
# 设定周期单位不是年,不处理
return result
if not self.yearly_bar_window:
# 首次调用年K线更新函数
year_tradedays = self.trading_hours.get_year_tradedays(bar.datetime)
if not year_tradedays:
# 无效的日K线
return result
# 更新当前年K线交易日列表
self.yearly_bar_window = year_tradedays
# 创建新的年K线
self.year_bar = BarData(
symbol=bar.symbol,
exchange=bar.exchange,
interval=Interval.YEARLY,
datetime=bar.datetime,
gateway_name=bar.gateway_name,
open_price=bar.open_price,
high_price=bar.high_price,
low_price=bar.low_price,
)
if bar.datetime not in self.yearly_bar_window:
# 日线不属于当前年K线
year_tradedays = self.trading_hours.get_year_tradedays(bar.datetime)
if not year_tradedays:
# 无效的日K线
return result
# 当前年K线已经生成,推送
if self.on_window_bar:
self.on_window_bar(self.year_bar)
# 更新当前年交易日列表
self.yearly_bar_window = year_tradedays
# 创建新的年K线
self.year_bar = BarData(
symbol=bar.symbol,
exchange=bar.exchange,
interval=Interval.YEARLY,
datetime=bar.datetime,
gateway_name=bar.gateway_name,
open_price=bar.open_price,
high_price=bar.high_price,
low_price=bar.low_price,
)
# 更新当前年K线
self.year_bar.high_price = max(self.year_bar.high_price,bar.high_price)
self.year_bar.low_price = min(self.year_bar.low_price,bar.low_price)
self.year_bar.close_price = bar.close_price
self.year_bar.open_interest = bar.open_interest
self.year_bar.volume += bar.volume
self.year_bar.turnover += bar.turnover
result = True
# 判断当前年K线是否结束
trade_day,_ = self.trading_hours.get_trade_hours(bar.datetime)
if trade_day >= self.yearly_bar_window[-1]:
# 当前年K线已经结束,推送当前年K线
if self.on_window_bar:
self.on_window_bar(self.year_bar)
# 复位当前年交易日列表及年K线
self.yearly_bar_window = []
self.year_bar = None
return result
def get_temp_bar(self) -> BarData:
""" 返回临时1分钟K线 """
bar = deepcopy(self.bar)
if bar:
bar.datetime = bar.datetime.replace(second=0,microsecond=0)
return bar
def get_temp_window_bar(self,bar:BarData = None) -> BarData:
"""
返回临时窗口K线
"""
temp_bar:BarData = None
if not bar:
# 如果没有传入1分钟K线,取当前生成器的1分钟K线
bar = self.bar
if self.interval == Interval.MINUTE:
if self.window == 0:
temp_bar = deepcopy(self.bar)
else:
temp_bar = generate_temp_bar(bar,self.intra_day_bar,Interval.MINUTE)
elif self.interval == Interval.DAILY:
temp_bar = generate_temp_bar(bar,self.day_bar,Interval.DAILY)
elif self.interval == Interval.WEEKLY:
day_bar = generate_temp_bar(bar,self.day_bar,Interval.DAILY)
temp_bar = generate_temp_bar(day_bar,self.week_bar,Interval.WEEKLY)
elif self.interval == Interval.MONTHLY:
day_bar = generate_temp_bar(bar,self.day_bar,Interval.DAILY)
temp_bar = generate_temp_bar(day_bar,self.month_bar,Interval.MONTHLY)
elif self.interval == Interval.YEARLY:
day_bar = generate_temp_bar(bar,self.day_bar,Interval.DAILY)
temp_bar = generate_temp_bar(day_bar,self.year_bar,Interval.YEARLY)
return temp_bar
def generate(self) -> Optional[BarData]:
"""
Generate the bar data and call callback immediately.
"""
bar = self.bar
if self.bar:
bar.datetime = bar.datetime.replace(second=0, microsecond=0)
self.on_bar(bar)
self.bar = None
return bar
6. 对集合竞价tick和休市期间收到的tick的特别处理
交易时间段是交易所对一个合约连续交易时间的规定,它只规定了在哪些时间段内市场是可以连续交易的,也就是说投资者交易开仓、平仓和撤单的。
但是交易时间段不包括一个合约交易的所有交易时间的规定,例如集合竞价时间段、日内中间休市时间段和交易日收盘休市时间段这三类时间段的规定。
6.1 集合竞价时间段
集合竞价时间段在交易日的开盘时间之前。能够该时间段的参与的投资者可能有资格的限制,就是说可能不是市场的参与者都有资格能够在在集合竞价时段中进行交易的。
而且不同市场,不同合约的集合竞价时间段的长度是不一样的,不同的交易日也可能不同,例如:
1)国内市场
- 国内期货,期权,有夜盘品种是20:55-21:00,遇有长假则位于日盘的第一个交易时段前5分钟;只有日盘品种是8:55-9:00。
- 国内股票,9:25-9:30,因为A股全部是日盘,所以没有长假带来的问题集合竞价发生变化的问题。
- 国内市场的集合竞价通常包括前4分钟为撮合成交,在开盘前1分钟推送一个包括集合竞价tick,它的last_price就是该交易日的开盘价。
2)国外市场
- WTI原油期货的Pre-Open时间,其实就是集合竞价时段。它更加复制,周日的盘前议价期为开盘前1小时,其他交易日的盘前议价期为开盘前15分钟。在此期间,客户可以输入、修改和撤销报单,但报单在该时段不会被撮合成交。此外,在盘前议价期快结束时,即开盘前30秒,不可以进行修改和撤销报单,但是可以下新的报单。所有报单在开盘后的连续交易时段才会被撮合成交。

总之,集合竞价时段变化多端,非常复杂,在K线时长上需要特别关注和处理,否则您生成的是什么K线,正确与否是无从谈起的。没准您多了个莫名其妙的K线都不知道。
6.2 集合竞价和日内的休市时间段对K线合成处理的影响
6.2.1 在一个交易日中,用户接口收到的交易所中的tick为4类:
- 上一交易日结算时间之后~本交易日集合竞价开始之前收到的tick,为垃圾无效tick数据;
- 本交易日集合竞价期间收到唯一个包含本交易日开盘价的tick,为集合竞价tick数据;
- 在各个连续竞价时间段收到的tick,为连续竞价tick数据;
- 在日内的各个休市时间段收到tick,为休市tick数据。
6.2.2 以上4类tick的在K线合成方面的不同处理
- 收到无效tick,直接做丢弃处理
- 收到连续竞价tick,进行正常K线合成处理
- 收到集合竞价tick,将其时间修改为集合竞价时间段的截止时间,之后与连续竞价tick一样处理
- 收到休市tick,将其时间修改为所在休市时间段的开始时间减去1毫秒,之后与连续竞价tick一样处理
6.2.3 也可以参考利用合约交易状态信息来处理集合竞价tick
这种特别处理请参考:分析一下盘中启动CTA策略带来的第一根$K线错误
7. 该K线生成器的使用
7.1 解决回测中缺少合约信息和交易时间段信息
修改vnpy_ctastrategy\backtesting.py,修改后全部内容如下:
from collections import defaultdict
from datetime import date, datetime, timedelta
import imp
from pipes import Template
from typing import Callable, List
from functools import lru_cache, partial
import traceback
import numpy as np
from pandas import DataFrame
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from vnpy.trader.constant import (Direction, Offset, Exchange,
Interval, Status)
from vnpy.trader.database import get_database
from vnpy.trader.object import OrderData, TradeData, BarData, TickData, ContractData # hxxjava add ContractData
from vnpy.trader.utility import round_to
from vnpy.trader.optimize import (
OptimizationSetting,
check_optimization_setting,
run_bf_optimization,
run_ga_optimization
)
from .base import (
BacktestingMode,
EngineType,
STOPORDER_PREFIX,
StopOrder,
StopOrderStatus,
INTERVAL_DELTA_MAP
)
from .template import CtaTemplate
class BacktestingEngine:
""""""
engine_type = EngineType.BACKTESTING
gateway_name = "BACKTESTING"
def __init__(self):
""""""
self.vt_symbol = ""
self.symbol = ""
self.exchange = None
self.start = None
self.end = None
self.rate = 0
self.slippage = 0
self.size = 1
self.pricetick = 0
self.capital = 1_000_000
self.risk_free: float = 0
self.annual_days: int = 240
self.mode = BacktestingMode.BAR
self.inverse = False
self.strategy_class = None
self.strategy = None
self.tick: TickData
self.bar: BarData
self.datetime = None
self.interval = None
self.days = 0
self.callback = None
self.history_data = []
self.stop_order_count = 0
self.stop_orders = {}
self.active_stop_orders = {}
self.limit_order_count = 0
self.limit_orders = {}
self.active_limit_orders = {}
self.trade_count = 0
self.trades = {}
self.logs = []
self.daily_results = {}
self.daily_df = None
self.load_all_trading_hours() # hxxjava add
self.load_contracts() # hxxjava add
def clear_data(self):
"""
Clear all data of last backtesting.
"""
self.strategy = None
self.tick = None
self.bar = None
self.datetime = None
self.stop_order_count = 0
self.stop_orders.clear()
self.active_stop_orders.clear()
self.limit_order_count = 0
self.limit_orders.clear()
self.active_limit_orders.clear()
self.trade_count = 0
self.trades.clear()
self.logs.clear()
self.daily_results.clear()
def set_parameters(
self,
vt_symbol: str,
interval: Interval,
start: datetime,
rate: float,
slippage: float,
size: float,
pricetick: float,
capital: int = 0,
end: datetime = None,
mode: BacktestingMode = BacktestingMode.BAR,
inverse: bool = False,
risk_free: float = 0,
annual_days: int = 240
):
""""""
self.mode = mode
self.vt_symbol = vt_symbol
self.interval = Interval(interval)
self.rate = rate
self.slippage = slippage
self.size = size
self.pricetick = pricetick
self.start = start
self.symbol, exchange_str = self.vt_symbol.split(".")
self.exchange = Exchange(exchange_str)
self.capital = capital
self.end = end
self.mode = mode
self.inverse = inverse
self.risk_free = risk_free
self.annual_days = annual_days
def load_all_trading_hours(self) -> None: # hxxjava add end
""" """
from vnpy.trader.datafeed import get_datafeed
df = get_datafeed()
if not df.inited:
df.init()
self.all_trading_hours = df.load_all_trading_hours()
print(f"BachtestingEngine.all_trading_hours len={len(self.all_trading_hours)}")
def load_contracts(self) -> None: # hxxjava add end
""" """
database = get_database()
contracts:List[ContractData] = database.load_contract_data()
self.contracts = {}
for c in contracts:
self.contracts[c.vt_symbol] = c
print(f"BachtestingEngine.contracts len={len(self.contracts)}")
def get_trading_hours(self,strategy:CtaTemplate) -> str: # hxxjava add
"""
get vt_symbol's trading hours
"""
ths = self.all_trading_hours.get(strategy.vt_symbol.upper(),"")
return ths["trading_hours"] if ths else ""
def get_contract(self, strategy:CtaTemplate) :# -> Optional[ContractData]:
"""
Get contract data by vt_symbol.
"""
return self.contracts.get(strategy.vt_symbol,None)
def add_strategy(self, strategy_class: type, setting: dict):
""""""
self.strategy_class = strategy_class
self.strategy = strategy_class(
self, strategy_class.__name__, self.vt_symbol, setting
)
def load_data(self):
""""""
self.output("开始加载历史数据")
if not self.end:
self.end = datetime.now()
if self.start >= self.end:
self.output("起始日期必须小于结束日期")
return
self.history_data.clear() # Clear previously loaded history data
# Load 30 days of data each time and allow for progress update
total_days = (self.end - self.start).days
progress_days = max(int(total_days / 10), 1)
progress_delta = timedelta(days=progress_days)
interval_delta = INTERVAL_DELTA_MAP[self.interval]
start = self.start
end = self.start + progress_delta
progress = 0
while start < self.end:
progress_bar = "#" * int(progress * 10 + 1)
self.output(f"加载进度:{progress_bar} [{progress:.0%}]")
end = min(end, self.end) # Make sure end time stays within set range
if self.mode == BacktestingMode.BAR:
data = load_bar_data(
self.symbol,
self.exchange,
self.interval,
start,
end
)
else:
data = load_tick_data(
self.symbol,
self.exchange,
start,
end
)
self.history_data.extend(data)
progress += progress_days / total_days
progress = min(progress, 1)
start = end + interval_delta
end += progress_delta
self.output(f"历史数据加载完成,数据量:{len(self.history_data)}")
def run_backtesting(self):
""""""
if self.mode == BacktestingMode.BAR:
func = self.new_bar
else:
func = self.new_tick
self.strategy.on_init()
# Use the first [days] of history data for initializing strategy
day_count = 0
ix = 0
for ix, data in enumerate(self.history_data):
if self.datetime and data.datetime.day != self.datetime.day:
day_count += 1
if day_count >= self.days:
break
self.datetime = data.datetime
try:
self.callback(data)
except Exception:
self.output("触发异常,回测终止")
self.output(traceback.format_exc())
return
self.strategy.inited = True
self.output("策略初始化完成")
self.strategy.on_start()
self.strategy.trading = True
self.output("开始回放历史数据")
# Use the rest of history data for running backtesting
backtesting_data = self.history_data[ix:]
if len(backtesting_data) <= 1:
self.output("历史数据不足,回测终止")
return
total_size = len(backtesting_data)
batch_size = max(int(total_size / 10), 1)
for ix, i in enumerate(range(0, total_size, batch_size)):
batch_data = backtesting_data[i: i + batch_size]
for data in batch_data:
try:
func(data)
except Exception:
self.output("触发异常,回测终止")
self.output(traceback.format_exc())
return
progress = min(ix / 10, 1)
progress_bar = "=" * (ix + 1)
self.output(f"回放进度:{progress_bar} [{progress:.0%}]")
self.strategy.on_stop()
self.output("历史数据回放结束")
def calculate_result(self):
""""""
self.output("开始计算逐日盯市盈亏")
if not self.trades:
self.output("成交记录为空,无法计算")
return
# Add trade data into daily reuslt.
for trade in self.trades.values():
d = trade.datetime.date()
daily_result = self.daily_results[d]
daily_result.add_trade(trade)
# Calculate daily result by iteration.
pre_close = 0
start_pos = 0
for daily_result in self.daily_results.values():
daily_result.calculate_pnl(
pre_close,
start_pos,
self.size,
self.rate,
self.slippage,
self.inverse
)
pre_close = daily_result.close_price
start_pos = daily_result.end_pos
# Generate dataframe
results = defaultdict(list)
for daily_result in self.daily_results.values():
for key, value in daily_result.__dict__.items():
results[key].append(value)
self.daily_df = DataFrame.from_dict(results).set_index("date")
self.output("逐日盯市盈亏计算完成")
return self.daily_df
def calculate_statistics(self, df: DataFrame = None, output=True):
""""""
self.output("开始计算策略统计指标")
# Check DataFrame input exterior
if df is None:
df = self.daily_df
# Check for init DataFrame
if df is None:
# Set all statistics to 0 if no trade.
start_date = ""
end_date = ""
total_days = 0
profit_days = 0
loss_days = 0
end_balance = 0
max_drawdown = 0
max_ddpercent = 0
max_drawdown_duration = 0
total_net_pnl = 0
daily_net_pnl = 0
total_commission = 0
daily_commission = 0
total_slippage = 0
daily_slippage = 0
total_turnover = 0
daily_turnover = 0
total_trade_count = 0
daily_trade_count = 0
total_return = 0
annual_return = 0
daily_return = 0
return_std = 0
sharpe_ratio = 0
return_drawdown_ratio = 0
else:
# Calculate balance related time series data
df["balance"] = df["net_pnl"].cumsum() + self.capital
# When balance falls below 0, set daily return to 0
pre_balance = df["balance"].shift(1)
pre_balance.iloc[0] = self.capital
x = df["balance"] / pre_balance
x[x <= 0] = np.nan
df["return"] = np.log(x).fillna(0)
df["highlevel"] = (
df["balance"].rolling(
min_periods=1, window=len(df), center=False).max()
)
df["drawdown"] = df["balance"] - df["highlevel"]
df["ddpercent"] = df["drawdown"] / df["highlevel"] * 100
# Calculate statistics value
start_date = df.index[0]
end_date = df.index[-1]
total_days = len(df)
profit_days = len(df[df["net_pnl"] > 0])
loss_days = len(df[df["net_pnl"] < 0])
end_balance = df["balance"].iloc[-1]
max_drawdown = df["drawdown"].min()
max_ddpercent = df["ddpercent"].min()
max_drawdown_end = df["drawdown"].idxmin()
if isinstance(max_drawdown_end, date):
max_drawdown_start = df["balance"][:max_drawdown_end].idxmax()
max_drawdown_duration = (max_drawdown_end - max_drawdown_start).days
else:
max_drawdown_duration = 0
total_net_pnl = df["net_pnl"].sum()
daily_net_pnl = total_net_pnl / total_days
total_commission = df["commission"].sum()
daily_commission = total_commission / total_days
total_slippage = df["slippage"].sum()
daily_slippage = total_slippage / total_days
total_turnover = df["turnover"].sum()
daily_turnover = total_turnover / total_days
total_trade_count = df["trade_count"].sum()
daily_trade_count = total_trade_count / total_days
total_return = (end_balance / self.capital - 1) * 100
annual_return = total_return / total_days * self.annual_days
daily_return = df["return"].mean() * 100
return_std = df["return"].std() * 100
if return_std:
daily_risk_free = self.risk_free / np.sqrt(self.annual_days)
sharpe_ratio = (daily_return - daily_risk_free) / return_std * np.sqrt(self.annual_days)
else:
sharpe_ratio = 0
return_drawdown_ratio = -total_return / max_ddpercent
# Output
if output:
self.output("-" * 30)
self.output(f"首个交易日:\t{start_date}")
self.output(f"最后交易日:\t{end_date}")
self.output(f"总交易日:\t{total_days}")
self.output(f"盈利交易日:\t{profit_days}")
self.output(f"亏损交易日:\t{loss_days}")
self.output(f"起始资金:\t{self.capital:,.2f}")
self.output(f"结束资金:\t{end_balance:,.2f}")
self.output(f"总收益率:\t{total_return:,.2f}%")
self.output(f"年化收益:\t{annual_return:,.2f}%")
self.output(f"最大回撤: \t{max_drawdown:,.2f}")
self.output(f"百分比最大回撤: {max_ddpercent:,.2f}%")
self.output(f"最长回撤天数: \t{max_drawdown_duration}")
self.output(f"总盈亏:\t{total_net_pnl:,.2f}")
self.output(f"总手续费:\t{total_commission:,.2f}")
self.output(f"总滑点:\t{total_slippage:,.2f}")
self.output(f"总成交金额:\t{total_turnover:,.2f}")
self.output(f"总成交笔数:\t{total_trade_count}")
self.output(f"日均盈亏:\t{daily_net_pnl:,.2f}")
self.output(f"日均手续费:\t{daily_commission:,.2f}")
self.output(f"日均滑点:\t{daily_slippage:,.2f}")
self.output(f"日均成交金额:\t{daily_turnover:,.2f}")
self.output(f"日均成交笔数:\t{daily_trade_count}")
self.output(f"日均收益率:\t{daily_return:,.2f}%")
self.output(f"收益标准差:\t{return_std:,.2f}%")
self.output(f"Sharpe Ratio:\t{sharpe_ratio:,.2f}")
self.output(f"收益回撤比:\t{return_drawdown_ratio:,.2f}")
statistics = {
"start_date": start_date,
"end_date": end_date,
"total_days": total_days,
"profit_days": profit_days,
"loss_days": loss_days,
"capital": self.capital,
"end_balance": end_balance,
"max_drawdown": max_drawdown,
"max_ddpercent": max_ddpercent,
"max_drawdown_duration": max_drawdown_duration,
"total_net_pnl": total_net_pnl,
"daily_net_pnl": daily_net_pnl,
"total_commission": total_commission,
"daily_commission": daily_commission,
"total_slippage": total_slippage,
"daily_slippage": daily_slippage,
"total_turnover": total_turnover,
"daily_turnover": daily_turnover,
"total_trade_count": total_trade_count,
"daily_trade_count": daily_trade_count,
"total_return": total_return,
"annual_return": annual_return,
"daily_return": daily_return,
"return_std": return_std,
"sharpe_ratio": sharpe_ratio,
"return_drawdown_ratio": return_drawdown_ratio,
}
# Filter potential error infinite value
for key, value in statistics.items():
if value in (np.inf, -np.inf):
value = 0
statistics[key] = np.nan_to_num(value)
self.output("策略统计指标计算完成")
return statistics
def show_chart(self, df: DataFrame = None):
""""""
# Check DataFrame input exterior
if df is None:
df = self.daily_df
# Check for init DataFrame
if df is None:
return
fig = make_subplots(
rows=4,
cols=1,
subplot_titles=["Balance", "Drawdown", "Daily Pnl", "Pnl Distribution"],
vertical_spacing=0.06
)
balance_line = go.Scatter(
x=df.index,
y=df["balance"],
mode="lines",
name="Balance"
)
drawdown_scatter = go.Scatter(
x=df.index,
y=df["drawdown"],
fillcolor="red",
fill='tozeroy',
mode="lines",
name="Drawdown"
)
pnl_bar = go.Bar(y=df["net_pnl"], name="Daily Pnl")
pnl_histogram = go.Histogram(x=df["net_pnl"], nbinsx=100, name="Days")
fig.add_trace(balance_line, row=1, col=1)
fig.add_trace(drawdown_scatter, row=2, col=1)
fig.add_trace(pnl_bar, row=3, col=1)
fig.add_trace(pnl_histogram, row=4, col=1)
fig.update_layout(height=1000, width=1000)
fig.show()
def run_bf_optimization(self, optimization_setting: OptimizationSetting, output=True):
""""""
if not check_optimization_setting(optimization_setting):
return
evaluate_func: callable = wrap_evaluate(self, optimization_setting.target_name)
results = run_bf_optimization(
evaluate_func,
optimization_setting,
get_target_value,
output=self.output
)
if output:
for result in results:
msg: str = f"参数:{result[0]}, 目标:{result[1]}"
self.output(msg)
return results
run_optimization = run_bf_optimization
def run_ga_optimization(self, optimization_setting: OptimizationSetting, output=True):
""""""
if not check_optimization_setting(optimization_setting):
return
evaluate_func: callable = wrap_evaluate(self, optimization_setting.target_name)
results = run_ga_optimization(
evaluate_func,
optimization_setting,
get_target_value,
output=self.output
)
if output:
for result in results:
msg: str = f"参数:{result[0]}, 目标:{result[1]}"
self.output(msg)
return results
def update_daily_close(self, price: float):
""""""
d = self.datetime.date()
daily_result = self.daily_results.get(d, None)
if daily_result:
daily_result.close_price = price
else:
self.daily_results[d] = DailyResult(d, price)
def new_bar(self, bar: BarData):
""""""
self.bar = bar
self.datetime = bar.datetime
self.cross_limit_order()
self.cross_stop_order()
self.strategy.on_bar(bar)
self.update_daily_close(bar.close_price)
def new_tick(self, tick: TickData):
""""""
self.tick = tick
self.datetime = tick.datetime
self.cross_limit_order()
self.cross_stop_order()
self.strategy.on_tick(tick)
self.update_daily_close(tick.last_price)
def cross_limit_order(self):
"""
Cross limit order with last bar/tick data.
"""
if self.mode == BacktestingMode.BAR:
long_cross_price = self.bar.low_price
short_cross_price = self.bar.high_price
long_best_price = self.bar.open_price
short_best_price = self.bar.open_price
else:
long_cross_price = self.tick.ask_price_1
short_cross_price = self.tick.bid_price_1
long_best_price = long_cross_price
short_best_price = short_cross_price
for order in list(self.active_limit_orders.values()):
# Push order update with status "not traded" (pending).
if order.status == Status.SUBMITTING:
order.status = Status.NOTTRADED
self.strategy.on_order(order)
# Check whether limit orders can be filled.
long_cross = (
order.direction == Direction.LONG
and order.price >= long_cross_price
and long_cross_price > 0
)
short_cross = (
order.direction == Direction.SHORT
and order.price <= short_cross_price
and short_cross_price > 0
)
if not long_cross and not short_cross:
continue
# Push order udpate with status "all traded" (filled).
order.traded = order.volume
order.status = Status.ALLTRADED
self.strategy.on_order(order)
if order.vt_orderid in self.active_limit_orders:
self.active_limit_orders.pop(order.vt_orderid)
# Push trade update
self.trade_count += 1
if long_cross:
trade_price = min(order.price, long_best_price)
pos_change = order.volume
else:
trade_price = max(order.price, short_best_price)
pos_change = -order.volume
trade = TradeData(
symbol=order.symbol,
exchange=order.exchange,
orderid=order.orderid,
tradeid=str(self.trade_count),
direction=order.direction,
offset=order.offset,
price=trade_price,
volume=order.volume,
datetime=self.datetime,
gateway_name=self.gateway_name,
)
self.strategy.pos += pos_change
self.strategy.on_trade(trade)
self.trades[trade.vt_tradeid] = trade
def cross_stop_order(self):
"""
Cross stop order with last bar/tick data.
"""
if self.mode == BacktestingMode.BAR:
long_cross_price = self.bar.high_price
short_cross_price = self.bar.low_price
long_best_price = self.bar.open_price
short_best_price = self.bar.open_price
else:
long_cross_price = self.tick.last_price
short_cross_price = self.tick.last_price
long_best_price = long_cross_price
short_best_price = short_cross_price
for stop_order in list(self.active_stop_orders.values()):
# Check whether stop order can be triggered.
long_cross = (
stop_order.direction == Direction.LONG
and stop_order.price <= long_cross_price
)
short_cross = (
stop_order.direction == Direction.SHORT
and stop_order.price >= short_cross_price
)
if not long_cross and not short_cross:
continue
# Create order data.
self.limit_order_count += 1
order = OrderData(
symbol=self.symbol,
exchange=self.exchange,
orderid=str(self.limit_order_count),
direction=stop_order.direction,
offset=stop_order.offset,
price=stop_order.price,
volume=stop_order.volume,
traded=stop_order.volume,
status=Status.ALLTRADED,
gateway_name=self.gateway_name,
datetime=self.datetime
)
self.limit_orders[order.vt_orderid] = order
# Create trade data.
if long_cross:
trade_price = max(stop_order.price, long_best_price)
pos_change = order.volume
else:
trade_price = min(stop_order.price, short_best_price)
pos_change = -order.volume
self.trade_count += 1
trade = TradeData(
symbol=order.symbol,
exchange=order.exchange,
orderid=order.orderid,
tradeid=str(self.trade_count),
direction=order.direction,
offset=order.offset,
price=trade_price,
volume=order.volume,
datetime=self.datetime,
gateway_name=self.gateway_name,
)
self.trades[trade.vt_tradeid] = trade
# Update stop order.
stop_order.vt_orderids.append(order.vt_orderid)
stop_order.status = StopOrderStatus.TRIGGERED
if stop_order.stop_orderid in self.active_stop_orders:
self.active_stop_orders.pop(stop_order.stop_orderid)
# Push update to strategy.
self.strategy.on_stop_order(stop_order)
self.strategy.on_order(order)
self.strategy.pos += pos_change
self.strategy.on_trade(trade)
def load_bar(
self,
vt_symbol: str,
days: int,
interval: Interval,
callback: Callable,
use_database: bool
) -> List[BarData]:
""""""
self.days = days
self.callback = callback
return []
def load_tick(self, vt_symbol: str, days: int, callback: Callable) -> List[TickData]:
""""""
self.days = days
self.callback = callback
return []
def send_order(
self,
strategy: CtaTemplate,
direction: Direction,
offset: Offset,
price: float,
volume: float,
stop: bool,
lock: bool,
net: bool
):
""""""
price = round_to(price, self.pricetick)
if stop:
vt_orderid = self.send_stop_order(direction, offset, price, volume)
else:
vt_orderid = self.send_limit_order(direction, offset, price, volume)
return [vt_orderid]
def send_stop_order(
self,
direction: Direction,
offset: Offset,
price: float,
volume: float
):
""""""
self.stop_order_count += 1
stop_order = StopOrder(
vt_symbol=self.vt_symbol,
direction=direction,
offset=offset,
price=price,
volume=volume,
datetime=self.datetime,
stop_orderid=f"{STOPORDER_PREFIX}.{self.stop_order_count}",
strategy_name=self.strategy.strategy_name,
)
self.active_stop_orders[stop_order.stop_orderid] = stop_order
self.stop_orders[stop_order.stop_orderid] = stop_order
return stop_order.stop_orderid
def send_limit_order(
self,
direction: Direction,
offset: Offset,
price: float,
volume: float
):
""""""
self.limit_order_count += 1
order = OrderData(
symbol=self.symbol,
exchange=self.exchange,
orderid=str(self.limit_order_count),
direction=direction,
offset=offset,
price=price,
volume=volume,
status=Status.SUBMITTING,
gateway_name=self.gateway_name,
datetime=self.datetime
)
self.active_limit_orders[order.vt_orderid] = order
self.limit_orders[order.vt_orderid] = order
return order.vt_orderid
def cancel_order(self, strategy: CtaTemplate, vt_orderid: str):
"""
Cancel order by vt_orderid.
"""
if vt_orderid.startswith(STOPORDER_PREFIX):
self.cancel_stop_order(strategy, vt_orderid)
else:
self.cancel_limit_order(strategy, vt_orderid)
def cancel_stop_order(self, strategy: CtaTemplate, vt_orderid: str):
""""""
if vt_orderid not in self.active_stop_orders:
return
stop_order = self.active_stop_orders.pop(vt_orderid)
stop_order.status = StopOrderStatus.CANCELLED
self.strategy.on_stop_order(stop_order)
def cancel_limit_order(self, strategy: CtaTemplate, vt_orderid: str):
""""""
if vt_orderid not in self.active_limit_orders:
return
order = self.active_limit_orders.pop(vt_orderid)
order.status = Status.CANCELLED
self.strategy.on_order(order)
def cancel_all(self, strategy: CtaTemplate):
"""
Cancel all orders, both limit and stop.
"""
vt_orderids = list(self.active_limit_orders.keys())
for vt_orderid in vt_orderids:
self.cancel_limit_order(strategy, vt_orderid)
stop_orderids = list(self.active_stop_orders.keys())
for vt_orderid in stop_orderids:
self.cancel_stop_order(strategy, vt_orderid)
def write_log(self, msg: str, strategy: CtaTemplate = None):
"""
Write log message.
"""
msg = f"{self.datetime}\t{msg}"
self.logs.append(msg)
def send_email(self, msg: str, strategy: CtaTemplate = None):
"""
Send email to default receiver.
"""
pass
def sync_strategy_data(self, strategy: CtaTemplate):
"""
Sync strategy data into json file.
"""
pass
def get_engine_type(self):
"""
Return engine type.
"""
return self.engine_type
def get_pricetick(self, strategy: CtaTemplate):
"""
Return contract pricetick data.
"""
return self.pricetick
def put_strategy_event(self, strategy: CtaTemplate):
"""
Put an event to update strategy status.
"""
pass
def output(self, msg):
"""
Output message of backtesting engine.
"""
print(f"{datetime.now()}\t{msg}")
def get_all_trades(self):
"""
Return all trade data of current backtesting result.
"""
return list(self.trades.values())
def get_all_orders(self):
"""
Return all limit order data of current backtesting result.
"""
return list(self.limit_orders.values())
def get_all_daily_results(self):
"""
Return all daily result data.
"""
return list(self.daily_results.values())
class DailyResult:
""""""
def __init__(self, date: date, close_price: float):
""""""
self.date = date
self.close_price = close_price
self.pre_close = 0
self.trades = []
self.trade_count = 0
self.start_pos = 0
self.end_pos = 0
self.turnover = 0
self.commission = 0
self.slippage = 0
self.trading_pnl = 0
self.holding_pnl = 0
self.total_pnl = 0
self.net_pnl = 0
def add_trade(self, trade: TradeData):
""""""
self.trades.append(trade)
def calculate_pnl(
self,
pre_close: float,
start_pos: float,
size: int,
rate: float,
slippage: float,
inverse: bool
):
""""""
# If no pre_close provided on the first day,
# use value 1 to avoid zero division error
if pre_close:
self.pre_close = pre_close
else:
self.pre_close = 1
# Holding pnl is the pnl from holding position at day start
self.start_pos = start_pos
self.end_pos = start_pos
if not inverse: # For normal contract
self.holding_pnl = self.start_pos * \
(self.close_price - self.pre_close) * size
else: # For crypto currency inverse contract
self.holding_pnl = self.start_pos * \
(1 / self.pre_close - 1 / self.close_price) * size
# Trading pnl is the pnl from new trade during the day
self.trade_count = len(self.trades)
for trade in self.trades:
if trade.direction == Direction.LONG:
pos_change = trade.volume
else:
pos_change = -trade.volume
self.end_pos += pos_change
# For normal contract
if not inverse:
turnover = trade.volume * size * trade.price
self.trading_pnl += pos_change * \
(self.close_price - trade.price) * size
self.slippage += trade.volume * size * slippage
# For crypto currency inverse contract
else:
turnover = trade.volume * size / trade.price
self.trading_pnl += pos_change * \
(1 / trade.price - 1 / self.close_price) * size
self.slippage += trade.volume * size * slippage / (trade.price ** 2)
self.turnover += turnover
self.commission += turnover * rate
# Net pnl takes account of commission and slippage cost
self.total_pnl = self.trading_pnl + self.holding_pnl
self.net_pnl = self.total_pnl - self.commission - self.slippage
@lru_cache(maxsize=999)
def load_bar_data(
symbol: str,
exchange: Exchange,
interval: Interval,
start: datetime,
end: datetime
):
""""""
database = get_database()
return database.load_bar_data(
symbol, exchange, interval, start, end
)
@lru_cache(maxsize=999)
def load_tick_data(
symbol: str,
exchange: Exchange,
start: datetime,
end: datetime
):
""""""
database = get_database()
return database.load_tick_data(
symbol, exchange, start, end
)
def evaluate(
target_name: str,
strategy_class: CtaTemplate,
vt_symbol: str,
interval: Interval,
start: datetime,
rate: float,
slippage: float,
size: float,
pricetick: float,
capital: int,
end: datetime,
mode: BacktestingMode,
inverse: bool,
setting: dict
):
"""
Function for running in multiprocessing.pool
"""
engine = BacktestingEngine()
engine.set_parameters(
vt_symbol=vt_symbol,
interval=interval,
start=start,
rate=rate,
slippage=slippage,
size=size,
pricetick=pricetick,
capital=capital,
end=end,
mode=mode,
inverse=inverse
)
engine.add_strategy(strategy_class, setting)
engine.load_data()
engine.run_backtesting()
engine.calculate_result()
statistics = engine.calculate_statistics(output=False)
target_value = statistics[target_name]
return (str(setting), target_value, statistics)
def wrap_evaluate(engine: BacktestingEngine, target_name: str) -> callable:
"""
Wrap evaluate function with given setting from backtesting engine.
"""
func: callable = partial(
evaluate,
target_name,
engine.strategy_class,
engine.vt_symbol,
engine.interval,
engine.start,
engine.rate,
engine.slippage,
engine.size,
engine.pricetick,
engine.capital,
engine.end,
engine.mode,
engine.inverse
)
return func
def get_target_value(result: list) -> float:
"""
Get target value for sorting optimization results.
"""
return result[1]