VeighNa量化社区
你的开源社区量化交易平台
Member
avatar
加入于:
帖子: 419
声望: 170

本主题已经失败了,正确的做法请参见:https://www.vnpy.com/forum/topic/3920-wei-kxian-tu-biao-tian-zhuan-jia-wa-rang-ctace-lue-de-yun-xing-kan-de-jian-1

因为没有办法删除自己的主题,所以另外开了一个(1)的主题。

需要为CTA策略配上K线图表

VNPY系统确实是非常好的开发平台!可是做CTA交易的朋友一定有这样的感觉,策略在安静地运行,你却看不见它的K线,也无法直观地看到它的交易情况。经常来回地在第三方的行情客户端来回切换,为的是看看都对策略运动什么样一个情况。如果能够为每个CTA策略都给出一个观看图表的机会,那该有多好!
本人经过一周时间的努力下,已经基本让CTA策略可以看的见了。

实现思路

1 实现一个给予ChartWizard的K线图NewChartWidget,它可以接受历史K线数据、tick数据、bar线数据、order数据、trader数据的推送
2 定义五个与行情相关的消息类型

为K线图表添砖加瓦——让CTA策略看得见(1)

EVENT_CTA_HISTORY_BAR  —— 历史K线消息
EVENT_CTA_TICK —— TICK消息
EVENT_CTA_BAR —— BAR消息
EVENT_CTA_ORDER  —— 委托单消息
EVENT_CTA_TRADE  —— 成交单
EVENT_CTA_STOPORDER —— 停止单消息(系统本来就有)

3 利用策略的on_start(),on_tick(),on_bar(),on_order(),on_stop_order(),on_trade()接口,发送这些消息
4 可以接受实时的临时K线的显示
5 可以显示行情,也已显示委托、成交的发生位置

先看看效果图:

description

操作简便

经过修改后,在CTA策略管理界面中,如上图所示,每个策略都会增加一个 “K线图表” 按钮,在已经新建好策略后,只有策略初始化后“K线图表” 按钮和“开始”按钮同时有效,如果你想观看该策略的K线图,先点击“K线图表” 按钮,就会显示一个空的K线图窗口,再和原来一样按“开始”按钮,K线图立即就把策略在初始化阶段读取的历史K线显示处理,当然你的你所关心的主图指标和附图指标可以是可以显示的。

下一步的想法:

目前只实现了几个代表性的几个指标,还不能给自由配置。未来应该是设计成可以在每个策略可以独立配置自己的K线图表的主图和附图的数量可指标,因为不同的策略的算法是不一样,周期也是可能是不一样的。

下一次,我会把目前功能的实现方法和代码发表出,正在整理,请大家等待。

Member
avatar
加入于:
帖子: 419
声望: 170

实现方法:

1. 先备文件

vnpy\app\cta_strategy\base.py
vnpy\app\cta_strategy\ui\widget.py

2. 修改vnpy\app\cta_strategy\base.py

内容如下:

"""
Defines constants and objects used in CtaStrategy App.
"""

from dataclasses import dataclass, field
from enum import Enum
from datetime import timedelta

from vnpy.trader.constant import Direction, Offset, Interval

APP_NAME = "CtaStrategy"
STOPORDER_PREFIX = "STOP"


class StopOrderStatus(Enum):
    WAITING = "等待中"
    CANCELLED = "已撤销"
    TRIGGERED = "已触发"


class EngineType(Enum):
    LIVE = "实盘"
    BACKTESTING = "回测"


class BacktestingMode(Enum):
    BAR = 1
    TICK = 2


@dataclass
class StopOrder:
    vt_symbol: str
    direction: Direction
    offset: Offset
    price: float
    volume: float
    stop_orderid: str
    strategy_name: str
    lock: bool = False
    vt_orderids: list = field(default_factory=list)
    status: StopOrderStatus = StopOrderStatus.WAITING


EVENT_CTA_LOG = "eCtaLog"
EVENT_CTA_STRATEGY = "eCtaStrategy"
EVENT_CTA_STOPORDER = "eCtaStopOrder"

EVENT_CTA_TICK = "eCtaTick"                     # hxxjava add
EVENT_CTA_HISTORY_BAR = "eCtaHistoryBar"        # hxxjava add
EVENT_CTA_BAR = "eCtaBar"                       # hxxjava add
EVENT_CTA_ORDER = "eCtaOrder"                   # hxxjava add  
EVENT_CTA_TRADE = "eCtaTrade"                   # hxxjava add


INTERVAL_DELTA_MAP = {
    Interval.MINUTE: timedelta(minutes=1),
    Interval.HOUR: timedelta(hours=1),
    Interval.DAILY: timedelta(days=1),
}

3. 修改vnpy\app\cta_strategy\ui\widget.py

修改如下:

from vnpy.event import Event, EventEngine
from vnpy.trader.engine import MainEngine
from vnpy.trader.ui import QtCore, QtGui, QtWidgets
from vnpy.trader.ui.widget import (
    BaseCell,
    EnumCell,
    MsgCell,
    TimeCell,
    BaseMonitor
)

from ..base import (
    APP_NAME,
    EVENT_CTA_LOG,
    EVENT_CTA_STOPORDER,
    EVENT_CTA_STRATEGY,
)

from ..engine import CtaEngine

from vnpy.usertools.kx_chart import (   # hxxjava add
    NewChartWidget,
    CandleItem,
    VolumeItem, 
    LineItem,
    SmaItem,
    RsiItem,
    MacdItem,
)


