
相关文件:
1. vnpy/chart/item.py
from abc import abstractmethod
from typing import List, Dict, Tuple
import numpy as np
import pyqtgraph as pg
import talib
from vnpy.trader.ui import QtCore, QtGui, QtWidgets
from vnpy.trader.object import BarData
from .base import BLACK_COLOR, UP_COLOR, DOWN_COLOR, PEN_WIDTH, BAR_WIDTH
from .manager import BarManager
class ChartItem(pg.GraphicsObject):
""""""
def __init__(self, manager: BarManager):
""""""
super().__init__()
self._manager: BarManager = manager
self._bar_picutures: Dict[int, QtGui.QPicture] = {}
self._item_picuture: QtGui.QPicture = None
self._black_brush: QtGui.QBrush = pg.mkBrush(color=UP_COLOR)
self._up_pen: QtGui.QPen = pg.mkPen(
color=UP_COLOR, width=PEN_WIDTH
)
self._up_brush: QtGui.QBrush = pg.mkBrush(color=UP_COLOR)
self._down_pen: QtGui.QPen = pg.mkPen(
color=DOWN_COLOR, width=PEN_WIDTH
)
self._down_brush: QtGui.QBrush = pg.mkBrush(color=DOWN_COLOR)
self._rect_area: Tuple[float, float] = None
# Very important! Only redraw the visible part and improve speed a lot.
self.setFlag(self.ItemUsesExtendedStyleOption)
@abstractmethod
def _draw_bar_picture(self, ix: int, bar: BarData) -> QtGui.QPicture:
"""
Draw picture for specific bar.
"""
pass
@abstractmethod
def boundingRect(self) -> QtCore.QRectF:
"""
Get bounding rectangles for item.
"""
pass
@abstractmethod
def get_y_range(self, min_ix: int = None, max_ix: int = None) -> Tuple[float, float]:
"""
Get range of y-axis with given x-axis range.
If min_ix and max_ix not specified, then return range with whole data set.
"""
pass
@abstractmethod
def get_info_text(self, ix: int) -> str:
"""
Get information text to show by cursor.
"""
pass
def update_history(self, history: List[BarData]) -> BarData:
"""
Update a list of bar data.
"""
self._bar_picutures.clear()
bars = self._manager.get_all_bars()
for ix, bar in enumerate(bars):
self._bar_picutures[ix] = None
self.update()
def update_bar(self, bar: BarData) -> BarData:
"""
Update single bar data.
"""
ix = self._manager.get_index(bar.datetime)
self._bar_picutures[ix] = None
self.update()
def update(self) -> None:
"""
Refresh the item.
"""
if self.scene():
self.scene().update()
def paint(
self,
painter: QtGui.QPainter,
opt: QtWidgets.QStyleOptionGraphicsItem,
w: QtWidgets.QWidget
):
"""
Reimplement the paint method of parent class.
This function is called by external QGraphicsView.
"""
rect = opt.exposedRect
min_ix = int(rect.left())
max_ix = int(rect.right())
max_ix = min(max_ix, len(self._bar_picutures))
rect_area = (min_ix, max_ix)
if rect_area != self._rect_area or not self._item_picuture:
self._rect_area = rect_area
self._draw_item_picture(min_ix, max_ix)
self._item_picuture.play(painter)
def _draw_item_picture(self, min_ix: int, max_ix: int) -> None:
"""
Draw the picture of item in specific range.
"""
self._item_picuture = QtGui.QPicture()
painter = QtGui.QPainter(self._item_picuture)
for ix in range(min_ix, max_ix):
bar_picture = self._bar_picutures[ix]
if bar_picture is None:
bar = self._manager.get_bar(ix)
bar_picture = self._draw_bar_picture(ix, bar)
self._bar_picutures[ix] = bar_picture
bar_picture.play(painter)
painter.end()
def clear_all(self) -> None:
"""
Clear all data in the item.
"""
self._item_picuture = None
self._bar_picutures.clear()
self.update()
class CandleItem(ChartItem):
""""""
def __init__(self, manager: BarManager):
""""""
super().__init__(manager)
def _draw_bar_picture(self, ix: int, bar: BarData) -> QtGui.QPicture:
""""""
# Create objects
candle_picture = QtGui.QPicture()
painter = QtGui.QPainter(candle_picture)
# Set painter color
if bar.close_price >= bar.open_price:
painter.setPen(self._up_pen)
painter.setBrush(self._black_brush)
else:
painter.setPen(self._down_pen)
painter.setBrush(self._down_brush)
# Draw candle shadow
if bar.high_price > bar.low_price:
painter.drawLine(
QtCore.QPointF(ix, bar.high_price),
QtCore.QPointF(ix, bar.low_price)
)
# Draw candle body
if bar.open_price == bar.close_price:
painter.drawLine(
QtCore.QPointF(ix - BAR_WIDTH, bar.open_price),
QtCore.QPointF(ix + BAR_WIDTH, bar.open_price),
)
else:
rect = QtCore.QRectF(
ix - BAR_WIDTH,
bar.open_price,
BAR_WIDTH * 2,
bar.close_price - bar.open_price
)
painter.drawRect(rect)
# Finish
painter.end()
return candle_picture
def boundingRect(self) -> QtCore.QRectF:
""""""
min_price, max_price = self._manager.get_price_range()
rect = QtCore.QRectF(
0,
min_price,
len(self._bar_picutures),
max_price - min_price
)
return rect
def get_y_range(self, min_ix: int = None, max_ix: int = None) -> Tuple[float, float]:
"""
Get range of y-axis with given x-axis range.
If min_ix and max_ix not specified, then return range with whole data set.
"""
min_price, max_price = self._manager.get_price_range(min_ix, max_ix)
return min_price, max_price
def get_info_text(self, ix: int) -> str:
"""
Get information text to show by cursor.
"""
bar = self._manager.get_bar(ix)
if bar:
words = [
"Date",
bar.datetime.strftime("%Y-%m-%d"),
"",
"Time",
bar.datetime.strftime("%H:%M"),
"",
"Open",
str(bar.open_price),
"",
"High",
str(bar.high_price),
"",
"Low",
str(bar.low_price),
"",
"Close",
str(bar.close_price)
]
text = " ".join(words)
else:
text = ""
return text
class VolumeItem(ChartItem):
""""""
def __init__(self, manager: BarManager):
""""""
super().__init__(manager)
def _draw_bar_picture(self, ix: int, bar: BarData) -> QtGui.QPicture:
""""""
# Create objects
volume_picture = QtGui.QPicture()
painter = QtGui.QPainter(volume_picture)
# Set painter color
if bar.close_price >= bar.open_price:
painter.setPen(self._up_pen)
painter.setBrush(self._up_brush)
else:
painter.setPen(self._down_pen)
painter.setBrush(self._down_brush)
# Draw volume body
rect = QtCore.QRectF(
ix - BAR_WIDTH,
0,
BAR_WIDTH * 2,
bar.volume
)
painter.drawRect(rect)
# Finish
painter.end()
return volume_picture
def boundingRect(self) -> QtCore.QRectF:
""""""
min_volume, max_volume = self._manager.get_volume_range()
rect = QtCore.QRectF(
0,
min_volume,
len(self._bar_picutures),
max_volume - min_volume
)
return rect
def get_y_range(self, min_ix: int = None, max_ix: int = None) -> Tuple[float, float]:
"""
Get range of y-axis with given x-axis range.
If min_ix and max_ix not specified, then return range with whole data set.
"""
min_volume, max_volume = self._manager.get_volume_range(min_ix, max_ix)
return min_volume, max_volume
def get_info_text(self, ix: int) -> str:
"""
Get information text to show by cursor.
"""
bar = self._manager.get_bar(ix)
if bar:
text = f"Volume {bar.volume}"
else:
text = ""
return text
class LineItem(CandleItem):
""""""
def __init__(self, manager: BarManager):
""""""
super().__init__(manager)
self.yellow_pen: QtGui.QPen = pg.mkPen(color=(255, 255, 0), width=1, style=QtCore.Qt.DashDotLine)
def get_info_text(self, ix: int) -> str:
return ""
def _draw_bar_picture(self, ix: int, bar: BarData) -> QtGui.QPicture:
""""""
last_bar = self._manager.get_bar(ix - 1)
# Create objects
picture = QtGui.QPicture()
painter = QtGui.QPainter(picture)
# Set painter color
painter.setPen(self.yellow_pen)
# Draw Line
end_point = QtCore.QPointF(ix, bar.close_price)
if last_bar:
start_point = QtCore.QPointF(ix - 1, last_bar.close_price)
else:
start_point = end_point
painter.drawLine(start_point, end_point)
# Finish
painter.end()
return picture
class SmaItem(CandleItem):
""""""
def __init__(self, manager: BarManager):
""""""
super().__init__(manager)
self.blue_pen: QtGui.QPen = pg.mkPen(color=(100, 100, 255), width=2)
self.sma_window = 10
self.sma_data: Dict[int, float] = {}
def get_sma_value(self, ix: int) -> float:
""""""
if ix < 0:
return 0
# When initialize, calculate all sma value
if not self.sma_data:
bars = self._manager.get_all_bars()
close_data = [bar.close_price for bar in bars]
sma_array = talib.SMA(np.array(close_data), self.sma_window)
for n, value in enumerate(sma_array):
self.sma_data[n] = value
# Return if already calcualted
if ix in self.sma_data:
return self.sma_data[ix]
# Else calculate new value
close_data = []
for n in range(ix - self.sma_window, ix + 1):
bar = self._manager.get_bar(n)
close_data.append(bar.close_price)
sma_array = talib.SMA(np.array(close_data), self.sma_window)
sma_value = sma_array[-1]
self.sma_data[ix] = sma_value
return sma_value
def _draw_bar_picture(self, ix: int, bar: BarData) -> QtGui.QPicture:
""""""
sma_value = self.get_sma_value(ix)
last_sma_value = self.get_sma_value(ix - 1)
# Create objects
picture = QtGui.QPicture()
painter = QtGui.QPainter(picture)
# Set painter color
painter.setPen(self.blue_pen)
# Draw Line
start_point = QtCore.QPointF(ix - 1, last_sma_value)
end_point = QtCore.QPointF(ix, sma_value)
painter.drawLine(start_point, end_point)
# Finish
painter.end()
return picture
def get_info_text(self, ix: int) -> str:
""""""
if ix in self.sma_data:
sma_value = self.sma_data[ix]
text = "SMA {:.8f}".format(sma_value)
else:
text = "SMA -"
return text
class RsiItem(ChartItem):
""""""
def __init__(self, manager: BarManager):
""""""
super().__init__(manager)
self.white_pen: QtGui.QPen = pg.mkPen(color=(255, 255, 255), width=1)
self.yellow_pen: QtGui.QPen = pg.mkPen(color=(255, 255, 0), width=2)
self.rsi_window = 14
self.rsi_data: Dict[int, float] = {}
def get_rsi_value(self, ix: int) -> float:
""""""
if ix < 0:
return 50
# When initialize, calculate all rsi value
if not self.rsi_data:
bars = self._manager.get_all_bars()
close_data = [bar.close_price for bar in bars]
rsi_array = talib.RSI(np.array(close_data), self.rsi_window)
for n, value in enumerate(rsi_array):
self.rsi_data[n] = value
# Return if already calcualted
if ix in self.rsi_data:
return self.rsi_data[ix]
# Else calculate new value
close_data = []
for n in range(ix - self.rsi_window, ix + 1):
bar = self._manager.get_bar(n)
close_data.append(bar.close_price)
rsi_array = talib.RSI(np.array(close_data), self.rsi_window)
rsi_value = rsi_array[-1]
self.rsi_data[ix] = rsi_value
return rsi_value
def _draw_bar_picture(self, ix: int, bar: BarData) -> QtGui.QPicture:
""""""
rsi_value = self.get_rsi_value(ix)
last_rsi_value = self.get_rsi_value(ix - 1)
# Create objects
picture = QtGui.QPicture()
painter = QtGui.QPainter(picture)
# Draw RSI line
painter.setPen(self.yellow_pen)
if np.isnan(last_rsi_value) or np.isnan(rsi_value):
# print(ix - 1, last_rsi_value,ix, rsi_value,)
pass
else:
end_point = QtCore.QPointF(ix, rsi_value)
start_point = QtCore.QPointF(ix - 1, last_rsi_value)
painter.drawLine(start_point, end_point)
# Draw oversold/overbought line
painter.setPen(self.white_pen)
painter.drawLine(
QtCore.QPointF(ix, 70),
QtCore.QPointF(ix - 1, 70),
)
painter.drawLine(
QtCore.QPointF(ix, 30),
QtCore.QPointF(ix - 1, 30),
)
# Finish
painter.end()
return picture
def boundingRect(self) -> QtCore.QRectF:
""""""
# min_price, max_price = self._manager.get_price_range()
rect = QtCore.QRectF(
0,
0,
len(self._bar_picutures),
100
)
return rect
def get_y_range(self, min_ix: int = None, max_ix: int = None) -> Tuple[float, float]:
""" """
return 0, 100
def get_info_text(self, ix: int) -> str:
""""""
if ix in self.rsi_data:
rsi_value = self.rsi_data[ix]
text = "RSI {:.8f}".format(rsi_value)
# print(text)
else:
text = "RSI -"
return text
def to_int(value: float) -> int:
""""""
return int(round(value, 0))
# 将y方向的显示范围扩大到1.1
def adjust_range(in_range: Tuple[float, float]) -> Tuple[float, float]:
ret_range: Tuple[float, float]
diff = abs(in_range[0] - in_range[1])
ret_range = (in_range[0] - diff * 0.05, in_range[1] + diff * 0.05)
return ret_range
class MacdItem(ChartItem):
""""""
_values_ranges: Dict[Tuple[int, int], Tuple[float, float]] = {}
last_range: Tuple[int, int] = (-1, -1) # 最新显示K线索引范围
def __init__(self, manager: BarManager):
""""""
super().__init__(manager)
self.white_pen: QtGui.QPen = pg.mkPen(color=(255, 255, 255), width=1)
self.yellow_pen: QtGui.QPen = pg.mkPen(color=(255, 255, 0), width=1)
self.red_pen: QtGui.QPen = pg.mkPen(color=(237, 97, 96), width=1)
self.green_pen: QtGui.QPen = pg.mkPen(color=(0, 192, 135), width=1)
self.short_window = 12
self.long_window = 26
self.M = 9
self.macd_data: Dict[int, Tuple[float, float, float]] = {}
def get_macd_value(self, ix: int) -> Tuple[float, float, float]:
""""""
if ix < 0:
return 0.0, 0.0, 0.0
# When initialize, calculate all macd value
if not self.macd_data:
bars = self._manager.get_all_bars()
close_data = [bar.close_price for bar in bars]
diffs, deas, macds = talib.MACD(np.array(close_data),
fastperiod=self.short_window,
slowperiod=self.long_window,
signalperiod=self.M)
for n in range(0, len(diffs)):
self.macd_data[n] = (diffs[n], deas[n], macds[n])
# Return if already calcualted
if ix in self.macd_data:
return self.macd_data[ix]
# Else calculate new value
close_data = []
for n in range(ix - self.long_window - self.M + 1, ix + 1):
bar = self._manager.get_bar(n)
close_data.append(bar.close_price)
diffs, deas, macds = talib.MACD(np.array(close_data),
fastperiod=self.short_window,
slowperiod=self.long_window,
signalperiod=self.M)
diff, dea, macd = diffs[-1], deas[-1], macds[-1]
self.macd_data[ix] = (diff, dea, macd)
return diff, dea, macd
def _draw_bar_picture(self, ix: int, bar: BarData) -> QtGui.QPicture:
""""""
macd_value = self.get_macd_value(ix)
last_macd_value = self.get_macd_value(ix - 1)
# # Create objects
picture = QtGui.QPicture()
painter = QtGui.QPainter(picture)
# # Draw macd lines
if np.isnan(macd_value[0]) or np.isnan(last_macd_value[0]):
# print("略过macd lines0")
pass
else:
end_point0 = QtCore.QPointF(ix, macd_value[0])
start_point0 = QtCore.QPointF(ix - 1, last_macd_value[0])
painter.setPen(self.white_pen)
painter.drawLine(start_point0, end_point0)
if np.isnan(macd_value[1]) or np.isnan(last_macd_value[1]):
# print("略过macd lines1")
pass
else:
end_point1 = QtCore.QPointF(ix, macd_value[1])
start_point1 = QtCore.QPointF(ix - 1, last_macd_value[1])
painter.setPen(self.yellow_pen)
painter.drawLine(start_point1, end_point1)
if not np.isnan(macd_value[2]):
if (macd_value[2] > 0):
painter.setPen(self.green_pen)
painter.setBrush(pg.mkBrush(0, 192, 135))
else:
painter.setPen(self.red_pen)
painter.setBrush(pg.mkBrush(237, 97, 96))
painter.drawRect(QtCore.QRectF(ix - 0.3, 0, 0.6, macd_value[2]))
else:
# print("略过macd lines2")
pass
painter.end()
return picture
def boundingRect(self) -> QtCore.QRectF:
""""""
min_y, max_y = self.get_y_range()
rect = QtCore.QRectF(
0,
min_y,
len(self._bar_picutures),
max_y
)
return rect
def get_y_range(self, min_ix: int = None, max_ix: int = None) -> Tuple[float, float]:
# 获得3个指标在y轴方向的范围
# hxxjava 修改,2020-6-29
# 当显示范围改变时,min_ix,max_ix的值不为None,当显示范围不变时,min_ix,max_ix的值不为None,
offset = max(self.short_window, self.long_window) + self.M - 1
if not self.macd_data or len(self.macd_data) < offset:
return 0.0, 1.0
# print("len of range dict:",len(self._values_ranges),",macd_data:",len(self.macd_data),(min_ix,max_ix))
if min_ix != None: # 调整最小K线索引
min_ix = max(min_ix, offset)
if max_ix != None: # 调整最大K线索引
max_ix = min(max_ix, len(self.macd_data) - 1)
last_range = (min_ix, max_ix) # 请求的最新范围
if last_range == (None, None): # 当显示范围不变时
if self.last_range in self._values_ranges:
# 如果y方向范围已经保存
# 读取y方向范围
result = self._values_ranges[self.last_range]
# print("1:",self.last_range,result)
return adjust_range(result)
else:
# 如果y方向范围没有保存
# 从macd_data重新计算y方向范围
min_ix, max_ix = 0, len(self.macd_data) - 1
macd_list = list(self.macd_data.values())[min_ix:max_ix + 1]
ndarray = np.array(macd_list)
max_price = np.nanmax(ndarray)
min_price = np.nanmin(ndarray)
# 保存y方向范围,同时返回结果
result = (min_price, max_price)
self.last_range = (min_ix, max_ix)
self._values_ranges[self.last_range] = result
# print("2:",self.last_range,result)
return adjust_range(result)
""" 以下为显示范围变化时 """
if last_range in self._values_ranges:
# 该范围已经保存过y方向范围
# 取得y方向范围,返回结果
result = self._values_ranges[last_range]
# print("3:",last_range,result)
return adjust_range(result)
# 该范围没有保存过y方向范围
# 从macd_data重新计算y方向范围
macd_list = list(self.macd_data.values())[min_ix:max_ix + 1]
ndarray = np.array(macd_list)
max_price = np.nanmax(ndarray)
min_price = np.nanmin(ndarray)
# 取得y方向范围,返回结果
result = (min_price, max_price)
self.last_range = last_range
self._values_ranges[self.last_range] = result
# print("4:",self.last_range,result)
return adjust_range(result)
def get_info_text(self, ix: int) -> str:
# """"""
if ix in self.macd_data:
diff, dea, macd = self.macd_data[ix]
words = [
"diff {:.8f}".format(diff), " ",
"dea {:.8f}".format(diff), " ",
"macd {:.8f}".format(diff)
]
text = " ".join(words)
else:
text = "diff - \ndea - \nmacd -"
return text
2. vnpy/chart/base.py
from vnpy.trader.ui import QtGui
WHITE_COLOR = (255, 255, 255)
BLACK_COLOR = (0, 0, 0)
GREY_COLOR = (100, 100, 100)
UP_COLOR = (0, 192, 135)
DOWN_COLOR = (237, 97, 96)
CURSOR_COLOR = (100, 110, 120)
PEN_WIDTH = 1
BAR_WIDTH = 0.3
AXIS_WIDTH = 0.8
NORMAL_FONT = QtGui.QFont("Arial", 9)
def to_int(value: float) -> int:
""""""
return int(round(value, 0))
3. vnpy/chart/widget.py
from typing import List, Dict, Type
import pyqtgraph as pg
from vnpy.trader.ui import QtGui, QtWidgets, QtCore
from vnpy.trader.object import BarData
from vnpy.event import Event, EventEngine
from .manager import BarManager
from .base import (
GREY_COLOR, WHITE_COLOR, CURSOR_COLOR, BLACK_COLOR,
to_int, NORMAL_FONT
)
from .axis import DatetimeAxis
from .item import ChartItem
from vnpy.app.cta_strategy.base import (
EVENT_CTA_TICK,
EVENT_CTA_BAR,
EVENT_CTA_ORDER,
EVENT_CTA_TRADE,
EVENT_CTA_HISTORY_BAR
)
pg.setConfigOptions(antialias=True)
class ChartWidget(pg.PlotWidget):
""""""
MIN_BAR_COUNT = 100
def __init__(self, parent: QtWidgets.QWidget = None):
""""""
super().__init__(parent)
self._manager: BarManager = BarManager()
self._plots: Dict[str, pg.PlotItem] = {}
self._items: Dict[str, ChartItem] = {}
self._item_plot_map: Dict[ChartItem, pg.PlotItem] = {}
self._first_plot: pg.PlotItem = None
self._cursor: ChartCursor = None
self._right_ix: int = 0 # Index of most right data
self._bar_count: int = self.MIN_BAR_COUNT # Total bar visible in chart
self._init_ui()
def _init_ui(self) -> None:
""""""
self.setWindowTitle("ChartWidget of vn.py")
self._layout = pg.GraphicsLayout()
self._layout.setContentsMargins(10, 10, 10, 10)
self._layout.setSpacing(0)
self._layout.setBorder(color=GREY_COLOR, width=0.8)
self._layout.setZValue(0)
self.setCentralItem(self._layout)
def _get_new_x_axis(self):
return DatetimeAxis(self._manager, orientation='bottom')
def add_cursor(self) -> None:
""""""
if not self._cursor:
self._cursor = ChartCursor(
self, self._manager, self._plots, self._item_plot_map)
def add_plot(
self,
plot_name: str,
minimum_height: int = 80,
maximum_height: int = None,
hide_x_axis: bool = False
) -> None:
"""
Add plot area.
"""
# Create plot object
plot = pg.PlotItem(axisItems={'bottom': self._get_new_x_axis()})
plot.setMenuEnabled(False)
plot.setClipToView(True)
plot.hideAxis('left')
plot.showAxis('right')
plot.setDownsampling(mode='peak')
plot.setRange(xRange=(0, 1), yRange=(0, 1))
plot.hideButtons()
plot.setMinimumHeight(minimum_height)
if maximum_height:
plot.setMaximumHeight(maximum_height)
if hide_x_axis:
plot.hideAxis("bottom")
if not self._first_plot:
self._first_plot = plot
# Connect view change signal to update y range function
view = plot.getViewBox()
view.sigXRangeChanged.connect(self._update_y_range)
view.setMouseEnabled(x=True, y=False)
# Set right axis
right_axis = plot.getAxis('right')
right_axis.setWidth(60)
right_axis.tickFont = NORMAL_FONT
# Connect x-axis link
if self._plots:
first_plot = list(self._plots.values())[0]
plot.setXLink(first_plot)
# Store plot object in dict
self._plots[plot_name] = plot
# Add plot onto the layout
self._layout.nextRow()
self._layout.addItem(plot)
def add_item(
self,
item_class: Type[ChartItem],
item_name: str,
plot_name: str
):
"""
Add chart item.
"""
item = item_class(self._manager)
self._items[item_name] = item
plot = self._plots.get(plot_name)
plot.addItem(item)
self._item_plot_map[item] = plot
def get_plot(self, plot_name: str) -> pg.PlotItem:
"""
Get specific plot with its name.
"""
return self._plots.get(plot_name, None)
def get_all_plots(self) -> List[pg.PlotItem]:
"""
Get all plot objects.
"""
return self._plots.values()
def clear_all(self) -> None:
"""
Clear all data.
"""
self._manager.clear_all()
for item in self._items.values():
item.clear_all()
if self._cursor:
self._cursor.clear_all()
def update_history(self, history: List[BarData]) -> None:
"""
Update a list of bar data.
"""
self._manager.update_history(history)
for item in self._items.values():
item.update_history(history)
self._update_plot_limits()
self.move_to_right()
def update_bar(self, bar: BarData) -> None:
"""
Update single bar data.
"""
self._manager.update_bar(bar)
for item in self._items.values():
item.update_bar(bar)
self._update_plot_limits()
if self._right_ix >= (self._manager.get_count() - self._bar_count / 2):
self.move_to_right()
def _update_plot_limits(self) -> None:
"""
Update the limit of plots.
"""
for item, plot in self._item_plot_map.items():
min_value, max_value = item.get_y_range()
plot.setLimits(
xMin=-1,
xMax=self._manager.get_count(),
yMin=min_value,
yMax=max_value
)
def _update_x_range(self) -> None:
"""
Update the x-axis range of plots.
"""
max_ix = self._right_ix
min_ix = self._right_ix - self._bar_count
for plot in self._plots.values():
plot.setRange(xRange=(min_ix, max_ix), padding=0)
def _update_y_range(self) -> None:
"""
Update the y-axis range of plots.
"""
view = self._first_plot.getViewBox()
view_range = view.viewRange()
min_ix = max(0, int(view_range[0][0]))
max_ix = min(self._manager.get_count(), int(view_range[0][1]))
# Update limit for y-axis
for item, plot in self._item_plot_map.items():
y_range = item.get_y_range(min_ix, max_ix)
plot.setRange(yRange=y_range)
def paintEvent(self, event: QtGui.QPaintEvent) -> None:
"""
Reimplement this method of parent to update current max_ix value.
"""
view = self._first_plot.getViewBox()
view_range = view.viewRange()
self._right_ix = max(0, view_range[0][1])
super().paintEvent(event)
def keyPressEvent(self, event: QtGui.QKeyEvent) -> None:
"""
Reimplement this method of parent to move chart horizontally and zoom in/out.
"""
if event.key() == QtCore.Qt.Key_Left:
self._on_key_left()
elif event.key() == QtCore.Qt.Key_Right:
self._on_key_right()
elif event.key() == QtCore.Qt.Key_Up:
self._on_key_up()
elif event.key() == QtCore.Qt.Key_Down:
self._on_key_down()
def wheelEvent(self, event: QtGui.QWheelEvent) -> None:
"""
Reimplement this method of parent to zoom in/out.
"""
delta = event.angleDelta()
if delta.y() > 0:
self._on_key_up()
elif delta.y() < 0:
self._on_key_down()
def _on_key_left(self) -> None:
"""
Move chart to left.
"""
self._right_ix -= 1
self._right_ix = max(self._right_ix, self._bar_count)
self._update_x_range()
self._cursor.move_left()
self._cursor.update_info()
def _on_key_right(self) -> None:
"""
Move chart to right.
"""
self._right_ix += 1
self._right_ix = min(self._right_ix, self._manager.get_count())
self._update_x_range()
self._cursor.move_right()
self._cursor.update_info()
def _on_key_down(self) -> None:
"""
Zoom out the chart.
"""
self._bar_count *= 1.2
self._bar_count = min(int(self._bar_count), self._manager.get_count())
self._update_x_range()
self._cursor.update_info()
def _on_key_up(self) -> None:
"""
Zoom in the chart.
"""
self._bar_count /= 1.2
self._bar_count = max(int(self._bar_count), self.MIN_BAR_COUNT)
self._update_x_range()
self._cursor.update_info()
def move_to_right(self) -> None:
"""
Move chart to the most right.
"""
self._right_ix = self._manager.get_count()
self._update_x_range()
self._cursor.update_info()
class ChartCursor(QtCore.QObject):
""""""
def __init__(
self,
widget: ChartWidget,
manager: BarManager,
plots: Dict[str, pg.GraphicsObject],
item_plot_map: Dict[ChartItem, pg.GraphicsObject]
):
""""""
super().__init__()
self._widget: ChartWidget = widget
self._manager: BarManager = manager
self._plots: Dict[str, pg.GraphicsObject] = plots
self._item_plot_map: Dict[ChartItem, pg.GraphicsObject] = item_plot_map
self._x: int = 0
self._y: int = 0
self._plot_name: str = ""
self._init_ui()
self._connect_signal()
def _init_ui(self):
""""""
self._init_line()
self._init_label()
self._init_info()
def _init_line(self) -> None:
"""
Create line objects.
"""
self._v_lines: Dict[str, pg.InfiniteLine] = {}
self._h_lines: Dict[str, pg.InfiniteLine] = {}
self._views: Dict[str, pg.ViewBox] = {}
pen = pg.mkPen(CURSOR_COLOR, style=QtCore.Qt.DashDotLine)
for plot_name, plot in self._plots.items():
v_line = pg.InfiniteLine(angle=90, movable=False, pen=pen)
h_line = pg.InfiniteLine(angle=0, movable=False, pen=pen)
view = plot.getViewBox()
for line in [v_line, h_line]:
line.setZValue(0)
line.hide()
view.addItem(line)
self._v_lines[plot_name] = v_line
self._h_lines[plot_name] = h_line
self._views[plot_name] = view
def _init_label(self) -> None:
"""
Create label objects on axis.
"""
self._y_labels: Dict[str, pg.TextItem] = {}
for plot_name, plot in self._plots.items():
label = pg.TextItem(
plot_name, fill=GREY_COLOR, color=WHITE_COLOR)
label.hide()
label.setZValue(2)
label.setFont(NORMAL_FONT)
plot.addItem(label, ignoreBounds=True)
self._y_labels[plot_name] = label
self._x_label: pg.TextItem = pg.TextItem(
"datetime", fill=GREY_COLOR, color=WHITE_COLOR)
self._x_label.hide()
self._x_label.setZValue(2)
self._x_label.setFont(NORMAL_FONT)
plot.addItem(self._x_label, ignoreBounds=True)
def _init_info(self) -> None:
"""
"""
self._infos: Dict[str, pg.TextItem] = {}
for plot_name, plot in self._plots.items():
info = pg.TextItem(
"info",
color=CURSOR_COLOR
)
info.hide()
info.setZValue(2)
info.setFont(NORMAL_FONT)
plot.addItem(info, ignoreBounds=True)
self._infos[plot_name] = info
def _connect_signal(self) -> None:
"""
Connect mouse move signal to update function.
"""
self._widget.scene().sigMouseMoved.connect(self._mouse_moved)
def _mouse_moved(self, evt: tuple) -> None:
"""
Callback function when mouse is moved.
"""
if not self._manager.get_count():
return
# First get current mouse point
pos = evt
for plot_name, view in self._views.items():
rect = view.sceneBoundingRect()
if rect.contains(pos):
mouse_point = view.mapSceneToView(pos)
self._x = to_int(mouse_point.x())
self._y = mouse_point.y()
self._plot_name = plot_name
break
# Then update cursor component
self._update_line()
self._update_label()
self.update_info()
def _update_line(self) -> None:
""""""
for v_line in self._v_lines.values():
v_line.setPos(self._x)
v_line.show()
for plot_name, h_line in self._h_lines.items():
if plot_name == self._plot_name:
h_line.setPos(self._y)
h_line.show()
else:
h_line.hide()
def _update_label(self) -> None:
""""""
bottom_plot = list(self._plots.values())[-1]
axis_width = bottom_plot.getAxis("right").width()
axis_height = bottom_plot.getAxis("bottom").height()
axis_offset = QtCore.QPointF(axis_width, axis_height)
bottom_view = list(self._views.values())[-1]
bottom_right = bottom_view.mapSceneToView(
bottom_view.sceneBoundingRect().bottomRight() - axis_offset
)
for plot_name, label in self._y_labels.items():
if plot_name == self._plot_name:
label.setText(str(self._y))
label.show()
label.setPos(bottom_right.x(), self._y)
else:
label.hide()
dt = self._manager.get_datetime(self._x)
if dt:
self._x_label.setText(dt.strftime("%Y-%m-%d %H:%M:%S"))
self._x_label.show()
self._x_label.setPos(self._x, bottom_right.y())
self._x_label.setAnchor((0, 0))
def update_info(self) -> None:
""""""
buf = {}
for item, plot in self._item_plot_map.items():
item_info_text = item.get_info_text(self._x)
if plot not in buf:
buf[plot] = item_info_text
else:
if item_info_text:
buf[plot] += ("\n\n" + item_info_text)
for plot_name, plot in self._plots.items():
plot_info_text = buf[plot]
info = self._infos[plot_name]
info.setText(plot_info_text)
info.show()
view = self._views[plot_name]
top_left = view.mapSceneToView(view.sceneBoundingRect().topLeft())
info.setPos(top_left)
def move_right(self) -> None:
"""
Move cursor index to right by 1.
"""
if self._x == self._manager.get_count() - 1:
return
self._x += 1
self._update_after_move()
def move_left(self) -> None:
"""
Move cursor index to left by 1.
"""
if self._x == 0:
return
self._x -= 1
self._update_after_move()
def _update_after_move(self) -> None:
"""
Update cursor after moved by left/right.
"""
bar = self._manager.get_bar(self._x)
self._y = bar.close_price
self._update_line()
self._update_label()
def clear_all(self) -> None:
"""
Clear all data.
"""
self._x = 0
self._y = 0
self._plot_name = ""
for line in list(self._v_lines.values()) + list(self._h_lines.values()):
line.hide()
for label in list(self._y_labels.values()) + [self._x_label]:
label.hide()
class NewChartWidget(ChartWidget):
""""""
MIN_BAR_COUNT = 100
strategy_name: str = ""
signal_cta_history_bar: QtCore.pyqtSignal = QtCore.pyqtSignal(Event)
signal_cta_tick: QtCore.pyqtSignal = QtCore.pyqtSignal(Event)
signal_cta_bar: QtCore.pyqtSignal = QtCore.pyqtSignal(Event)
def __init__(self, parent: QtWidgets.QWidget = None, event_engine: EventEngine = None, strategy_name: str = ""):
""""""
super().__init__(parent)
self.strategy_name = strategy_name
self.event_engine = event_engine
self.last_price_line: pg.InfiniteLine = None
def register_event(self) -> None:
""""""
self.signal_cta_history_bar.connect(self.process_cta_history_bar)
self.event_engine.register(EVENT_CTA_HISTORY_BAR, self.signal_cta_history_bar.emit)
self.signal_cta_tick.connect(self.process_tick_event)
self.event_engine.register(EVENT_CTA_TICK, self.signal_cta_tick.emit)
self.signal_cta_bar.connect(self.process_cta_bar)
self.event_engine.register(EVENT_CTA_BAR, self.signal_cta_bar.emit)
def process_cta_history_bar(self, event: Event) -> None:
""" 处理历史K线推送 """
strategy_name, history_bars = event.data
if strategy_name == self.strategy_name:
self.update_history(history_bars)
def process_tick_event(self, event: Event) -> None:
""" 处理tick数据推送 """
strategy_name, tick = event.data
if strategy_name == self.strategy_name:
if self.last_price_line:
self.last_price_line.setValue(tick.last_price)
def process_cta_bar(self, event: Event) -> None:
""" 处理K线数据推送 """
strategy_name, bar = event.data
if strategy_name == self.strategy_name:
self.update_bar(bar)
def add_last_price_line(self):
""""""
plot = list(self._plots.values())[0]
color = (255, 255, 0)
self.last_price_line = pg.InfiniteLine(
angle=0,
movable=False,
label="{value}",
pen=pg.mkPen(color, width=1, style=QtCore.Qt.DashDotLine),
labelOpts={
"color": color,
"position": 1,
"anchors": [(1, 1), (1, 1)]
}
)
self.last_price_line.label.setFont(NORMAL_FONT)
plot.addItem(self.last_price_line)
def update_history(self, history: List[BarData]) -> None:
"""
Update a list of bar data.
"""
self._manager.update_history(history)
for item in self._items.values():
item.update_history(history)
self._update_plot_limits()
self.move_to_right()
self.update_last_price_line(history[-1])
def update_bar(self, bar: BarData) -> None:
"""
Update single bar data.
"""
self._manager.update_bar(bar)
for item in self._items.values():
item.update_bar(bar)
self._update_plot_limits()
if self._right_ix >= (self._manager.get_count() - self._bar_count / 2):
self.move_to_right()
self.update_last_price_line(bar)
def update_last_price_line(self, bar: BarData) -> None:
""""""
if self.last_price_line:
self.last_price_line.setValue(bar.close_price)
4. vnpy/app/cta_backtester/ui/widget.py
import csv
from datetime import datetime, timedelta
from copy import copy
import numpy as np
import pyqtgraph as pg
from vnpy.trader.constant import Interval, Direction, Exchange
from vnpy.trader.engine import MainEngine
from vnpy.trader.ui import QtCore, QtWidgets, QtGui
from vnpy.trader.ui.widget import BaseMonitor, BaseCell, DirectionCell, EnumCell
from vnpy.trader.ui.editor import CodeEditor
from vnpy.event import Event, EventEngine
from vnpy.chart import NewChartWidget, CandleItem, VolumeItem, LineItem, SmaItem, RsiItem, MacdItem
from vnpy.trader.utility import load_json, save_json
from vnpy.trader.database import DB_TZ
from ..engine import (
APP_NAME,
EVENT_BACKTESTER_LOG,
EVENT_BACKTESTER_BACKTESTING_FINISHED,
EVENT_BACKTESTER_OPTIMIZATION_FINISHED,
OptimizationSetting
)
class BacktesterManager(QtWidgets.QWidget):
""""""
setting_filename = "cta_backtester_setting.json"
signal_log = QtCore.pyqtSignal(Event)
signal_backtesting_finished = QtCore.pyqtSignal(Event)
signal_optimization_finished = QtCore.pyqtSignal(Event)
def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
""""""
super().__init__()
self.main_engine = main_engine
self.event_engine = event_engine
self.backtester_engine = main_engine.get_engine(APP_NAME)
self.class_names = []
self.settings = {}
self.target_display = ""
self.init_ui()
self.register_event()
self.backtester_engine.init_engine()
self.init_strategy_settings()
self.load_backtesting_setting()
def init_strategy_settings(self):
""""""
self.class_names = self.backtester_engine.get_strategy_class_names()
for class_name in self.class_names:
setting = self.backtester_engine.get_default_setting(class_name)
self.settings[class_name] = setting
self.class_combo.addItems(self.class_names)
def init_ui(self):
""""""
self.setWindowTitle("CTA回测")
# Setting Part
self.class_combo = QtWidgets.QComboBox()
self.symbol_line = QtWidgets.QLineEdit("IF88.CFFEX")
self.interval_combo = QtWidgets.QComboBox()
for interval in Interval:
self.interval_combo.addItem(interval.value)
end_dt = datetime.now()
start_dt = end_dt - timedelta(days=3 * 365)
self.start_date_edit = QtWidgets.QDateEdit(
QtCore.QDate(
start_dt.year,
start_dt.month,
start_dt.day
)
)
self.end_date_edit = QtWidgets.QDateEdit(
QtCore.QDate.currentDate()
)
self.rate_line = QtWidgets.QLineEdit("0.000025")
self.slippage_line = QtWidgets.QLineEdit("0.2")
self.size_line = QtWidgets.QLineEdit("300")
self.pricetick_line = QtWidgets.QLineEdit("0.2")
self.capital_line = QtWidgets.QLineEdit("1000000")
self.inverse_combo = QtWidgets.QComboBox()
self.inverse_combo.addItems(["正向", "反向"])
backtesting_button = QtWidgets.QPushButton("开始回测")
backtesting_button.clicked.connect(self.start_backtesting)
optimization_button = QtWidgets.QPushButton("参数优化")
optimization_button.clicked.connect(self.start_optimization)
self.result_button = QtWidgets.QPushButton("优化结果")
self.result_button.clicked.connect(self.show_optimization_result)
self.result_button.setEnabled(False)
downloading_button = QtWidgets.QPushButton("下载数据")
downloading_button.clicked.connect(self.start_downloading)
self.order_button = QtWidgets.QPushButton("委托记录")
self.order_button.clicked.connect(self.show_backtesting_orders)
self.order_button.setEnabled(False)
self.trade_button = QtWidgets.QPushButton("成交记录")
self.trade_button.clicked.connect(self.show_backtesting_trades)
self.trade_button.setEnabled(False)
self.daily_button = QtWidgets.QPushButton("每日盈亏")
self.daily_button.clicked.connect(self.show_daily_results)
self.daily_button.setEnabled(False)
self.candle_button = QtWidgets.QPushButton("K线图表")
self.candle_button.clicked.connect(self.show_candle_chart)
self.candle_button.setEnabled(False)
edit_button = QtWidgets.QPushButton("代码编辑")
edit_button.clicked.connect(self.edit_strategy_code)
reload_button = QtWidgets.QPushButton("策略重载")
reload_button.clicked.connect(self.reload_strategy_class)
for button in [
backtesting_button,
optimization_button,
downloading_button,
self.result_button,
self.order_button,
self.trade_button,
self.daily_button,
self.candle_button,
edit_button,
reload_button
]:
button.setFixedHeight(button.sizeHint().height() * 2)
form = QtWidgets.QFormLayout()
form.addRow("交易策略", self.class_combo)
form.addRow("本地代码", self.symbol_line)
form.addRow("K线周期", self.interval_combo)
form.addRow("开始日期", self.start_date_edit)
form.addRow("结束日期", self.end_date_edit)
form.addRow("手续费率", self.rate_line)
form.addRow("交易滑点", self.slippage_line)
form.addRow("合约乘数", self.size_line)
form.addRow("价格跳动", self.pricetick_line)
form.addRow("回测资金", self.capital_line)
form.addRow("合约模式", self.inverse_combo)
result_grid = QtWidgets.QGridLayout()
result_grid.addWidget(self.trade_button, 0, 0)
result_grid.addWidget(self.order_button, 0, 1)
result_grid.addWidget(self.daily_button, 1, 0)
result_grid.addWidget(self.candle_button, 1, 1)
left_vbox = QtWidgets.QVBoxLayout()
left_vbox.addLayout(form)
left_vbox.addWidget(backtesting_button)
left_vbox.addWidget(downloading_button)
left_vbox.addStretch()
left_vbox.addLayout(result_grid)
left_vbox.addStretch()
left_vbox.addWidget(optimization_button)
left_vbox.addWidget(self.result_button)
left_vbox.addStretch()
left_vbox.addWidget(edit_button)
left_vbox.addWidget(reload_button)
# Result part
self.statistics_monitor = StatisticsMonitor()
self.log_monitor = QtWidgets.QTextEdit()
self.log_monitor.setMaximumHeight(400)
self.chart = BacktesterChart()
self.chart.setMinimumWidth(1000)
self.trade_dialog = BacktestingResultDialog(
self.main_engine,
self.event_engine,
"回测成交记录",
BacktestingTradeMonitor
)
self.order_dialog = BacktestingResultDialog(
self.main_engine,
self.event_engine,
"回测委托记录",
BacktestingOrderMonitor
)
self.daily_dialog = BacktestingResultDialog(
self.main_engine,
self.event_engine,
"回测每日盈亏",
DailyResultMonitor
)
# Candle Chart
self.candle_dialog = CandleChartDialog()
# Layout
vbox = QtWidgets.QVBoxLayout()
vbox.addWidget(self.statistics_monitor)
vbox.addWidget(self.log_monitor)
hbox = QtWidgets.QHBoxLayout()
hbox.addLayout(left_vbox)
hbox.addLayout(vbox)
hbox.addWidget(self.chart)
self.setLayout(hbox)
# Code Editor
self.editor = CodeEditor(self.main_engine, self.event_engine)
def load_backtesting_setting(self):
""""""
setting = load_json(self.setting_filename)
if not setting:
return
self.class_combo.setCurrentIndex(
self.class_combo.findText(setting["class_name"])
)
self.symbol_line.setText(setting["vt_symbol"])
self.interval_combo.setCurrentIndex(
self.interval_combo.findText(setting["interval"])
)
start_str = setting.get("start", "")
if start_str:
start_dt = QtCore.QDate.fromString(start_str, "yyyy-MM-dd")
self.start_date_edit.setDate(start_dt)
self.rate_line.setText(str(setting["rate"]))
self.slippage_line.setText(str(setting["slippage"]))
self.size_line.setText(str(setting["size"]))
self.pricetick_line.setText(str(setting["pricetick"]))
self.capital_line.setText(str(setting["capital"]))
if not setting["inverse"]:
self.inverse_combo.setCurrentIndex(0)
else:
self.inverse_combo.setCurrentIndex(1)
def register_event(self):
""""""
self.signal_log.connect(self.process_log_event)
self.signal_backtesting_finished.connect(
self.process_backtesting_finished_event)
self.signal_optimization_finished.connect(
self.process_optimization_finished_event)
self.event_engine.register(EVENT_BACKTESTER_LOG, self.signal_log.emit)
self.event_engine.register(
EVENT_BACKTESTER_BACKTESTING_FINISHED, self.signal_backtesting_finished.emit)
self.event_engine.register(
EVENT_BACKTESTER_OPTIMIZATION_FINISHED, self.signal_optimization_finished.emit)
def process_log_event(self, event: Event):
""""""
msg = event.data
self.write_log(msg)
def write_log(self, msg):
""""""
timestamp = datetime.now().strftime("%H:%M:%S")
msg = f"{timestamp}\t{msg}"
self.log_monitor.append(msg)
def process_backtesting_finished_event(self, event: Event):
""""""
statistics = self.backtester_engine.get_result_statistics()
self.statistics_monitor.set_data(statistics)
df = self.backtester_engine.get_result_df()
self.chart.set_data(df)
self.trade_button.setEnabled(True)
self.order_button.setEnabled(True)
self.daily_button.setEnabled(True)
# Tick data can not be displayed using candle chart
interval = self.interval_combo.currentText()
if interval != Interval.TICK.value:
self.candle_button.setEnabled(True)
def process_optimization_finished_event(self, event: Event):
""""""
self.write_log("请点击[优化结果]按钮查看")
self.result_button.setEnabled(True)
def start_backtesting(self):
""""""
class_name = self.class_combo.currentText()
vt_symbol = self.symbol_line.text()
interval = self.interval_combo.currentText()
start = self.start_date_edit.dateTime().toPyDateTime()
end = self.end_date_edit.dateTime().toPyDateTime()
rate = float(self.rate_line.text())
slippage = float(self.slippage_line.text())
size = float(self.size_line.text())
pricetick = float(self.pricetick_line.text())
capital = float(self.capital_line.text())
if self.inverse_combo.currentText() == "正向":
inverse = False
else:
inverse = True
# Check validity of vt_symbol
if "." not in vt_symbol:
self.write_log("本地代码缺失交易所后缀,请检查")
return
_, exchange_str = vt_symbol.split(".")
if exchange_str not in Exchange.__members__:
self.write_log("本地代码的交易所后缀不正确,请检查")
return
# Save backtesting parameters
backtesting_setting = {
"class_name": class_name,
"vt_symbol": vt_symbol,
"interval": interval,
"start": start.isoformat(),
"rate": rate,
"slippage": slippage,
"size": size,
"pricetick": pricetick,
"capital": capital,
"inverse": inverse,
}
save_json(self.setting_filename, backtesting_setting)
# Get strategy setting
old_setting = self.settings[class_name]
dialog = BacktestingSettingEditor(class_name, old_setting)
i = dialog.exec()
if i != dialog.Accepted:
return
new_setting = dialog.get_setting()
self.settings[class_name] = new_setting
result = self.backtester_engine.start_backtesting(
class_name,
vt_symbol,
interval,
start,
end,
rate,
slippage,
size,
pricetick,
capital,
inverse,
new_setting
)
if result:
self.statistics_monitor.clear_data()
self.chart.clear_data()
self.trade_button.setEnabled(False)
self.order_button.setEnabled(False)
self.daily_button.setEnabled(False)
self.candle_button.setEnabled(False)
self.trade_dialog.clear_data()
self.order_dialog.clear_data()
self.daily_dialog.clear_data()
self.candle_dialog.clear_data()
def start_optimization(self):
""""""
class_name = self.class_combo.currentText()
vt_symbol = self.symbol_line.text()
interval = self.interval_combo.currentText()
start = self.start_date_edit.dateTime().toPyDateTime()
end = self.end_date_edit.dateTime().toPyDateTime()
rate = float(self.rate_line.text())
slippage = float(self.slippage_line.text())
size = float(self.size_line.text())
pricetick = float(self.pricetick_line.text())
capital = float(self.capital_line.text())
if self.inverse_combo.currentText() == "正向":
inverse = False
else:
inverse = True
parameters = self.settings[class_name]
dialog = OptimizationSettingEditor(class_name, parameters)
i = dialog.exec()
if i != dialog.Accepted:
return
optimization_setting, use_ga = dialog.get_setting()
self.target_display = dialog.target_display
self.backtester_engine.start_optimization(
class_name,
vt_symbol,
interval,
start,
end,
rate,
slippage,
size,
pricetick,
capital,
inverse,
optimization_setting,
use_ga
)
self.result_button.setEnabled(False)
def start_downloading(self):
""""""
vt_symbol = self.symbol_line.text()
interval = self.interval_combo.currentText()
start_date = self.start_date_edit.date()
end_date = self.end_date_edit.date()
start = datetime(
start_date.year(),
start_date.month(),
start_date.day(),
)
start = DB_TZ.localize(start)
end = datetime(
end_date.year(),
end_date.month(),
end_date.day(),
23,
59,
59,
)
end = DB_TZ.localize(end)
self.backtester_engine.start_downloading(
vt_symbol,
interval,
start,
end
)
def show_optimization_result(self):
""""""
result_values = self.backtester_engine.get_result_values()
dialog = OptimizationResultMonitor(
result_values,
self.target_display
)
dialog.exec_()
def show_backtesting_trades(self):
""""""
if not self.trade_dialog.is_updated():
trades = self.backtester_engine.get_all_trades()
self.trade_dialog.update_data(trades)
self.trade_dialog.exec_()
def show_backtesting_orders(self):
""""""
if not self.order_dialog.is_updated():
orders = self.backtester_engine.get_all_orders()
self.order_dialog.update_data(orders)
self.order_dialog.exec_()
def show_daily_results(self):
""""""
if not self.daily_dialog.is_updated():
results = self.backtester_engine.get_all_daily_results()
self.daily_dialog.update_data(results)
self.daily_dialog.exec_()
def show_candle_chart(self):
""""""
if not self.candle_dialog.is_updated():
history = self.backtester_engine.get_history_data()
self.candle_dialog.update_history(history)
trades = self.backtester_engine.get_all_trades()
self.candle_dialog.update_trades(trades)
self.candle_dialog.exec_()
def edit_strategy_code(self):
""""""
class_name = self.class_combo.currentText()
file_path = self.backtester_engine.get_strategy_class_file(class_name)
self.editor.open_editor(file_path)
self.editor.show()
def reload_strategy_class(self):
""""""
self.backtester_engine.reload_strategy_class()
current_strategy_name = self.class_combo.currentText()
self.class_combo.clear()
self.init_strategy_settings()
ix = self.class_combo.findText(current_strategy_name)
self.class_combo.setCurrentIndex(ix)
def show(self):
""""""
self.showMaximized()
class StatisticsMonitor(QtWidgets.QTableWidget):
""""""
KEY_NAME_MAP = {
"start_date": "首个交易日",
"end_date": "最后交易日",
"total_days": "总交易日",
"profit_days": "盈利交易日",
"loss_days": "亏损交易日",
"capital": "起始资金",
"end_balance": "结束资金",
"total_return": "总收益率",
"annual_return": "年化收益",
"max_drawdown": "最大回撤",
"max_ddpercent": "百分比最大回撤",
"total_net_pnl": "总盈亏",
"total_commission": "总手续费",
"total_slippage": "总滑点",
"total_turnover": "总成交额",
"total_trade_count": "总成交笔数",
"daily_net_pnl": "日均盈亏",
"daily_commission": "日均手续费",
"daily_slippage": "日均滑点",
"daily_turnover": "日均成交额",
"daily_trade_count": "日均成交笔数",
"daily_return": "日均收益率",
"return_std": "收益标准差",
"sharpe_ratio": "夏普比率",
"return_drawdown_ratio": "收益回撤比"
}
def __init__(self):
""""""
super().__init__()
self.cells = {}
self.init_ui()
def init_ui(self):
""""""
self.setRowCount(len(self.KEY_NAME_MAP))
self.setVerticalHeaderLabels(list(self.KEY_NAME_MAP.values()))
self.setColumnCount(1)
self.horizontalHeader().setVisible(False)
self.horizontalHeader().setSectionResizeMode(
QtWidgets.QHeaderView.Stretch
)
self.setEditTriggers(self.NoEditTriggers)
for row, key in enumerate(self.KEY_NAME_MAP.keys()):
cell = QtWidgets.QTableWidgetItem()
self.setItem(row, 0, cell)
self.cells[key] = cell
def clear_data(self):
""""""
for cell in self.cells.values():
cell.setText("")
def set_data(self, data: dict):
""""""
data["capital"] = f"{data['capital']:,.2f}"
data["end_balance"] = f"{data['end_balance']:,.2f}"
data["total_return"] = f"{data['total_return']:,.2f}%"
data["annual_return"] = f"{data['annual_return']:,.2f}%"
data["max_drawdown"] = f"{data['max_drawdown']:,.2f}"
data["max_ddpercent"] = f"{data['max_ddpercent']:,.2f}%"
data["total_net_pnl"] = f"{data['total_net_pnl']:,.2f}"
data["total_commission"] = f"{data['total_commission']:,.2f}"
data["total_slippage"] = f"{data['total_slippage']:,.2f}"
data["total_turnover"] = f"{data['total_turnover']:,.2f}"
data["daily_net_pnl"] = f"{data['daily_net_pnl']:,.2f}"
data["daily_commission"] = f"{data['daily_commission']:,.2f}"
data["daily_slippage"] = f"{data['daily_slippage']:,.2f}"
data["daily_turnover"] = f"{data['daily_turnover']:,.2f}"
data["daily_trade_count"] = f"{data['daily_trade_count']:,.2f}"
data["daily_return"] = f"{data['daily_return']:,.2f}%"
data["return_std"] = f"{data['return_std']:,.2f}%"
data["sharpe_ratio"] = f"{data['sharpe_ratio']:,.2f}"
data["return_drawdown_ratio"] = f"{data['return_drawdown_ratio']:,.2f}"
for key, cell in self.cells.items():
value = data.get(key, "")
cell.setText(str(value))
class BacktestingSettingEditor(QtWidgets.QDialog):
"""
For creating new strategy and editing strategy parameters.
"""
def __init__(
self, class_name: str, parameters: dict
):
""""""
super(BacktestingSettingEditor, self).__init__()
self.class_name = class_name
self.parameters = parameters
self.edits = {}
self.init_ui()
def init_ui(self):
""""""
form = QtWidgets.QFormLayout()
# Add vt_symbol and name edit if add new strategy
self.setWindowTitle(f"策略参数配置:{self.class_name}")
button_text = "确定"
parameters = self.parameters
for name, value in parameters.items():
type_ = type(value)
edit = QtWidgets.QLineEdit(str(value))
if type_ is int:
validator = QtGui.QIntValidator()
edit.setValidator(validator)
elif type_ is float:
validator = QtGui.QDoubleValidator()
edit.setValidator(validator)
form.addRow(f"{name} {type_}", edit)
self.edits[name] = (edit, type_)
button = QtWidgets.QPushButton(button_text)
button.clicked.connect(self.accept)
form.addRow(button)
widget = QtWidgets.QWidget()
widget.setLayout(form)
scroll = QtWidgets.QScrollArea()
scroll.setWidgetResizable(True)
scroll.setWidget(widget)
vbox = QtWidgets.QVBoxLayout()
vbox.addWidget(scroll)
self.setLayout(vbox)
def get_setting(self):
""""""
setting = {}
for name, tp in self.edits.items():
edit, type_ = tp
value_text = edit.text()
if type_ == bool:
if value_text == "True":
value = True
else:
value = False
else:
value = type_(value_text)
setting[name] = value
return setting
class BacktesterChart(pg.GraphicsWindow):
""""""
def __init__(self):
""""""
super().__init__(title="Backtester Chart")
self.dates = {}
self.init_ui()
def init_ui(self):
""""""
pg.setConfigOptions(antialias=True)
# Create plot widgets
self.balance_plot = self.addPlot(
title="账户净值",
axisItems={"bottom": DateAxis(self.dates, orientation="bottom")}
)
self.nextRow()
self.drawdown_plot = self.addPlot(
title="净值回撤",
axisItems={"bottom": DateAxis(self.dates, orientation="bottom")}
)
self.nextRow()
self.pnl_plot = self.addPlot(
title="每日盈亏",
axisItems={"bottom": DateAxis(self.dates, orientation="bottom")}
)
self.nextRow()
self.distribution_plot = self.addPlot(title="盈亏分布")
# Add curves and bars on plot widgets
self.balance_curve = self.balance_plot.plot(
pen=pg.mkPen("#ffc107", width=3)
)
dd_color = "#303f9f"
self.drawdown_curve = self.drawdown_plot.plot(
fillLevel=-0.3, brush=dd_color, pen=dd_color
)
profit_color = 'r'
loss_color = 'g'
self.profit_pnl_bar = pg.BarGraphItem(
x=[], height=[], width=0.3, brush=profit_color, pen=profit_color
)
self.loss_pnl_bar = pg.BarGraphItem(
x=[], height=[], width=0.3, brush=loss_color, pen=loss_color
)
self.pnl_plot.addItem(self.profit_pnl_bar)
self.pnl_plot.addItem(self.loss_pnl_bar)
distribution_color = "#6d4c41"
self.distribution_curve = self.distribution_plot.plot(
fillLevel=-0.3, brush=distribution_color, pen=distribution_color
)
def clear_data(self):
""""""
self.balance_curve.setData([], [])
self.drawdown_curve.setData([], [])
self.profit_pnl_bar.setOpts(x=[], height=[])
self.loss_pnl_bar.setOpts(x=[], height=[])
self.distribution_curve.setData([], [])
def set_data(self, df):
""""""
if df is None:
return
count = len(df)
self.dates.clear()
for n, date in enumerate(df.index):
self.dates[n] = date
# Set data for curve of balance and drawdown
self.balance_curve.setData(df["balance"])
self.drawdown_curve.setData(df["drawdown"])
# Set data for daily pnl bar
profit_pnl_x = []
profit_pnl_height = []
loss_pnl_x = []
loss_pnl_height = []
for count, pnl in enumerate(df["net_pnl"]):
if pnl >= 0:
profit_pnl_height.append(pnl)
profit_pnl_x.append(count)
else:
loss_pnl_height.append(pnl)
loss_pnl_x.append(count)
self.profit_pnl_bar.setOpts(x=profit_pnl_x, height=profit_pnl_height)
self.loss_pnl_bar.setOpts(x=loss_pnl_x, height=loss_pnl_height)
# Set data for pnl distribution
hist, x = np.histogram(df["net_pnl"], bins="auto")
x = x[:-1]
self.distribution_curve.setData(x, hist)
class DateAxis(pg.AxisItem):
"""Axis for showing date data"""
def __init__(self, dates: dict, *args, **kwargs):
""""""
super().__init__(*args, **kwargs)
self.dates = dates
def tickStrings(self, values, scale, spacing):
""""""
strings = []
for v in values:
dt = self.dates.get(v, "")
strings.append(str(dt))
return strings
class OptimizationSettingEditor(QtWidgets.QDialog):
"""
For setting up parameters for optimization.
"""
DISPLAY_NAME_MAP = {
"总收益率": "total_return",
"夏普比率": "sharpe_ratio",
"收益回撤比": "return_drawdown_ratio",
"日均盈亏": "daily_net_pnl"
}
def __init__(
self, class_name: str, parameters: dict
):
""""""
super().__init__()
self.class_name = class_name
self.parameters = parameters
self.edits = {}
self.optimization_setting = None
self.use_ga = False
self.init_ui()
def init_ui(self):
""""""
QLabel = QtWidgets.QLabel
self.target_combo = QtWidgets.QComboBox()
self.target_combo.addItems(list(self.DISPLAY_NAME_MAP.keys()))
grid = QtWidgets.QGridLayout()
grid.addWidget(QLabel("目标"), 0, 0)
grid.addWidget(self.target_combo, 0, 1, 1, 3)
grid.addWidget(QLabel("参数"), 1, 0)
grid.addWidget(QLabel("开始"), 1, 1)
grid.addWidget(QLabel("步进"), 1, 2)
grid.addWidget(QLabel("结束"), 1, 3)
# Add vt_symbol and name edit if add new strategy
self.setWindowTitle(f"优化参数配置:{self.class_name}")
validator = QtGui.QDoubleValidator()
row = 2
for name, value in self.parameters.items():
type_ = type(value)
if type_ not in [int, float]:
continue
start_edit = QtWidgets.QLineEdit(str(value))
step_edit = QtWidgets.QLineEdit(str(1))
end_edit = QtWidgets.QLineEdit(str(value))
for edit in [start_edit, step_edit, end_edit]:
edit.setValidator(validator)
grid.addWidget(QLabel(name), row, 0)
grid.addWidget(start_edit, row, 1)
grid.addWidget(step_edit, row, 2)
grid.addWidget(end_edit, row, 3)
self.edits[name] = {
"type": type_,
"start": start_edit,
"step": step_edit,
"end": end_edit
}
row += 1
parallel_button = QtWidgets.QPushButton("多进程优化")
parallel_button.clicked.connect(self.generate_parallel_setting)
grid.addWidget(parallel_button, row, 0, 1, 4)
row += 1
ga_button = QtWidgets.QPushButton("遗传算法优化")
ga_button.clicked.connect(self.generate_ga_setting)
grid.addWidget(ga_button, row, 0, 1, 4)
widget = QtWidgets.QWidget()
widget.setLayout(grid)
scroll = QtWidgets.QScrollArea()
scroll.setWidgetResizable(True)
scroll.setWidget(widget)
vbox = QtWidgets.QVBoxLayout()
vbox.addWidget(scroll)
self.setLayout(vbox)
def generate_ga_setting(self):
""""""
self.use_ga = True
self.generate_setting()
def generate_parallel_setting(self):
""""""
self.use_ga = False
self.generate_setting()
def generate_setting(self):
""""""
self.optimization_setting = OptimizationSetting()
self.target_display = self.target_combo.currentText()
target_name = self.DISPLAY_NAME_MAP[self.target_display]
self.optimization_setting.set_target(target_name)
for name, d in self.edits.items():
type_ = d["type"]
start_value = type_(d["start"].text())
step_value = type_(d["step"].text())
end_value = type_(d["end"].text())
if start_value == end_value:
self.optimization_setting.add_parameter(name, start_value)
else:
self.optimization_setting.add_parameter(
name,
start_value,
end_value,
step_value
)
self.accept()
def get_setting(self):
""""""
return self.optimization_setting, self.use_ga
class OptimizationResultMonitor(QtWidgets.QDialog):
"""
For viewing optimization result.
"""
def __init__(
self, result_values: list, target_display: str
):
""""""
super().__init__()
self.result_values = result_values
self.target_display = target_display
self.init_ui()
def init_ui(self):
""""""
self.setWindowTitle("参数优化结果")
self.resize(1100, 500)
# Creat table to show result
table = QtWidgets.QTableWidget()
table.setColumnCount(2)
table.setRowCount(len(self.result_values))
table.setHorizontalHeaderLabels(["参数", self.target_display])
table.setEditTriggers(table.NoEditTriggers)
table.verticalHeader().setVisible(False)
table.horizontalHeader().setSectionResizeMode(
0, QtWidgets.QHeaderView.ResizeToContents
)
table.horizontalHeader().setSectionResizeMode(
1, QtWidgets.QHeaderView.Stretch
)
for n, tp in enumerate(self.result_values):
setting, target_value, _ = tp
setting_cell = QtWidgets.QTableWidgetItem(str(setting))
target_cell = QtWidgets.QTableWidgetItem(f"{target_value:.2f}")
setting_cell.setTextAlignment(QtCore.Qt.AlignCenter)
target_cell.setTextAlignment(QtCore.Qt.AlignCenter)
table.setItem(n, 0, setting_cell)
table.setItem(n, 1, target_cell)
# Create layout
button = QtWidgets.QPushButton("保存")
button.clicked.connect(self.save_csv)
hbox = QtWidgets.QHBoxLayout()
hbox.addStretch()
hbox.addWidget(button)
vbox = QtWidgets.QVBoxLayout()
vbox.addWidget(table)
vbox.addLayout(hbox)
self.setLayout(vbox)
def save_csv(self) -> None:
"""
Save table data into a csv file
"""
path, _ = QtWidgets.QFileDialog.getSaveFileName(
self, "保存数据", "", "CSV(*.csv)")
if not path:
return
with open(path, "w") as f:
writer = csv.writer(f, lineterminator="\n")
writer.writerow(["参数", self.target_display])
for tp in self.result_values:
setting, target_value, _ = tp
row_data = [str(setting), str(target_value)]
writer.writerow(row_data)
class BacktestingTradeMonitor(BaseMonitor):
"""
Monitor for backtesting trade data.
"""
headers = {
"tradeid": {"display": "成交号 ", "cell": BaseCell, "update": False},
"orderid": {"display": "委托号", "cell": BaseCell, "update": False},
"symbol": {"display": "代码", "cell": BaseCell, "update": False},
"exchange": {"display": "交易所", "cell": EnumCell, "update": False},
"direction": {"display": "方向", "cell": DirectionCell, "update": False},
"offset": {"display": "开平", "cell": EnumCell, "update": False},
"price": {"display": "价格", "cell": BaseCell, "update": False},
"volume": {"display": "数量", "cell": BaseCell, "update": False},
"datetime": {"display": "时间", "cell": BaseCell, "update": False},
"gateway_name": {"display": "接口", "cell": BaseCell, "update": False},
}
class BacktestingOrderMonitor(BaseMonitor):
"""
Monitor for backtesting order data.
"""
headers = {
"orderid": {"display": "委托号", "cell": BaseCell, "update": False},
"symbol": {"display": "代码", "cell": BaseCell, "update": False},
"exchange": {"display": "交易所", "cell": EnumCell, "update": False},
"type": {"display": "类型", "cell": EnumCell, "update": False},
"direction": {"display": "方向", "cell": DirectionCell, "update": False},
"offset": {"display": "开平", "cell": EnumCell, "update": False},
"price": {"display": "价格", "cell": BaseCell, "update": False},
"volume": {"display": "总数量", "cell": BaseCell, "update": False},
"traded": {"display": "已成交", "cell": BaseCell, "update": False},
"status": {"display": "状态", "cell": EnumCell, "update": False},
"datetime": {"display": "时间", "cell": BaseCell, "update": False},
"gateway_name": {"display": "接口", "cell": BaseCell, "update": False},
}
class FloatCell(BaseCell):
"""
Cell used for showing pnl data.
"""
def __init__(self, content, data):
""""""
content = f"{content:.2f}"
super().__init__(content, data)
class DailyResultMonitor(BaseMonitor):
"""
Monitor for backtesting daily result.
"""
headers = {
"date": {"display": "日期", "cell": BaseCell, "update": False},
"trade_count": {"display": "成交笔数", "cell": BaseCell, "update": False},
"start_pos": {"display": "开盘持仓", "cell": BaseCell, "update": False},
"end_pos": {"display": "收盘持仓", "cell": BaseCell, "update": False},
"turnover": {"display": "成交额", "cell": FloatCell, "update": False},
"commission": {"display": "手续费", "cell": FloatCell, "update": False},
"slippage": {"display": "滑点", "cell": FloatCell, "update": False},
"trading_pnl": {"display": "交易盈亏", "cell": FloatCell, "update": False},
"holding_pnl": {"display": "持仓盈亏", "cell": FloatCell, "update": False},
"total_pnl": {"display": "总盈亏", "cell": FloatCell, "update": False},
"net_pnl": {"display": "净盈亏", "cell": FloatCell, "update": False},
}
class BacktestingResultDialog(QtWidgets.QDialog):
"""
"""
def __init__(
self,
main_engine: MainEngine,
event_engine: EventEngine,
title: str,
table_class: QtWidgets.QTableWidget
):
""""""
super().__init__()
self.main_engine = main_engine
self.event_engine = event_engine
self.title = title
self.table_class = table_class
self.updated = False
self.init_ui()
def init_ui(self):
""""""
self.setWindowTitle(self.title)
self.resize(1100, 600)
self.table = self.table_class(self.main_engine, self.event_engine)
vbox = QtWidgets.QVBoxLayout()
vbox.addWidget(self.table)
self.setLayout(vbox)
def clear_data(self):
""""""
self.updated = False
self.table.setRowCount(0)
def update_data(self, data: list):
""""""
self.updated = True
data.reverse()
for obj in data:
self.table.insert_new_row(obj)
def is_updated(self):
""""""
return self.updated
class CandleChartDialog(QtWidgets.QDialog):
"""
"""
def __init__(self):
""""""
super().__init__()
self.updated = False
self.dt_ix_map = {}
self.ix_bar_map = {}
self.high_price = 0
self.low_price = 0
self.price_range = 0
self.items = []
self.init_ui()
def init_ui(self):
""""""
self.setWindowTitle("回测K线图表")
self.resize(1400, 800)
# Create chart widget
self.chart = NewChartWidget()
self.chart.add_plot("candle", hide_x_axis=True)
self.chart.add_plot("volume", maximum_height=200)
self.chart.add_plot("rsi", maximum_height=150)
self.chart.add_plot("macd", maximum_height=150)
self.chart.add_item(CandleItem, "candle", "candle")
self.chart.add_item(VolumeItem, "volume", "volume")
self.chart.add_item(SmaItem, "sma", "candle")
self.chart.add_item(RsiItem, "rsi", "rsi")
self.chart.add_item(MacdItem, "macd", "macd")
self.chart.add_last_price_line()
self.chart.add_cursor()
# Create help widget
text1 = "红色虚线 —— 盈利交易"
label1 = QtWidgets.QLabel(text1)
label1.setStyleSheet("color:red")
text2 = "绿色虚线 —— 亏损交易"
label2 = QtWidgets.QLabel(text2)
label2.setStyleSheet("color:#00FF00")
text3 = "黄色向上箭头 —— 买入开仓 Buy"
label3 = QtWidgets.QLabel(text3)
label3.setStyleSheet("color:yellow")
text4 = "黄色向下箭头 —— 卖出平仓 Sell"
label4 = QtWidgets.QLabel(text4)
label4.setStyleSheet("color:yellow")
text5 = "紫红向下箭头 —— 卖出开仓 Short"
label5 = QtWidgets.QLabel(text5)
label5.setStyleSheet("color:magenta")
text6 = "紫红向上箭头 —— 买入平仓 Cover"
label6 = QtWidgets.QLabel(text6)
label6.setStyleSheet("color:magenta")
hbox1 = QtWidgets.QHBoxLayout()
hbox1.addStretch()
hbox1.addWidget(label1)
hbox1.addStretch()
hbox1.addWidget(label2)
hbox1.addStretch()
hbox2 = QtWidgets.QHBoxLayout()
hbox2.addStretch()
hbox2.addWidget(label3)
hbox2.addStretch()
hbox2.addWidget(label4)
hbox2.addStretch()
hbox3 = QtWidgets.QHBoxLayout()
hbox3.addStretch()
hbox3.addWidget(label5)
hbox3.addStretch()
hbox3.addWidget(label6)
hbox3.addStretch()
# Set layout
vbox = QtWidgets.QVBoxLayout()
vbox.addWidget(self.chart)
vbox.addLayout(hbox1)
vbox.addLayout(hbox2)
vbox.addLayout(hbox3)
self.setLayout(vbox)
def update_history(self, history: list):
""""""
self.updated = True
self.chart.update_history(history)
for ix, bar in enumerate(history):
self.ix_bar_map[ix] = bar
self.dt_ix_map[bar.datetime] = ix
if not self.high_price:
self.high_price = bar.high_price
self.low_price = bar.low_price
else:
self.high_price = max(self.high_price, bar.high_price)
self.low_price = min(self.low_price, bar.low_price)
self.price_range = self.high_price - self.low_price
def update_trades(self, trades: list):
""""""
trade_pairs = generate_trade_pairs(trades)
candle_plot = self.chart.get_plot("candle")
scatter_data = []
y_adjustment = self.price_range * 0.001
for d in trade_pairs:
open_ix = self.dt_ix_map[d["open_dt"]]
close_ix = self.dt_ix_map[d["close_dt"]]
open_price = d["open_price"]
close_price = d["close_price"]
# Trade Line
x = [open_ix, close_ix]
y = [open_price, close_price]
if d["direction"] == Direction.LONG and close_price >= open_price:
color = "r"
elif d["direction"] == Direction.SHORT and close_price <= open_price:
color = "r"
else:
color = "g"
pen = pg.mkPen(color, width=1.5, style=QtCore.Qt.DashLine)
item = pg.PlotCurveItem(x, y, pen=pen)
self.items.append(item)
candle_plot.addItem(item)
# Trade Scatter
open_bar = self.ix_bar_map[open_ix]
close_bar = self.ix_bar_map[close_ix]
if d["direction"] == Direction.LONG:
scatter_color = "yellow"
open_symbol = "t1"
close_symbol = "t"
open_side = 1
close_side = -1
open_y = open_bar.low_price
close_y = close_bar.high_price
else:
scatter_color = "magenta"
open_symbol = "t"
close_symbol = "t1"
open_side = -1
close_side = 1
open_y = open_bar.high_price
close_y = close_bar.low_price
pen = pg.mkPen(QtGui.QColor(scatter_color))
brush = pg.mkBrush(QtGui.QColor(scatter_color))
size = 10
open_scatter = {
"pos": (open_ix, open_y - open_side * y_adjustment),
"size": size,
"pen": pen,
"brush": brush,
"symbol": open_symbol
}
close_scatter = {
"pos": (close_ix, close_y - close_side * y_adjustment),
"size": size,
"pen": pen,
"brush": brush,
"symbol": close_symbol
}
scatter_data.append(open_scatter)
scatter_data.append(close_scatter)
# Trade text
volume = d["volume"]
text_color = QtGui.QColor(scatter_color)
open_text = pg.TextItem(f"[{volume}]", color=text_color, anchor=(0.5, 0.5))
close_text = pg.TextItem(f"[{volume}]", color=text_color, anchor=(0.5, 0.5))
open_text.setPos(open_ix, open_y - open_side * y_adjustment * 3)
close_text.setPos(close_ix, close_y - close_side * y_adjustment * 3)
self.items.append(open_text)
self.items.append(close_text)
candle_plot.addItem(open_text)
candle_plot.addItem(close_text)
trade_scatter = pg.ScatterPlotItem(scatter_data)
self.items.append(trade_scatter)
candle_plot.addItem(trade_scatter)
def clear_data(self):
""""""
self.updated = False
candle_plot = self.chart.get_plot("candle")
for item in self.items:
candle_plot.removeItem(item)
self.items.clear()
self.chart.clear_all()
self.dt_ix_map.clear()
self.ix_bar_map.clear()
def is_updated(self):
""""""
return self.updated
def generate_trade_pairs(trades: list) -> list:
""""""
long_trades = []
short_trades = []
trade_pairs = []
for trade in trades:
trade = copy(trade)
if trade.direction == Direction.LONG:
same_direction = long_trades
opposite_direction = short_trades
else:
same_direction = short_trades
opposite_direction = long_trades
while trade.volume and opposite_direction:
open_trade = opposite_direction[0]
close_volume = min(open_trade.volume, trade.volume)
d = {
"open_dt": open_trade.datetime,
"open_price": open_trade.price,
"close_dt": trade.datetime,
"close_price": trade.price,
"direction": open_trade.direction,
"volume": close_volume,
}
trade_pairs.append(d)
open_trade.volume -= close_volume
if not open_trade.volume:
opposite_direction.pop(0)
trade.volume -= close_volume
if trade.volume:
same_direction.append(trade)
return trade_pairs
5. vnpy/app/chart_wizard/ui/widget.py
from typing import List, Dict, Type
import pyqtgraph as pg
from vnpy.trader.ui import QtGui, QtWidgets, QtCore
from vnpy.trader.object import BarData
from vnpy.event import Event, EventEngine
from .manager import BarManager
from .base import (
GREY_COLOR, WHITE_COLOR, CURSOR_COLOR, BLACK_COLOR,
to_int, NORMAL_FONT
)
from .axis import DatetimeAxis
from .item import ChartItem
from vnpy.app.cta_strategy.base import (
EVENT_CTA_TICK,
EVENT_CTA_BAR,
EVENT_CTA_ORDER,
EVENT_CTA_TRADE,
EVENT_CTA_HISTORY_BAR
)
pg.setConfigOptions(antialias=True)
class ChartWidget(pg.PlotWidget):
""""""
MIN_BAR_COUNT = 100
def __init__(self, parent: QtWidgets.QWidget = None):
""""""
super().__init__(parent)
self._manager: BarManager = BarManager()
self._plots: Dict[str, pg.PlotItem] = {}
self._items: Dict[str, ChartItem] = {}
self._item_plot_map: Dict[ChartItem, pg.PlotItem] = {}
self._first_plot: pg.PlotItem = None
self._cursor: ChartCursor = None
self._right_ix: int = 0 # Index of most right data
self._bar_count: int = self.MIN_BAR_COUNT # Total bar visible in chart
self._init_ui()
def _init_ui(self) -> None:
""""""
self.setWindowTitle("ChartWidget of vn.py")
self._layout = pg.GraphicsLayout()
self._layout.setContentsMargins(10, 10, 10, 10)
self._layout.setSpacing(0)
self._layout.setBorder(color=GREY_COLOR, width=0.8)
self._layout.setZValue(0)
self.setCentralItem(self._layout)
def _get_new_x_axis(self):
return DatetimeAxis(self._manager, orientation='bottom')
def add_cursor(self) -> None:
""""""
if not self._cursor:
self._cursor = ChartCursor(
self, self._manager, self._plots, self._item_plot_map)
def add_plot(
self,
plot_name: str,
minimum_height: int = 80,
maximum_height: int = None,
hide_x_axis: bool = False
) -> None:
"""
Add plot area.
"""
# Create plot object
plot = pg.PlotItem(axisItems={'bottom': self._get_new_x_axis()})
plot.setMenuEnabled(False)
plot.setClipToView(True)
plot.hideAxis('left')
plot.showAxis('right')
plot.setDownsampling(mode='peak')
plot.setRange(xRange=(0, 1), yRange=(0, 1))
plot.hideButtons()
plot.setMinimumHeight(minimum_height)
if maximum_height:
plot.setMaximumHeight(maximum_height)
if hide_x_axis:
plot.hideAxis("bottom")
if not self._first_plot:
self._first_plot = plot
# Connect view change signal to update y range function
view = plot.getViewBox()
view.sigXRangeChanged.connect(self._update_y_range)
view.setMouseEnabled(x=True, y=False)
# Set right axis
right_axis = plot.getAxis('right')
right_axis.setWidth(60)
right_axis.tickFont = NORMAL_FONT
# Connect x-axis link
if self._plots:
first_plot = list(self._plots.values())[0]
plot.setXLink(first_plot)
# Store plot object in dict
self._plots[plot_name] = plot
# Add plot onto the layout
self._layout.nextRow()
self._layout.addItem(plot)
def add_item(
self,
item_class: Type[ChartItem],
item_name: str,
plot_name: str
):
"""
Add chart item.
"""
item = item_class(self._manager)
self._items[item_name] = item
plot = self._plots.get(plot_name)
plot.addItem(item)
self._item_plot_map[item] = plot
def get_plot(self, plot_name: str) -> pg.PlotItem:
"""
Get specific plot with its name.
"""
return self._plots.get(plot_name, None)
def get_all_plots(self) -> List[pg.PlotItem]:
"""
Get all plot objects.
"""
return self._plots.values()
def clear_all(self) -> None:
"""
Clear all data.
"""
self._manager.clear_all()
for item in self._items.values():
item.clear_all()
if self._cursor:
self._cursor.clear_all()
def update_history(self, history: List[BarData]) -> None:
"""
Update a list of bar data.
"""
self._manager.update_history(history)
for item in self._items.values():
item.update_history(history)
self._update_plot_limits()
self.move_to_right()
def update_bar(self, bar: BarData) -> None:
"""
Update single bar data.
"""
self._manager.update_bar(bar)
for item in self._items.values():
item.update_bar(bar)
self._update_plot_limits()
if self._right_ix >= (self._manager.get_count() - self._bar_count / 2):
self.move_to_right()
def _update_plot_limits(self) -> None:
"""
Update the limit of plots.
"""
for item, plot in self._item_plot_map.items():
min_value, max_value = item.get_y_range()
plot.setLimits(
xMin=-1,
xMax=self._manager.get_count(),
yMin=min_value,
yMax=max_value
)
def _update_x_range(self) -> None:
"""
Update the x-axis range of plots.
"""
max_ix = self._right_ix
min_ix = self._right_ix - self._bar_count
for plot in self._plots.values():
plot.setRange(xRange=(min_ix, max_ix), padding=0)
def _update_y_range(self) -> None:
"""
Update the y-axis range of plots.
"""
view = self._first_plot.getViewBox()
view_range = view.viewRange()
min_ix = max(0, int(view_range[0][0]))
max_ix = min(self._manager.get_count(), int(view_range[0][1]))
# Update limit for y-axis
for item, plot in self._item_plot_map.items():
y_range = item.get_y_range(min_ix, max_ix)
plot.setRange(yRange=y_range)
def paintEvent(self, event: QtGui.QPaintEvent) -> None:
"""
Reimplement this method of parent to update current max_ix value.
"""
view = self._first_plot.getViewBox()
view_range = view.viewRange()
self._right_ix = max(0, view_range[0][1])
super().paintEvent(event)
def keyPressEvent(self, event: QtGui.QKeyEvent) -> None:
"""
Reimplement this method of parent to move chart horizontally and zoom in/out.
"""
if event.key() == QtCore.Qt.Key_Left:
self._on_key_left()
elif event.key() == QtCore.Qt.Key_Right:
self._on_key_right()
elif event.key() == QtCore.Qt.Key_Up:
self._on_key_up()
elif event.key() == QtCore.Qt.Key_Down:
self._on_key_down()
def wheelEvent(self, event: QtGui.QWheelEvent) -> None:
"""
Reimplement this method of parent to zoom in/out.
"""
delta = event.angleDelta()
if delta.y() > 0:
self._on_key_up()
elif delta.y() < 0:
self._on_key_down()
def _on_key_left(self) -> None:
"""
Move chart to left.
"""
self._right_ix -= 1
self._right_ix = max(self._right_ix, self._bar_count)
self._update_x_range()
self._cursor.move_left()
self._cursor.update_info()
def _on_key_right(self) -> None:
"""
Move chart to right.
"""
self._right_ix += 1
self._right_ix = min(self._right_ix, self._manager.get_count())
self._update_x_range()
self._cursor.move_right()
self._cursor.update_info()
def _on_key_down(self) -> None:
"""
Zoom out the chart.
"""
self._bar_count *= 1.2
self._bar_count = min(int(self._bar_count), self._manager.get_count())
self._update_x_range()
self._cursor.update_info()
def _on_key_up(self) -> None:
"""
Zoom in the chart.
"""
self._bar_count /= 1.2
self._bar_count = max(int(self._bar_count), self.MIN_BAR_COUNT)
self._update_x_range()
self._cursor.update_info()
def move_to_right(self) -> None:
"""
Move chart to the most right.
"""
self._right_ix = self._manager.get_count()
self._update_x_range()
self._cursor.update_info()
class ChartCursor(QtCore.QObject):
""""""
def __init__(
self,
widget: ChartWidget,
manager: BarManager,
plots: Dict[str, pg.GraphicsObject],
item_plot_map: Dict[ChartItem, pg.GraphicsObject]
):
""""""
super().__init__()
self._widget: ChartWidget = widget
self._manager: BarManager = manager
self._plots: Dict[str, pg.GraphicsObject] = plots
self._item_plot_map: Dict[ChartItem, pg.GraphicsObject] = item_plot_map
self._x: int = 0
self._y: int = 0
self._plot_name: str = ""
self._init_ui()
self._connect_signal()
def _init_ui(self):
""""""
self._init_line()
self._init_label()
self._init_info()
def _init_line(self) -> None:
"""
Create line objects.
"""
self._v_lines: Dict[str, pg.InfiniteLine] = {}
self._h_lines: Dict[str, pg.InfiniteLine] = {}
self._views: Dict[str, pg.ViewBox] = {}
pen = pg.mkPen(CURSOR_COLOR, style=QtCore.Qt.DashDotLine)
for plot_name, plot in self._plots.items():
v_line = pg.InfiniteLine(angle=90, movable=False, pen=pen)
h_line = pg.InfiniteLine(angle=0, movable=False, pen=pen)
view = plot.getViewBox()
for line in [v_line, h_line]:
line.setZValue(0)
line.hide()
view.addItem(line)
self._v_lines[plot_name] = v_line
self._h_lines[plot_name] = h_line
self._views[plot_name] = view
def _init_label(self) -> None:
"""
Create label objects on axis.
"""
self._y_labels: Dict[str, pg.TextItem] = {}
for plot_name, plot in self._plots.items():
label = pg.TextItem(
plot_name, fill=GREY_COLOR, color=WHITE_COLOR)
label.hide()
label.setZValue(2)
label.setFont(NORMAL_FONT)
plot.addItem(label, ignoreBounds=True)
self._y_labels[plot_name] = label
self._x_label: pg.TextItem = pg.TextItem(
"datetime", fill=GREY_COLOR, color=WHITE_COLOR)
self._x_label.hide()
self._x_label.setZValue(2)
self._x_label.setFont(NORMAL_FONT)
plot.addItem(self._x_label, ignoreBounds=True)
def _init_info(self) -> None:
"""
"""
self._infos: Dict[str, pg.TextItem] = {}
for plot_name, plot in self._plots.items():
info = pg.TextItem(
"info",
color=CURSOR_COLOR
)
info.hide()
info.setZValue(2)
info.setFont(NORMAL_FONT)
plot.addItem(info, ignoreBounds=True)
self._infos[plot_name] = info
def _connect_signal(self) -> None:
"""
Connect mouse move signal to update function.
"""
self._widget.scene().sigMouseMoved.connect(self._mouse_moved)
def _mouse_moved(self, evt: tuple) -> None:
"""
Callback function when mouse is moved.
"""
if not self._manager.get_count():
return
# First get current mouse point
pos = evt
for plot_name, view in self._views.items():
rect = view.sceneBoundingRect()
if rect.contains(pos):
mouse_point = view.mapSceneToView(pos)
self._x = to_int(mouse_point.x())
self._y = mouse_point.y()
self._plot_name = plot_name
break
# Then update cursor component
self._update_line()
self._update_label()
self.update_info()
def _update_line(self) -> None:
""""""
for v_line in self._v_lines.values():
v_line.setPos(self._x)
v_line.show()
for plot_name, h_line in self._h_lines.items():
if plot_name == self._plot_name:
h_line.setPos(self._y)
h_line.show()
else:
h_line.hide()
def _update_label(self) -> None:
""""""
bottom_plot = list(self._plots.values())[-1]
axis_width = bottom_plot.getAxis("right").width()
axis_height = bottom_plot.getAxis("bottom").height()
axis_offset = QtCore.QPointF(axis_width, axis_height)
bottom_view = list(self._views.values())[-1]
bottom_right = bottom_view.mapSceneToView(
bottom_view.sceneBoundingRect().bottomRight() - axis_offset
)
for plot_name, label in self._y_labels.items():
if plot_name == self._plot_name:
label.setText(str(self._y))
label.show()
label.setPos(bottom_right.x(), self._y)
else:
label.hide()
dt = self._manager.get_datetime(self._x)
if dt:
self._x_label.setText(dt.strftime("%Y-%m-%d %H:%M:%S"))
self._x_label.show()
self._x_label.setPos(self._x, bottom_right.y())
self._x_label.setAnchor((0, 0))
def update_info(self) -> None:
""""""
buf = {}
for item, plot in self._item_plot_map.items():
item_info_text = item.get_info_text(self._x)
if plot not in buf:
buf[plot] = item_info_text
else:
if item_info_text:
buf[plot] += ("\n\n" + item_info_text)
for plot_name, plot in self._plots.items():
plot_info_text = buf[plot]
info = self._infos[plot_name]
info.setText(plot_info_text)
info.show()
view = self._views[plot_name]
top_left = view.mapSceneToView(view.sceneBoundingRect().topLeft())
info.setPos(top_left)
def move_right(self) -> None:
"""
Move cursor index to right by 1.
"""
if self._x == self._manager.get_count() - 1:
return
self._x += 1
self._update_after_move()
def move_left(self) -> None:
"""
Move cursor index to left by 1.
"""
if self._x == 0:
return
self._x -= 1
self._update_after_move()
def _update_after_move(self) -> None:
"""
Update cursor after moved by left/right.
"""
bar = self._manager.get_bar(self._x)
self._y = bar.close_price
self._update_line()
self._update_label()
def clear_all(self) -> None:
"""
Clear all data.
"""
self._x = 0
self._y = 0
self._plot_name = ""
for line in list(self._v_lines.values()) + list(self._h_lines.values()):
line.hide()
for label in list(self._y_labels.values()) + [self._x_label]:
label.hide()
class NewChartWidget(ChartWidget):
""""""
MIN_BAR_COUNT = 100
strategy_name: str = ""
signal_cta_history_bar: QtCore.pyqtSignal = QtCore.pyqtSignal(Event)
signal_cta_tick: QtCore.pyqtSignal = QtCore.pyqtSignal(Event)
signal_cta_bar: QtCore.pyqtSignal = QtCore.pyqtSignal(Event)
def __init__(self, parent: QtWidgets.QWidget = None, event_engine: EventEngine = None, strategy_name: str = ""):
""""""
super().__init__(parent)
self.strategy_name = strategy_name
self.event_engine = event_engine
self.last_price_line: pg.InfiniteLine = None
def register_event(self) -> None:
""""""
self.signal_cta_history_bar.connect(self.process_cta_history_bar)
self.event_engine.register(EVENT_CTA_HISTORY_BAR, self.signal_cta_history_bar.emit)
self.signal_cta_tick.connect(self.process_tick_event)
self.event_engine.register(EVENT_CTA_TICK, self.signal_cta_tick.emit)
self.signal_cta_bar.connect(self.process_cta_bar)
self.event_engine.register(EVENT_CTA_BAR, self.signal_cta_bar.emit)
def process_cta_history_bar(self, event: Event) -> None:
""" 处理历史K线推送 """
strategy_name, history_bars = event.data
if strategy_name == self.strategy_name:
self.update_history(history_bars)
def process_tick_event(self, event: Event) -> None:
""" 处理tick数据推送 """
strategy_name, tick = event.data
if strategy_name == self.strategy_name:
if self.last_price_line:
self.last_price_line.setValue(tick.last_price)
def process_cta_bar(self, event: Event) -> None:
""" 处理K线数据推送 """
strategy_name, bar = event.data
if strategy_name == self.strategy_name:
self.update_bar(bar)
def add_last_price_line(self):
""""""
plot = list(self._plots.values())[0]
color = (255, 255, 0)
self.last_price_line = pg.InfiniteLine(
angle=0,
movable=False,
label="{value}",
pen=pg.mkPen(color, width=1, style=QtCore.Qt.DashDotLine),
labelOpts={
"color": color,
"position": 1,
"anchors": [(1, 1), (1, 1)]
}
)
self.last_price_line.label.setFont(NORMAL_FONT)
plot.addItem(self.last_price_line)
def update_history(self, history: List[BarData]) -> None:
"""
Update a list of bar data.
"""
self._manager.update_history(history)
for item in self._items.values():
item.update_history(history)
self._update_plot_limits()
self.move_to_right()
self.update_last_price_line(history[-1])
def update_bar(self, bar: BarData) -> None:
"""
Update single bar data.
"""
self._manager.update_bar(bar)
for item in self._items.values():
item.update_bar(bar)
self._update_plot_limits()
if self._right_ix >= (self._manager.get_count() - self._bar_count / 2):
self.move_to_right()
self.update_last_price_line(bar)
def update_last_price_line(self, bar: BarData) -> None:
""""""
if self.last_price_line:
self.last_price_line.setValue(bar.close_price)
特别感谢hxxjava