class CtaManager(QtWidgets.QWidget):
    """"""

    signal_log = QtCore.pyqtSignal(Event)
    signal_strategy = QtCore.pyqtSignal(Event)

    def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
        super(CtaManager, self).__init__()

        self.main_engine = main_engine
        self.event_engine = event_engine
        self.cta_engine = main_engine.get_engine(APP_NAME)

        self.managers = {}

        self.init_ui()
        self.register_event()
        self.cta_engine.init_engine()
        self.update_class_combo()

    def init_ui(self):
        """"""
        self.setWindowTitle("CTA策略")

        # Create widgets
        self.class_combo = QtWidgets.QComboBox()

        add_button = QtWidgets.QPushButton("添加策略")
        add_button.clicked.connect(self.add_strategy)

        init_button = QtWidgets.QPushButton("全部初始化")
        init_button.clicked.connect(self.cta_engine.init_all_strategies)

        start_button = QtWidgets.QPushButton("全部启动")
        start_button.clicked.connect(self.cta_engine.start_all_strategies)

        stop_button = QtWidgets.QPushButton("全部停止")
        stop_button.clicked.connect(self.cta_engine.stop_all_strategies)

        clear_button = QtWidgets.QPushButton("清空日志")
        clear_button.clicked.connect(self.clear_log)

        self.scroll_layout = QtWidgets.QVBoxLayout()
        self.scroll_layout.addStretch()

        scroll_widget = QtWidgets.QWidget()
        scroll_widget.setLayout(self.scroll_layout)

        scroll_area = QtWidgets.QScrollArea()
        scroll_area.setWidgetResizable(True)
        scroll_area.setWidget(scroll_widget)

        self.log_monitor = LogMonitor(self.main_engine, self.event_engine)

        self.stop_order_monitor = StopOrderMonitor(
            self.main_engine, self.event_engine
        )

        # Set layout
        hbox1 = QtWidgets.QHBoxLayout()
        hbox1.addWidget(self.class_combo)
        hbox1.addWidget(add_button)
        hbox1.addStretch()
        hbox1.addWidget(init_button)
        hbox1.addWidget(start_button)
        hbox1.addWidget(stop_button)
        hbox1.addWidget(clear_button)

        grid = QtWidgets.QGridLayout()
        grid.addWidget(scroll_area, 0, 0, 2, 1)
        grid.addWidget(self.stop_order_monitor, 0, 1)
        grid.addWidget(self.log_monitor, 1, 1)

        vbox = QtWidgets.QVBoxLayout()
        vbox.addLayout(hbox1)
        vbox.addLayout(grid)

        self.setLayout(vbox)

    def update_class_combo(self):
        """"""
        self.class_combo.addItems(
            self.cta_engine.get_all_strategy_class_names()
        )

    def register_event(self):
        """"""
        self.signal_strategy.connect(self.process_strategy_event)

        self.event_engine.register(
            EVENT_CTA_STRATEGY, self.signal_strategy.emit
        )

    def process_strategy_event(self, event):
        """
        Update strategy status onto its monitor.
        """
        data = event.data
        strategy_name = data["strategy_name"]

        if strategy_name in self.managers:
            manager = self.managers[strategy_name]
            manager.update_data(data)
        else:
            manager = StrategyManager(self, self.cta_engine, data)
            self.scroll_layout.insertWidget(0, manager)
            self.managers[strategy_name] = manager

    def remove_strategy(self, strategy_name):
        """"""
        manager = self.managers.pop(strategy_name)
        manager.deleteLater()

    def add_strategy(self):
        """"""
        class_name = str(self.class_combo.currentText())
        if not class_name:
            return

        parameters = self.cta_engine.get_strategy_class_parameters(class_name)
        editor = SettingEditor(parameters, class_name=class_name)
        n = editor.exec_()

        if n == editor.Accepted:
            setting = editor.get_setting()
            vt_symbol = setting.pop("vt_symbol")
            strategy_name = setting.pop("strategy_name")

            self.cta_engine.add_strategy(
                class_name, strategy_name, vt_symbol, setting
            )

    def clear_log(self):
        """"""
        self.log_monitor.setRowCount(0)

    def show(self):
        """"""
        self.showMaximized()



class StrategyManager(QtWidgets.QFrame):
    """
    Manager for a strategy
    """
    kx_chart:NewChartWidget = None  # hxxjava add

    def __init__(
        self, cta_manager: CtaManager, cta_engine: CtaEngine, data: dict
    ):
        """"""
        super(StrategyManager, self).__init__()

        self.cta_manager = cta_manager
        self.cta_engine = cta_engine

        self.strategy_name = data["strategy_name"]
        self._data = data

        self.init_ui()

    def init_ui(self):
        """"""
        self.setFixedHeight(300)
        self.setFrameShape(self.Box)
        self.setLineWidth(1)

        self.init_button = QtWidgets.QPushButton("初始化")
        self.init_button.clicked.connect(self.init_strategy)

        self.start_button = QtWidgets.QPushButton("启动")
        self.start_button.clicked.connect(self.start_strategy)
        self.start_button.setEnabled(False)

        # hxxjava add start
        self.kx_button = QtWidgets.QPushButton("K线图表")   
        self.kx_button.clicked.connect(self.open_kx_chart)
        self.kx_button.setEnabled(False)
        # hxxjava add end

        self.stop_button = QtWidgets.QPushButton("停止")
        self.stop_button.clicked.connect(self.stop_strategy)
        self.stop_button.setEnabled(False)

        self.edit_button = QtWidgets.QPushButton("编辑")
        self.edit_button.clicked.connect(self.edit_strategy)

        self.remove_button = QtWidgets.QPushButton("移除")
        self.remove_button.clicked.connect(self.remove_strategy)

        strategy_name = self._data["strategy_name"]
        vt_symbol = self._data["vt_symbol"]
        class_name = self._data["class_name"]
        author = self._data["author"]

        label_text = (
            f"{strategy_name}  -  {vt_symbol}  ({class_name} by {author})"
        )
        label = QtWidgets.QLabel(label_text)
        label.setAlignment(QtCore.Qt.AlignCenter)

        self.parameters_monitor = DataMonitor(self._data["parameters"])
        self.variables_monitor = DataMonitor(self._data["variables"])

        hbox = QtWidgets.QHBoxLayout()
        hbox.addWidget(self.init_button)
        hbox.addWidget(self.kx_button)      # hxxjava add
        hbox.addWidget(self.start_button)
        hbox.addWidget(self.stop_button)
        hbox.addWidget(self.edit_button)
        hbox.addWidget(self.remove_button)

        vbox = QtWidgets.QVBoxLayout()
        vbox.addWidget(label)
        vbox.addLayout(hbox)
        vbox.addWidget(self.parameters_monitor)
        vbox.addWidget(self.variables_monitor)
        self.setLayout(vbox)

    def update_data(self, data: dict):
        """"""
        self._data = data

        self.parameters_monitor.update_data(data["parameters"])
        self.variables_monitor.update_data(data["variables"])

        # Update button status
        variables = data["variables"]
        inited = variables["inited"]
        trading = variables["trading"]

        if not inited:
            return
        self.init_button.setEnabled(False)

        if trading:
            self.kx_button.setEnabled(False)     # hxxjava add
            self.start_button.setEnabled(False)
            self.stop_button.setEnabled(True)
            self.edit_button.setEnabled(False)
            self.remove_button.setEnabled(False)
        else:
            self.kx_button.setEnabled(True)     # hxxjava add
            self.start_button.setEnabled(True)
            self.stop_button.setEnabled(False)
            self.edit_button.setEnabled(True)
            self.remove_button.setEnabled(True)

    def init_strategy(self):
        """"""
        self.cta_engine.init_strategy(self.strategy_name)

    def start_strategy(self):
        """"""
        self.cta_engine.start_strategy(self.strategy_name)

    def stop_strategy(self):
        """"""
        self.cta_engine.stop_strategy(self.strategy_name)

    def edit_strategy(self):
        """"""
        strategy_name = self._data["strategy_name"]

        parameters = self.cta_engine.get_strategy_parameters(strategy_name)
        editor = SettingEditor(parameters, strategy_name=strategy_name)
        n = editor.exec_()

        if n == editor.Accepted:
            setting = editor.get_setting()
            self.cta_engine.edit_strategy(strategy_name, setting)

    def remove_strategy(self):
        """"""
        result = self.cta_engine.remove_strategy(self.strategy_name)

        # Only remove strategy gui manager if it has been removed from engine
        if result:
            self.cta_manager.remove_strategy(self.strategy_name)

    def open_kx_chart(self):    # hxx add
        event_engine = self.cta_engine.event_engine
        if not self.kx_chart:
            self.kx_chart = NewChartWidget(self,event_engine)

            self.kx_chart.setWindowTitle(f"K线图表——{self.strategy_name}")

            self.kx_chart.add_plot("candle", hide_x_axis=True)
            self.kx_chart.add_plot("volume", maximum_height=150)
            self.kx_chart.add_plot("rsi", maximum_height=150)
            self.kx_chart.add_plot("macd", maximum_height=150)
            self.kx_chart.add_item(CandleItem, "candle", "candle")
            self.kx_chart.add_item(VolumeItem, "volume", "volume")

            self.kx_chart.add_item(LineItem, "line", "candle")
            self.kx_chart.add_item(SmaItem, "sma", "candle")
            self.kx_chart.add_item(RsiItem, "rsi", "rsi")
            self.kx_chart.add_item(MacdItem, "macd", "macd")
            self.kx_chart.add_last_price_line()
            self.kx_chart.add_cursor()

            self.kx_chart.register_event()

        self.kx_chart.show()

class DataMonitor(QtWidgets.QTableWidget):
    """
    Table monitor for parameters and variables.
    """

    def __init__(self, data: dict):
        """"""
        super(DataMonitor, self).__init__()

        self._data = data
        self.cells = {}

        self.init_ui()

    def init_ui(self):
        """"""
        labels = list(self._data.keys())
        self.setColumnCount(len(labels))
        self.setHorizontalHeaderLabels(labels)

        self.setRowCount(1)
        self.verticalHeader().setSectionResizeMode(
            QtWidgets.QHeaderView.Stretch
        )
        self.verticalHeader().setVisible(False)
        self.setEditTriggers(self.NoEditTriggers)

        for column, name in enumerate(self._data.keys()):
            value = self._data[name]

            cell = QtWidgets.QTableWidgetItem(str(value))
            cell.setTextAlignment(QtCore.Qt.AlignCenter)

            self.setItem(0, column, cell)
            self.cells[name] = cell

    def update_data(self, data: dict):
        """"""
        for name, value in data.items():
            cell = self.cells[name]
            cell.setText(str(value))


class StopOrderMonitor(BaseMonitor):
    """
    Monitor for local stop order.
    """

    event_type = EVENT_CTA_STOPORDER
    data_key = "stop_orderid"
    sorting = True

    headers = {
        "stop_orderid": {"display": "停止委托号","cell": BaseCell,"update": False,},
        "vt_orderids": {"display": "限价委托号", "cell": BaseCell, "update": True},
        "vt_symbol": {"display": "本地代码", "cell": BaseCell, "update": False},
        "direction": {"display": "方向", "cell": EnumCell, "update": False},
        "offset": {"display": "开平", "cell": EnumCell, "update": False},
        "price": {"display": "价格", "cell": BaseCell, "update": False},
        "volume": {"display": "数量", "cell": BaseCell, "update": False},
        "status": {"display": "状态", "cell": EnumCell, "update": True},
        "lock": {"display": "锁仓", "cell": BaseCell, "update": False},
        "strategy_name": {"display": "策略名", "cell": BaseCell, "update": False},
    }


class LogMonitor(BaseMonitor):
    """
    Monitor for log data.
    """

    event_type = EVENT_CTA_LOG
    data_key = ""
    sorting = False

    headers = {
        "time": {"display": "时间", "cell": TimeCell, "update": False},
        "msg": {"display": "信息", "cell": MsgCell, "update": False},
    }

    def init_ui(self):
        """
        Stretch last column.
        """
        super(LogMonitor, self).init_ui()

        self.horizontalHeader().setSectionResizeMode(
            1, QtWidgets.QHeaderView.Stretch
        )

    def insert_new_row(self, data):
        """
        Insert a new row at the top of table.
        """
        super(LogMonitor, self).insert_new_row(data)
        self.resizeRowToContents(0)


class SettingEditor(QtWidgets.QDialog):
    """
    For creating new strategy and editing strategy parameters.
    """

    def __init__(
        self, parameters: dict, strategy_name: str = "", class_name: str = ""
    ):
        """"""
        super(SettingEditor, self).__init__()

        self.parameters = parameters
        self.strategy_name = strategy_name
        self.class_name = class_name

        self.edits = {}

        self.init_ui()

    def init_ui(self):
        """"""
        form = QtWidgets.QFormLayout()

        # Add vt_symbol and name edit if add new strategy
        if self.class_name:
            self.setWindowTitle(f"添加策略:{self.class_name}")
            button_text = "添加"
            parameters = {"strategy_name": "", "vt_symbol": ""}
            parameters.update(self.parameters)
        else:
            self.setWindowTitle(f"参数编辑:{self.strategy_name}")
            button_text = "确定"
            parameters = self.parameters

        for name, value in parameters.items():
            type_ = type(value)

            edit = QtWidgets.QLineEdit(str(value))
            if type_ is int:
                validator = QtGui.QIntValidator()
                edit.setValidator(validator)
            elif type_ is float:
                validator = QtGui.QDoubleValidator()
                edit.setValidator(validator)

            form.addRow(f"{name} {type_}", edit)

            self.edits[name] = (edit, type_)

        button = QtWidgets.QPushButton(button_text)
        button.clicked.connect(self.accept)
        form.addRow(button)

        widget = QtWidgets.QWidget()
        widget.setLayout(form)

        scroll = QtWidgets.QScrollArea()
        scroll.setWidgetResizable(True)
        scroll.setWidget(widget)

        vbox = QtWidgets.QVBoxLayout()
        vbox.addWidget(scroll)
        self.setLayout(vbox)

    def get_setting(self):
        """"""
        setting = {}

        if self.class_name:
            setting["class_name"] = self.class_name

        for name, tp in self.edits.items():
            edit, type_ = tp
            value_text = edit.text()

            if type_ == bool:
                if value_text == "True":
                    value = True
                else:
                    value = False
            else:
                value = type_(value_text)

            setting[name] = value

        return setting

Member
avatar
加入于:
帖子: 419
声望: 170

4. 创建vnpy\usertools\kx_chart.py,内容如下:

from datetime import datetime
from typing import List, Tuple, Dict

import numpy as np
import pyqtgraph as pg
import talib
import copy

from vnpy.trader.ui import create_qapp, QtCore, QtGui, QtWidgets
from vnpy.trader.database import database_manager
from vnpy.trader.constant import Exchange, Interval
from vnpy.trader.object import BarData

from vnpy.chart import ChartWidget, VolumeItem, CandleItem
from vnpy.chart.item import ChartItem
from vnpy.chart.manager import BarManager
from vnpy.chart.base import NORMAL_FONT

from vnpy.trader.engine import MainEngine
from vnpy.event import Event, EventEngine

from vnpy.trader.event import (
    EVENT_TICK,
    EVENT_TRADE,
    EVENT_ORDER,
    EVENT_POSITION,
    EVENT_ACCOUNT,
    EVENT_LOG
)

from vnpy.app.cta_strategy.base import (    # hxxjava add
    EVENT_CTA_TICK,     
    EVENT_CTA_BAR,      
    EVENT_CTA_ORDER,    
    EVENT_CTA_TRADE,     
    EVENT_CTA_HISTORY_BAR
)

from vnpy.trader.object import TickData,BarData  # hxxjava add


class LineItem(CandleItem):
    """"""

    def __init__(self, manager: BarManager):
        """"""
        super().__init__(manager)

        self.white_pen: QtGui.QPen = pg.mkPen(color=(255, 255, 255), width=1)

    def _draw_bar_picture(self, ix: int, bar: BarData) -> QtGui.QPicture:
        """"""
        last_bar = self._manager.get_bar(ix - 1)

        # Create objects
        picture = QtGui.QPicture()
        painter = QtGui.QPainter(picture)

        # Set painter color
        painter.setPen(self.white_pen)

        # Draw Line
        end_point = QtCore.QPointF(ix, bar.close_price)

        if last_bar:
            start_point = QtCore.QPointF(ix - 1, last_bar.close_price)
        else:
            start_point = end_point

        painter.drawLine(start_point, end_point)

        # Finish
        painter.end()
        return picture


class SmaItem(CandleItem):
    """"""

    def __init__(self, manager: BarManager):
        """"""
        super().__init__(manager)

        self.blue_pen: QtGui.QPen = pg.mkPen(color=(100, 100, 255), width=2)

        self.sma_window = 10
        self.sma_data: Dict[int, float] = {}

    def get_sma_value(self, ix: int) -> float:
        """"""
        if ix < 0:
            return 0

        # When initialize, calculate all rsi value
        if not self.sma_data:
            bars = self._manager.get_all_bars()
            close_data = [bar.close_price for bar in bars]
            sma_array = talib.SMA(np.array(close_data), self.sma_window)

            for n, value in enumerate(sma_array):
                self.sma_data[n] = value

        # Return if already calcualted
        if ix in self.sma_data:
            return self.sma_data[ix]

        # Else calculate new value
        close_data = []
        for n in range(ix - self.sma_window, ix + 1):
            bar = self._manager.get_bar(n)
            close_data.append(bar.close_price)

        sma_array = talib.SMA(np.array(close_data), self.sma_window)
        sma_value = sma_array[-1]
        self.sma_data[ix] = sma_value

        return sma_value

    def _draw_bar_picture(self, ix: int, bar: BarData) -> QtGui.QPicture:
        """"""
        sma_value = self.get_sma_value(ix)
        last_sma_value = self.get_sma_value(ix - 1)

        # Create objects
        picture = QtGui.QPicture()
        painter = QtGui.QPainter(picture)

        # Set painter color
        painter.setPen(self.blue_pen)

        # Draw Line
        start_point = QtCore.QPointF(ix-1, last_sma_value)
        end_point = QtCore.QPointF(ix, sma_value)
        painter.drawLine(start_point, end_point)

        # Finish
        painter.end()
        return picture

    def get_info_text(self, ix: int) -> str:
        """"""
        if ix in self.sma_data:
            sma_value = self.sma_data[ix]
            text = f"SMA {sma_value:.1f}"
        else:
            text = "SMA -"

        return text


class RsiItem(ChartItem):
    """"""

    def __init__(self, manager: BarManager):
        """"""
        super().__init__(manager)

        self.white_pen: QtGui.QPen = pg.mkPen(color=(255, 255, 255), width=1)
        self.yellow_pen: QtGui.QPen = pg.mkPen(color=(255, 255, 0), width=2)

        self.rsi_window = 14
        self.rsi_data: Dict[int, float] = {}

    def get_rsi_value(self, ix: int) -> float:
        """"""
        if ix < 0:
            return 50

        # When initialize, calculate all rsi value
        if not self.rsi_data:
            bars = self._manager.get_all_bars()
            close_data = [bar.close_price for bar in bars]
            rsi_array = talib.RSI(np.array(close_data), self.rsi_window)

            for n, value in enumerate(rsi_array):
                self.rsi_data[n] = value

        # Return if already calcualted
        if ix in self.rsi_data:
            return self.rsi_data[ix]

        # Else calculate new value
        close_data = []
        for n in range(ix - self.rsi_window, ix + 1):
            bar = self._manager.get_bar(n)
            close_data.append(bar.close_price)

        rsi_array = talib.RSI(np.array(close_data), self.rsi_window)
        rsi_value = rsi_array[-1]
        self.rsi_data[ix] = rsi_value

        return rsi_value

    def _draw_bar_picture(self, ix: int, bar: BarData) -> QtGui.QPicture:
        """"""
        rsi_value = self.get_rsi_value(ix)
        last_rsi_value = self.get_rsi_value(ix - 1)

        # Create objects
        picture = QtGui.QPicture()
        painter = QtGui.QPainter(picture)

        # Draw RSI line
        painter.setPen(self.yellow_pen)

        if np.isnan(last_rsi_value) or np.isnan(rsi_value):
            # print(ix - 1, last_rsi_value,ix, rsi_value,)
            pass
        else:
            end_point = QtCore.QPointF(ix, rsi_value)
            start_point = QtCore.QPointF(ix - 1, last_rsi_value)
            painter.drawLine(start_point, end_point)

        # Draw oversold/overbought line
        painter.setPen(self.white_pen)

        painter.drawLine(
            QtCore.QPointF(ix, 70),
            QtCore.QPointF(ix - 1, 70),
        )

        painter.drawLine(
            QtCore.QPointF(ix, 30),
            QtCore.QPointF(ix - 1, 30),
        )

        # Finish
        painter.end()
        return picture

    def boundingRect(self) -> QtCore.QRectF:
        """"""
        # min_price, max_price = self._manager.get_price_range()
        rect = QtCore.QRectF(
            0,
            0,
            len(self._bar_picutures),
            100
        )
        return rect

    def get_y_range( self, min_ix: int = None, max_ix: int = None) -> Tuple[float, float]:
        """  """
        return 0, 100

    def get_info_text(self, ix: int) -> str:
        """"""
        if ix in self.rsi_data:
            rsi_value = self.rsi_data[ix]
            text = f"RSI {rsi_value:.1f}"
            # print(text)
        else:
            text = "RSI -"

        return text


def to_int(value: float) -> int:
    """"""
    return int(round(value, 0))

""" 将y方向的显示范围扩大到1.1 """
def adjust_range(in_range:Tuple[float, float])->Tuple[float, float]:
    ret_range:Tuple[float, float]
    diff = abs(in_range[0] - in_range[1])
    ret_range = (in_range[0]-diff*0.05,in_range[1]+diff*0.05)
    return ret_range

class MacdItem(ChartItem):
    """"""
    _values_ranges: Dict[Tuple[int, int], Tuple[float, float]] = {}

    last_range:Tuple[int, int] = (-1,-1)    # 最新显示K线索引范围

    def __init__(self, manager: BarManager):
        """"""
        super().__init__(manager)

        self.white_pen: QtGui.QPen = pg.mkPen(color=(255, 255, 255), width=1)
        self.yellow_pen: QtGui.QPen = pg.mkPen(color=(255, 255, 0), width=1)
        self.red_pen: QtGui.QPen = pg.mkPen(color=(255, 0, 0), width=1)
        self.green_pen: QtGui.QPen = pg.mkPen(color=(0, 255, 0), width=1)

        self.short_window = 12
        self.long_window = 26
        self.M = 9

        self.macd_data: Dict[int, Tuple[float,float,float]] = {}

    def get_macd_value(self, ix: int) -> Tuple[float,float,float]:
        """"""
        if ix < 0:
            return (0.0,0.0,0.0)

        # When initialize, calculate all macd value
        if not self.macd_data:
            bars = self._manager.get_all_bars()
            close_data = [bar.close_price for bar in bars]

            diffs,deas,macds = talib.MACD(np.array(close_data), 
                                    fastperiod=self.short_window, 
                                    slowperiod=self.long_window, 
                                    signalperiod=self.M)

            for n in range(0,len(diffs)):
                self.macd_data[n] = (diffs[n],deas[n],macds[n])

        # Return if already calcualted
        if ix in self.macd_data:
            return self.macd_data[ix]

        # Else calculate new value
        close_data = []
        for n in range(ix-self.long_window-self.M+1, ix + 1):
            bar = self._manager.get_bar(n)
            close_data.append(bar.close_price)

        diffs,deas,macds = talib.MACD(np.array(close_data), 
                                            fastperiod=self.short_window, 
                                            slowperiod=self.long_window, 
                                            signalperiod=self.M) 
        diff,dea,macd = diffs[-1],deas[-1],macds[-1]
        self.macd_data[ix] = (diff,dea,macd)

        return (diff,dea,macd)

    def _draw_bar_picture(self, ix: int, bar: BarData) -> QtGui.QPicture:
        """"""
        macd_value = self.get_macd_value(ix)
        last_macd_value = self.get_macd_value(ix - 1)

        # # Create objects
        picture = QtGui.QPicture()
        painter = QtGui.QPainter(picture)

        # # Draw macd lines
        if np.isnan(macd_value[0]) or np.isnan(last_macd_value[0]):
            # print("略过macd lines0")
            pass
        else:
            end_point0 = QtCore.QPointF(ix, macd_value[0])
            start_point0 = QtCore.QPointF(ix - 1, last_macd_value[0])
            painter.setPen(self.white_pen)
            painter.drawLine(start_point0, end_point0)

        if np.isnan(macd_value[1]) or np.isnan(last_macd_value[1]):
            # print("略过macd lines1")
            pass
        else:
            end_point1 = QtCore.QPointF(ix, macd_value[1])
            start_point1 = QtCore.QPointF(ix - 1, last_macd_value[1])
            painter.setPen(self.yellow_pen)
            painter.drawLine(start_point1, end_point1)

        if not np.isnan(macd_value[2]):
            if (macd_value[2]>0):
                painter.setPen(self.red_pen)
                painter.setBrush(pg.mkBrush(255,0,0))
            else:
                painter.setPen(self.green_pen)
                painter.setBrush(pg.mkBrush(0,255,0))
            painter.drawRect(QtCore.QRectF(ix-0.3,0,0.6,macd_value[2]))
        else:
            # print("略过macd lines2")
            pass

        painter.end()
        return picture

    def boundingRect(self) -> QtCore.QRectF:
        """"""
        min_y, max_y = self.get_y_range()
        rect = QtCore.QRectF(
            0,
            min_y,
            len(self._bar_picutures),
            max_y
        )
        return rect

    def get_y_range(self, min_ix: int = None, max_ix: int = None) -> Tuple[float, float]:
        #   获得3个指标在y轴方向的范围   
        #   hxxjava 修改,2020-6-29
        #   当显示范围改变时,min_ix,max_ix的值不为None,当显示范围不变时,min_ix,max_ix的值不为None,

        offset = max(self.short_window,self.long_window) + self.M - 1

        if not self.macd_data or len(self.macd_data) < offset:
            return 0.0, 1.0

        # print("len of range dict:",len(self._values_ranges),",macd_data:",len(self.macd_data),(min_ix,max_ix))

        if min_ix != None:          # 调整最小K线索引
            min_ix = max(min_ix,offset)

        if max_ix != None:          # 调整最大K线索引
            max_ix = min(max_ix, len(self.macd_data)-1)

        last_range = (min_ix,max_ix)    # 请求的最新范围   

        if last_range == (None,None):   # 当显示范围不变时
            if self.last_range in self._values_ranges:  
                # 如果y方向范围已经保存
                # 读取y方向范围
                result = self._values_ranges[self.last_range]
                # print("1:",self.last_range,result)
                return adjust_range(result)
            else:
                # 如果y方向范围没有保存
                # 从macd_data重新计算y方向范围
                min_ix,max_ix = 0,len(self.macd_data)-1

                macd_list = list(self.macd_data.values())[min_ix:max_ix + 1]
                ndarray = np.array(macd_list)           
                max_price = np.nanmax(ndarray)
                min_price = np.nanmin(ndarray)

                # 保存y方向范围,同时返回结果
                result = (min_price, max_price)
                self.last_range = (min_ix,max_ix)
                self._values_ranges[self.last_range] = result
                # print("2:",self.last_range,result)
                return adjust_range(result)

        """ 以下为显示范围变化时 """

        if last_range in self._values_ranges:
            # 该范围已经保存过y方向范围
            # 取得y方向范围,返回结果
            result = self._values_ranges[last_range]
            # print("3:",last_range,result)
            return adjust_range(result)

        # 该范围没有保存过y方向范围
        # 从macd_data重新计算y方向范围
        macd_list = list(self.macd_data.values())[min_ix:max_ix + 1]
        ndarray = np.array(macd_list) 
        max_price = np.nanmax(ndarray)
        min_price = np.nanmin(ndarray)

        # 取得y方向范围,返回结果
        result = (min_price, max_price)
        self.last_range = last_range
        self._values_ranges[self.last_range] = result
        # print("4:",self.last_range,result)
        return adjust_range(result)


    def get_info_text(self, ix: int) -> str:
        # """"""
        if ix in self.macd_data:
            diff,dea,macd = self.macd_data[ix]
            words = [
                f"diff {diff:.3f}"," ",
                f"dea {dea:.3f}"," ",
                f"macd {macd:.3f}"
                ]
            text = "\n".join(words)
        else:
            text = "diff - \ndea - \nmacd -"

        return text



class NewChartWidget(ChartWidget):
    """"""
    MIN_BAR_COUNT = 100

    signal_cta_history_bar:QtCore.pyqtSignal = QtCore.pyqtSignal(Event)
    signal_cta_tick: QtCore.pyqtSignal = QtCore.pyqtSignal(Event)
    signal_cta_bar:QtCore.pyqtSignal = QtCore.pyqtSignal(Event)

    # def __init__(self, parent: QtWidgets.QWidget = None,event_engine: EventEngine = None):
    #     """"""
    #     super().__init__(parent)
    def __init__(self, main_engine: MainEngine = None,event_engine: EventEngine = None):
        """"""
        super().__init__()

        self.event_engine = event_engine

        self.last_price_line: pg.InfiniteLine = None

        # self.register_event()

        # self.event_engine.start()

    def register_event(self) -> None:
        """"""
        self.signal_cta_history_bar.connect(self.process_cta_history_bar)
        self.event_engine.register(EVENT_CTA_HISTORY_BAR, self.signal_cta_history_bar.emit)

        self.signal_cta_tick.connect(self.process_tick_event)
        self.event_engine.register(EVENT_CTA_TICK, self.signal_cta_tick.emit)

        self.signal_cta_bar.connect(self.process_cta_bar)
        self.event_engine.register(EVENT_CTA_BAR, self.signal_cta_bar.emit)

    def process_cta_history_bar(self, event:Event) -> None:
        """ 处理历史K线推送 """
        print(" I got an EVENT_CTA_HISTORY_BAR")
        history_bars:List[BarData] = event.data
        self.update_history(history_bars)

    def process_tick_event(self, event: Event) -> None:
        """ 处理tick数据推送 """
        tick:TickData = event.data
        if self.last_price_line:
            self.last_price_line.setValue(tick.last_price)
        # print("I got an event")

    def process_cta_bar(self, event:Event)-> None:
        """ 处理K线数据推送 """
        print(" I got an EVENT_CTA_BAR")
        bar:BarData = event.data
        self.update_bar(bar)

    def add_last_price_line(self):
        """"""
        plot = list(self._plots.values())[0]
        color = (255, 255, 255)

        self.last_price_line = pg.InfiniteLine(
            angle=0,
            movable=False,
            label="{value:.1f}",
            pen=pg.mkPen(color, width=1),
            labelOpts={
                "color": color,
                "position": 1,
                "anchors": [(1, 1), (1, 1)]
            }
        )
        self.last_price_line.label.setFont(NORMAL_FONT)
        plot.addItem(self.last_price_line)

    def update_history(self, history: List[BarData]) -> None:
        """
        Update a list of bar data.
        """
        self._manager.update_history(history)

        for item in self._items.values():
            item.update_history(history)

        self._update_plot_limits()

        self.move_to_right()

        self.update_last_price_line(history[-1])

    def update_bar(self, bar: BarData) -> None:
        """
        Update single bar data.
        """
        self._manager.update_bar(bar)

        for item in self._items.values():
            item.update_bar(bar)

        self._update_plot_limits()

        if self._right_ix >= (self._manager.get_count() - self._bar_count / 2):
            self.move_to_right()

        self.update_last_price_line(bar)

    def update_last_price_line(self, bar: BarData) -> None:
        """"""
        if self.last_price_line:
            self.last_price_line.setValue(bar.close_price)


if __name__ == "__main__":
    app = create_qapp()

    # bars = database_manager.load_bar_data(
    #     "IF888",
    #     Exchange.CFFEX,
    #     interval=Interval.MINUTE,
    #     start=datetime(2019, 7, 1),
    #     end=datetime(2019, 7, 17)
    # )

    symbol = "rb2010"
    exchange = Exchange.SHFE
    interval=Interval.MINUTE
    start=datetime(2020, 6, 1)
    end=datetime(2020, 6, 30)    

    dynamic = False  # 是否动态演示
    n = 200          # 缓冲K线根数


    bars = database_manager.load_bar_data(
        symbol=symbol,
        exchange=exchange,
        interval=interval,
        start=start,
        end=end
    )

    event_engine = EventEngine()

    widget = NewChartWidget(event_engine = event_engine)
    widget.setWindowTitle(f"K线图表——{symbol}.{exchange.value},{interval},{start}-{end}")
    widget.add_plot("candle", hide_x_axis=True)
    widget.add_plot("volume", maximum_height=150)
    widget.add_plot("rsi", maximum_height=150)
    widget.add_plot("macd", maximum_height=150)
    widget.add_item(CandleItem, "candle", "candle")
    widget.add_item(VolumeItem, "volume", "volume")

    widget.add_item(LineItem, "line", "candle")
    widget.add_item(SmaItem, "sma", "candle")
    widget.add_item(RsiItem, "rsi", "rsi")
    widget.add_item(MacdItem, "macd", "macd")
    widget.add_last_price_line()
    widget.add_cursor()

    if dynamic:
        history = bars[:n]      # 先取得最早的n根bar作为历史
        new_data = bars[n:]     # 其它留着演示
    else:
        history = bars[-n:]     # 先取得最新的n根bar作为历史
        new_data = []           # 演示的为空

    widget.update_history(history)

    def update_bar():
        if new_data:
            bar = new_data.pop(0)
            widget.update_bar(bar)

        # event = Event(EVENT_TICK,None)
        # event_engine.put(event)

    timer = QtCore.QTimer()
    timer.timeout.connect(update_bar)
    if dynamic:
        timer.start(100)

    widget.show()

    # event_engine.start()
    app.exec_()
Member
avatar
加入于:
帖子: 419
声望: 170

5 创建 [用户目录]\strategies\Kx_Monitor.py

这里的用户目录就是你自己windows系统的默认用户目录,如果你没有特别指定的话, [用户目录]\strategies就是你的其他策略文件存放的目录。
该文件的内容如下:

from typing import Any,List,Dict,Tuple
import copy

from vnpy.app.cta_strategy import (
    CtaTemplate,
    BarGenerator,
    ArrayManager,
    StopOrder,
    Direction
)

from vnpy.trader.engine import MainEngine,EventEngine
from vnpy.app.cta_strategy.engine import CtaEngine
from vnpy.event.engine import Event

from vnpy.trader.object import (
    LogData,
    TickData,
    BarData,
    TradeData,
    OrderData,
)

from vnpy.app.cta_strategy import (
    StopOrder
)

from vnpy.trader.constant import Interval

from vnpy.app.cta_strategy.base import (
    APP_NAME,
    EVENT_CTA_LOG,
    EVENT_CTA_TICK,
    EVENT_CTA_HISTORY_BAR,
    EVENT_CTA_BAR,
    EVENT_CTA_ORDER,
    EVENT_CTA_TRADE,    
    EVENT_CTA_STOPORDER,
    EVENT_CTA_STRATEGY,
)

class _kx_strategy(CtaTemplate):
    """"""

    author = "hxxjava"
    kx_interval = 1
    parameters = [
        "kx_interval"
    ]

    kx_count:int = 0
    all_bars:List[BarData] = []

    current_tick:[TickData] = None
    current_bar:[BarData] = None

    variables = ["kx_count"]

    def __init__(
        self,
        cta_engine: Any,
        strategy_name: str,
        vt_symbol: str,
        setting: dict,
    ):
        super().__init__(cta_engine,strategy_name,vt_symbol,setting)
        self.bg = BarGenerator(self.on_bar,self.kx_interval,self.on_Nmin_bar)
        self.am = ArrayManager()

        cta_engine:CtaEngine = self.cta_engine
        self.even_engine = cta_engine.main_engine.event_engine

    def on_init(self):
        """
        Callback when strategy is inited.
        """
        self.write_log("_kx_strategy 初始化")
        self.load_bar(20)

    def on_start(self):
        """ """
        self.write_log("_kx_strategy 已开始")
        if len(self.all_bars)>0:
            self.even_engine.put(Event(EVENT_CTA_HISTORY_BAR,self.all_bars))

    def on_stop(self):
        """"""
        self.write_log("_kx_strategy 已停止")

    def on_tick(self, tick: TickData):
        """
        Callback of new tick data update.
        """
        self.current_tick = tick    # 记录最新tick 

        if self.inited:     
            # 先产生当前临时K线
            cur_bar = self.get_cur_bar(tick)    
            if cur_bar:
                # 发送当前临时K线更新消息
                self.even_engine.put(Event(EVENT_CTA_BAR,cur_bar)) 


        # 再更新tick,产生1分钟K线乃至N 分钟线
        self.bg.update_tick(tick)
        self.even_engine.put(Event(EVENT_CTA_TICK,tick))   

    def on_bar(self, bar: BarData):
        """
        Callback of new bar data update.
        """
        if self.inited:
            self.write_log(f"I got a 1min BarData")
        self.bg.update_bar(bar)

    def on_Nmin_bar(self, bar: BarData):
        """
        Callback of new bar data update.
        """
        self.all_bars.append(bar)
        self.kx_count = len(self.all_bars)

        if self.inited:
            self.write_log(f"I got a {self.kx_interval}min BarData")
            self.even_engine.put(Event(EVENT_CTA_BAR,bar))

            if self.current_tick:
                # 当新N分钟K线产生的时候,立即产生新的临时K线
                self.current_bar = None
                self.get_cur_bar(self.current_tick)

        self.put_event()

    def on_trade(self, trade: TradeData):
        """
        Callback of new trade data update.
        """
        self.even_engine.put(Event(EVENT_CTA_TRADE,trade))

    def on_order(self, order: OrderData):
        """
        Callback of new order data update.
        """
        self.even_engine.put(Event(EVENT_CTA_ORDER,order))

    def on_stop_order(self, stop_order: StopOrder):
        """
        Callback of stop order update.
        """
        self.even_engine.put(Event(EVENT_CTA_STOPORDER,stop_order))

    def get_cur_bar(self,tick:TickData)->BarData:
        """
        产生临时K线,每个tick都会更新。除非把self.window_bar赋值为None,
        不会产生新的K线,只会更新K线的量和加。
        注意:self.last_tick是在BarGenerator中声明和改变的
        """
        last_tick = self.bg.last_tick
        if not self.inited or not last_tick:
            return None

        if last_tick and tick.datetime < last_tick.datetime:
            return None

        if not self.current_bar:
            # Generate timestamp for bar data
            if self.bg.interval == Interval.MINUTE:
                dt = tick.datetime.replace(second=0, microsecond=0)
            else:
                dt = tick.datetime.replace(minute=0, second=0, microsecond=0)

            self.current_bar = BarData(
                symbol=tick.symbol,
                exchange=tick.exchange,
                datetime=dt,
                gateway_name=tick.gateway_name,
                open_price=tick.last_price,
                high_price=tick.last_price,
                low_price=tick.last_price,
            )
        # Otherwise, update high/low price into window bar
        else:
            self.current_bar.high_price = max(self.current_bar.high_price, tick.last_price)
            self.current_bar.low_price = min(self.current_bar.low_price, tick.last_price)

        # Update last price/volume into window bar
        self.current_bar.close_price = tick.last_price
        volume_change = tick.volume - last_tick.volume
        self.current_bar.volume += volume_change
        self.current_bar.open_interest = tick.open_interest

        return copy.deepcopy(self.current_bar)
Member
avatar
加入于:
帖子: 419
声望: 170

6. 添加可以显示的策略

启动VnTrader,进入策略管理界面,完成如下步骤:
1)从策略下拉框中选择_kx_strategy策略
2)点击添加策略按钮进入3界面
3)输入策略名称、vt_symbol、kx_interval的值,注意kx_interval这里是你想要的K线周期,单位是分钟。
4)添加一个可以显示K线的策略
5)初始化策略,完成后”K线图表“和”启动“两个按钮都是有效的
6)按钮K线图表按钮,创建一个空的K线图窗口
7)按启动按钮启动策略
description
8)再次找到K线图表里窗口,你会发现里面已经是下图这个样子了。
说明一下:如果在交易时段,它会不停地更新K线状态和各个附图指标,而且随着tick数据的不断更新,它还可以更新最后一个未结束的当前K线。

description

Member
avatar
加入于:
帖子: 4
声望: 0

感谢分享,有个疑问:kx_interval的值 单位是分钟,如果策略的周期是2H或者3小时这样的,是不是需要修改 Kx_Monitor.py 里面的

Generate timestamp for bar data

        if self.bg.interval == Interval.MINUTE:    《-----这里改成小时的?!
Member
avatar
加入于:
帖子: 419
声望: 170

遇到问题了

对不起大家了,本次分享遇到了问题,导致K线图的数据和显示出错了,请大家耐心等待问题的解决。
具体问题请参考:https://www.vnpy.com/forum/topic/3893-ctace-lue-de-wen-ti-:liang-ge-he-yue-bu-ke-yi-gong-yong-yi-ge-ce-lue

Member
avatar
加入于:
帖子: 1
声望: 0

厉害呀! 楼主考虑直接merge到vnpy github不?

Member
avatar
加入于:
帖子: 7
声望: 1

感谢楼主,很不错的引路文章。

Member
avatar
加入于:
帖子: 2
声望: 0

你好,楼主,请教一下,在回测策略的时候能否动态显示K线图(5分钟,15分钟等等)

© 2015-2022 上海韦纳软件科技有限公司
备案服务号:沪ICP备18006526号

沪公网安备 31011502017034号

【用户协议】
【隐私政策】
【免责条款】