VeighNa量化社区
你的开源社区量化交易平台
Member
avatar
加入于:
帖子: 10
声望: 0

引言

首先感谢官方开源的功能,但portfoliostrategy很需要移仓换月功能

功能描述

  • 合约自动匹配
  • 合约移仓换月

流程图

description

原型图

description

代码

  • 代码位置-vnpy_portfoliostrategy.vnpy_portfoliostrategy.ui.rollover.py
from dataclasses import dataclass
from datetime import datetime
import re
from time import sleep
from typing import Any, List, Optional, TYPE_CHECKING
from copy import copy

from vnpy.trader.datafeed import BaseDatafeed, get_datafeed
from vnpy.trader.engine import MainEngine
from vnpy.trader.constant import OrderType
from vnpy.trader.object import ContractData, OrderRequest, SubscribeRequest, TickData
from vnpy.trader.object import Direction, Offset
from vnpy.trader.ui import QtWidgets, QtCore
from vnpy.trader.converter import OffsetConverter, PositionHolding

from ..base import APP_NAME
from ..engine import StrategyEngine
from ..template import StrategyTemplate

if TYPE_CHECKING:
    from .widget import PortfolioStrategyManager

CELL_HEIGHT = 50
CELL_WIDTH = 120


@dataclass
class RolloverData:
    """移仓合约数据"""
    vt_symbol: str = ""                 # 移仓合约
    new_vt_symbol: str = ""             # 目标合约
    payup: int = 5                      # 委托超价
    max_volume: int = 100               # 单比上限


class RolloverTool(QtWidgets.QDialog):
    """"""
    header_cols: list
    new_vt_symbols: list
    data_list: list

    def __init__(self, portfolioStrategyManager: "PortfolioStrategyManager") -> None:
        """"""
        super().__init__()

        self.portfolioStrategyManager: "PortfolioStrategyManager" = portfolioStrategyManager

        self.strategy_engine: StrategyEngine = portfolioStrategyManager.strategy_engine
        self.main_engine: MainEngine = portfolioStrategyManager.main_engine

        self.header_cols = ["移仓合约", "目标合约", "委托超价", "单比上限"]
        self.new_vt_symbols = []
        self.data_list = []

        self.init_ui()
        self.refresh_data()
        self.refresh_ui()

    def init_ui(self) -> None:
        """"""
        self.setWindowTitle("移仓助手")

        self.text_edit = QtWidgets.QLineEdit()
        self.text_edit.setPlaceholderText("请输入目标合约,多个合约用,号分隔")

        import_button = QtWidgets.QPushButton("导入")
        import_button.setSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Fixed)
        import_button.clicked.connect(self.handle_import_button_clicked)

        top_hbox = QtWidgets.QHBoxLayout()
        top_hbox.addWidget(self.text_edit, 2)
        top_hbox.addWidget(import_button)
        top_hbox.setContentsMargins(0, 0, 0, 0)
        top_hbox.setSpacing(15)

        self.table = QtWidgets.QTableWidget()
        self.table.horizontalHeader().setVisible(False)
        self.table.verticalHeader().setSectionResizeMode(QtWidgets.QHeaderView.Fixed)
        self.table.setEditTriggers(QtWidgets.QAbstractItemView.NoEditTriggers)
        self.table.setSelectionMode(QtWidgets.QAbstractItemView.NoSelection)
        self.table.verticalScrollBar().setEnabled(False)
        self.table.verticalHeader().setDefaultSectionSize(CELL_HEIGHT)
        self.table.setSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Fixed)
        self.table.setFixedHeight(CELL_HEIGHT * len(self.header_cols) + self.table.horizontalScrollBar().height())
        self.table.setVerticalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOff)
        self.table.setHorizontalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOn)

        self.rollover_button = QtWidgets.QPushButton("移仓")
        self.rollover_button.clicked.connect(self.handle_rollover_button_clicked)

        self.log_edit: QtWidgets.QTextEdit = QtWidgets.QTextEdit()
        self.log_edit.setReadOnly(True)
        self.log_edit.setMinimumWidth(500)

        vbox = QtWidgets.QVBoxLayout()
        vbox.addLayout(top_hbox)
        vbox.addWidget(self.table)
        vbox.addWidget(self.rollover_button)
        vbox.addWidget(self.log_edit)

        self.setLayout(vbox)

    def refresh_ui(self) -> None:
        """刷新UI"""
        self.refresh_table()
        self.refresh_button()

    def extract_non_digits(self, string) -> str:
        pattern = r'\D+'
        result = re.findall(pattern, string)
        return ''.join(result)

    def refresh_data(self) -> None:
        """刷新列表数据"""
        self.data_list: list[RolloverData] = []

        symbol_dict: dict[str, str] = {}
        if self.new_vt_symbols:
            for new_vt_symbol in self.new_vt_symbols:
                key = self.extract_non_digits(new_vt_symbol)
                symbol_dict[key] = new_vt_symbol

        old_symbols: list = []
        for vt_symbol, strategies in self.strategy_engine.symbol_strategy_map.items():
            if strategies:
                old_symbols.append(vt_symbol)
                data = RolloverData()
                data.vt_symbol = vt_symbol
                non_digits_sybmol = self.extract_non_digits(vt_symbol)
                if non_digits_sybmol in symbol_dict and vt_symbol != symbol_dict[non_digits_sybmol]:
                    data.new_vt_symbol = symbol_dict[non_digits_sybmol]
                self.data_list.append(data)

        self.sort_data()

    def sort_data(self) -> None:
        """置顶移仓合约"""
        self.data_list.sort(key=lambda x: not (x.new_vt_symbol is not None and x.new_vt_symbol != "" and x.new_vt_symbol != x.vt_symbol))

    def filter_data(self) -> list:
        """筛选数据"""
        return [data for data in self.data_list if data.new_vt_symbol is not None and data.new_vt_symbol != ""]

    def refresh_table(self) -> None:
        """刷新表格"""
        self.table.setRowCount(len(self.header_cols))
        self.table.setColumnCount(len(self.data_list))

        self.table.setVerticalHeaderLabels(self.header_cols)

        for col_index, data in enumerate(self.data_list):
            self.table.setColumnWidth(col_index, CELL_WIDTH)

            highlighted: bool = data.new_vt_symbol is not None and \
                data.new_vt_symbol != "" and data.new_vt_symbol != data.vt_symbol

            for row_index, row in enumerate(self.header_cols):
                self.table.setRowHeight(row_index, CELL_HEIGHT)

                cell_widget = self.table.cellWidget(row_index, col_index)
                if cell_widget:
                    self.table.removeCellWidget(row_index, col_index)
                    cell_widget.deleteLater()

                cell: RolloverBaseTableWidget = None
                if row == "移仓合约":
                    cell = RolloverTextTableWidget()
                    cell.set_content(data.vt_symbol)
                    self.table.setCellWidget(row_index, col_index, cell)
                elif row == "目标合约":
                    cell = RolloverComboTableWidget()
                    cell.set_setting(self.new_vt_symbols)
                    cell.set_value(data.new_vt_symbol)
                    cell.addTarget(self.combobox_activated)
                    self.table.setCellWidget(row_index, col_index, cell)
                elif row == "委托超价" and highlighted:
                    cell = RolloverSpinTableWidget()
                    cell.set_setting(5, 1000)
                    cell.set_value(data.payup)
                    cell.addTarget(self.spinbox_valueChanged)
                    self.table.setCellWidget(row_index, col_index, cell)
                elif row == "单比上限" and highlighted:
                    cell = RolloverSpinTableWidget()
                    cell.set_setting(1, 10000)
                    cell.set_value(data.max_volume)
                    cell.addTarget(self.spinbox_valueChanged)
                    self.table.setCellWidget(row_index, col_index, cell)
                if cell:
                    cell.set_highlighted(highlighted)

    def refresh_button(self) -> None:
        """刷新移仓按钮状态"""
        selected_list = self.filter_data()
        if selected_list:
            self.rollover_button.setEnabled(True)
        else:
            self.rollover_button.setEnabled(False)

    def write_log(self, text: str) -> None:
        """"""
        now: datetime = datetime.now()
        text: str = now.strftime("%H:%M:%S\t") + text
        self.log_edit.append(text)

    def subscribe(self, vt_symbol: str) -> None:
        """"""
        contract: Optional[ContractData] = self.main_engine.get_contract(vt_symbol)
        if not contract:
            return

        req: SubscribeRequest = SubscribeRequest(contract.symbol, contract.exchange)
        self.main_engine.subscribe(req, contract.gateway_name)

    def roll_all(self) -> None:
        """"""
        selected_list = self.filter_data()

        # 检策略状态
        for selected in selected_list:
            # Check all strategies inited (pos data loaded from disk json file) and not trading
            strategies = self.strategy_engine.symbol_strategy_map[selected.vt_symbol]
            for strategy in strategies:
                if not strategy.inited:
                    self.write_log(f"策略{strategy.strategy_name}尚未初始化,无法执行移仓")
                    return

                if strategy.trading:
                    self.write_log(f"策略{strategy.strategy_name}正在运行中,无法执行移仓")
                    return

        for selected in selected_list:
            old_symbol: str = selected.vt_symbol

            new_symbol: str = selected.new_vt_symbol
            self.subscribe(new_symbol)
            sleep(1)

            new_tick: Optional[TickData] = self.main_engine.get_tick(new_symbol)
            if not new_tick:
                self.write_log(f"无法获取目标合约{new_symbol}的盘口数据,请先订阅行情")
                return

            payup: int = selected.payup

            # Roll position first
            self.roll_position(old_symbol, new_symbol, payup)

        # Then roll strategy
        for selected in selected_list:
            strategies = self.strategy_engine.symbol_strategy_map[selected.vt_symbol]
            for strategy in copy(strategies):
                self.roll_strategy(strategy, selected.vt_symbol, selected.new_vt_symbol)

        # Disable self
        self.setEnabled(False)

    def roll_position(self, old_symbol: str, new_symbol: str, payup: int) -> None:
        """"""
        contract: ContractData = self.main_engine.get_contract(old_symbol)
        converter: OffsetConverter = self.main_engine.get_converter(contract.gateway_name)
        holding: PositionHolding = converter.get_position_holding(old_symbol)

        # Roll long position
        if holding.long_pos:
            volume: float = holding.long_pos

            self.send_order(
                old_symbol,
                Direction.SHORT,
                Offset.CLOSE,
                payup,
                volume
            )

            self.send_order(
                new_symbol,
                Direction.LONG,
                Offset.OPEN,
                payup,
                volume
            )

        # Roll short postiion
        if holding.short_pos:
            volume: float = holding.short_pos

            self.send_order(
                old_symbol,
                Direction.LONG,
                Offset.CLOSE,
                payup,
                volume
            )

            self.send_order(
                new_symbol,
                Direction.SHORT,
                Offset.OPEN,
                payup,
                volume
            )

    def roll_strategy(self, strategy: StrategyTemplate, vt_symbol: str, new_vt_symbol: str) -> None:
        """"""
        if not strategy.inited:
            self.strategy_engine._init_strategy(strategy.strategy_name)

        # Save data of old strategy
        name: str = strategy.strategy_name
        pos_data = strategy.pos_data
        vt_symbols = strategy.vt_symbols
        parameters: dict = strategy.get_parameters()

        vt_symbols = [new_vt_symbol if item == vt_symbol else item for item in vt_symbols]

        # Remove old strategy
        result: bool = self.strategy_engine.remove_strategy(name)
        if result:
            self.portfolioStrategyManager.remove_strategy(name)

        self.write_log(f"移除老策略{name}[{strategy.vt_symbols}]")

        # Add new strategy
        self.strategy_engine.add_strategy(
            strategy.__class__.__name__,
            name,
            vt_symbols,
            parameters
        )
        self.write_log(f"创建策略{name}[{vt_symbols}]")

        # Init new strategy
        self.strategy_engine.init_strategy(name)
        self.write_log(f"初始化策略{name}[{vt_symbols}]")

        # Update pos to new strategy
        new_strategy: StrategyTemplate = self.strategy_engine.strategies[name]
        new_strategy.pos_data = pos_data
        new_strategy.sync_data()
        self.write_log(f"更新策略仓位{name}[{vt_symbols}]")

    def send_order(
        self,
        vt_symbol: str,
        direction: Direction,
        offset: Offset,
        payup: int,
        volume: float,
    ) -> None:
        """
        Send a new order to server.
        """
        max_volume: int = self.max_volume_spin.value()

        contract: Optional[ContractData] = self.main_engine.get_contract(vt_symbol)
        tick: Optional[TickData] = self.main_engine.get_tick(vt_symbol)

        if direction == Direction.LONG:
            price = tick.ask_price_1 + contract.pricetick * payup
        else:
            price = tick.bid_price_1 - contract.pricetick * payup

        while True:
            order_volume: int = min(volume, max_volume)

            original_req: OrderRequest = OrderRequest(
                symbol=contract.symbol,
                exchange=contract.exchange,
                direction=direction,
                offset=offset,
                type=OrderType.LIMIT,
                price=price,
                volume=order_volume,
                reference=f"{APP_NAME}_Rollover"
            )

            req_list: List[OrderRequest] = self.main_engine.convert_order_request(
                original_req,
                contract.gateway_name,
                False,
                False
            )

            vt_orderids: list = []
            for req in req_list:
                vt_orderid: str = self.main_engine.send_order(req, contract.gateway_name)
                if not vt_orderid:
                    continue

                vt_orderids.append(vt_orderid)
                self.main_engine.update_order_request(req, vt_orderid, contract.gateway_name)

                msg: str = f"发出委托{vt_symbol},{direction.value} {offset.value},{volume}@{price}"
                self.write_log(msg)

            # Check whether all volume sent
            volume = volume - order_volume
            if not volume:
                break

    def import_symbols(self, text: str) -> None:
        """导入合约"""
        result = self.check_input(text)
        if result:
            self.show_message_box(result, self.check_new_symbols_message_box_action)
        else:
            self.new_vt_symbols = text.split(",")
            self.new_vt_symbols.insert(0, "")
            self.refresh_data()
            self.refresh_ui()

    def check_input(self, text: str) -> str:
        """
        1. 文字非空判断,如空,弹窗提示目标合约为空
        2. 文字合法性判断,如不合法,弹窗提示目标合约不合法
        """
        if text:
            return ""
        else:
            return "目标合约不能为空!"

    def handle_enter_pressed(self) -> None:
        text = self.text_edit.text()
        self.import_symbols(text)

    def handle_import_button_clicked(self) -> None:
        self.import_symbols(self.text_edit.text())

    def handle_download_button_clicked(self) -> None:
        self.text_edit.setText("")
        self.import_symbols("")

    def handle_rollover_button_clicked(self) -> None:
        """"""
        selected_result = self.filter_data()

        message = "请确认以下移仓信息: \n"
        for selected in selected_result:
            row_str = f"移仓合约: {selected.vt_symbol}\t目标合约: {selected.new_vt_symbol}\t委托超价: {selected.payup}\t单比上限: {selected.max_volume}"
            message += row_str + "\n"

        self.show_dialog(message, self.check_rollerover_message_box_action)

    def combobox_activated(self) -> None:
        """"""
        combobox: QtWidgets.QComboBox = self.sender()
        combobox_parent = combobox.parent()
        index = self.table.indexAt(combobox_parent.pos())
        col = self.data_list[index.column()]
        row = self.header_cols[index.row()]

        if row == "目标合约":
            col.new_vt_symbol = combobox.currentText()

        self.sort_data()
        self.refresh_ui()

    def spinbox_valueChanged(self) -> None:
        """"""
        spinbox: QtWidgets.QSpinBox = self.sender()
        spinbox_parent = spinbox.parent()
        index = self.table.indexAt(spinbox_parent.pos())

        col = self.data_list[index.column()]
        row = self.header_cols[index.row()]

        if row == "委托超价":
            col.payup = spinbox.value()
        elif row == "单比上限":
            col.max_volume = spinbox.value()

        self.sort_data()
        self.refresh_ui()

    def show_message_box(self, message, response_callback) -> None:
        msg_box = QtWidgets.QMessageBox()
        msg_box.setWindowTitle("提示")
        msg_box.setText(message)

        msg_box.setIcon(QtWidgets.QMessageBox.Information)

        msg_box.addButton("确认", QtWidgets.QMessageBox.AcceptRole)
        msg_box.addButton("取消", QtWidgets.QMessageBox.RejectRole)

        result = msg_box.exec()

        if result == QtWidgets.QMessageBox.AcceptRole:
            response_callback(True)
        elif result == QtWidgets.QMessageBox.RejectRole:
            response_callback(False)

    def show_dialog(self, message, response_callback) -> None:
        """"""
        dialog = QtWidgets.QDialog()
        dialog.setWindowTitle("提示")
        dialog.setMinimumSize(200, 100)

        layout = QtWidgets.QVBoxLayout(dialog)
        label = QtWidgets.QLabel(message)
        layout.addWidget(label)

        ok_button = QtWidgets.QPushButton("确定")
        ok_button.clicked.connect(dialog.accept)
        cancel_button = QtWidgets.QPushButton("取消")
        cancel_button.clicked.connect(dialog.reject)
        layout.addWidget(ok_button)
        layout.addWidget(cancel_button)

        result = dialog.exec()
        if result == QtWidgets.QDialog.Accepted:
            response_callback(True)
        else:
            response_callback(False)

    def check_new_symbols_message_box_action(self, result) -> None:
        """"""
        if result:
            pass
        else:
            pass

    def check_rollerover_message_box_action(self, result) -> None:
        """"""
        if result:
            self.roll_all()
        else:
            pass


class RolloverDialog(QtWidgets.QDialog):
    """
    提示确认已经配置好的移仓信息
    """

    def __init__(self) -> None:
        super().__init__()

    def init_ui(self) -> None:
        """
        初始化界面
        """


class RolloverBaseTableWidget(QtWidgets.QWidget):
    """基类"""
    def __init__(self) -> None:
        """构造函数"""
        super().__init__()

        self.setSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding)

    def set_highlighted(self, highlighted: bool) -> None:
        if highlighted:
            self.setStyleSheet("background-color: rgb(25,100,107);")
        else:
            self.setStyleSheet("")


class RolloverTextTableWidget(RolloverBaseTableWidget):
    """
    文本的表格项
    """
    def __init__(self) -> None:
        """构造函数"""
        super().__init__()

        self.init_ui()

    def init_ui(self) -> None:
        """初始化界面"""
        self.label = QtWidgets.QLabel()
        self.label.setSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding)

        hbox = QtWidgets.QHBoxLayout()
        hbox.addWidget(self.label)
        hbox.setContentsMargins(0, 0, 0, 0)

        self.setLayout(hbox)

    def set_content(self, content: object) -> None:
        """设置内容"""
        self.label.setText(content)


class RolloverComboTableWidget(RolloverBaseTableWidget):
    """
    下拉框的表格项
    """
    def __init__(self) -> None:
        """构造函数"""
        super().__init__()

        self.init_ui()

    def init_ui(self) -> None:
        """初始化界面"""
        self.vt_symbol_combo: QtWidgets.QComboBox = QtWidgets.QComboBox()
        self.vt_symbol_combo.setSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding)

        hbox = QtWidgets.QHBoxLayout()
        hbox.addWidget(self.vt_symbol_combo)
        hbox.setContentsMargins(0, 0, 0, 0)

        self.setLayout(hbox)

    def set_setting(self, content: object) -> None:
        """设置内容"""
        self.vt_symbol_combo.addItems(content)

    def set_value(self, content: object) -> None:
        """设置内容"""
        self.vt_symbol_combo.setCurrentText(content)

    def addTarget(self, target: Any) -> None:
        """添加点击事件"""
        self.vt_symbol_combo.activated.connect(target)


class RolloverSpinTableWidget(RolloverBaseTableWidget):
    """
    数值输入框的表格项
    """
    def __init__(self) -> None:
        """构造函数"""
        super().__init__()

        self.init_ui()

    def init_ui(self) -> None:
        """初始化界面"""
        self.span: QtWidgets.QSpinBox = QtWidgets.QSpinBox()
        self.span.setSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding)

        hbox = QtWidgets.QHBoxLayout()
        hbox.addWidget(self.span)
        hbox.setContentsMargins(0, 0, 0, 0)

        self.setLayout(hbox)

    def set_setting(self, minimum: int, maximum: int) -> None:
        """设置内容"""
        self.span.setMinimum(minimum)
        self.span.setMaximum(maximum)

    def set_value(self, content: int) -> None:
        """设置值"""
        self.span.setValue(content)

    def addTarget(self, target: Any) -> None:
        """添加点击事件"""
        self.span.valueChanged.connect(target)
  • 代码位置-vnpy_portfoliostrategy.vnpy_portfoliostrategy.ui.widget.py
class PortfolioStrategyManager(QtWidgets.QWidget):
  ...
 def init_ui(self) -> None:
        ...
        init_button: QtWidgets.QPushButton = QtWidgets.QPushButton("全部初始化")
        init_button.clicked.connect(self.strategy_engine.init_all_strategies)
        roll_button: QtWidgets.QPushButton = QtWidgets.QPushButton("移仓助手")
        roll_button.clicked.connect(self.roll)
        ...

def roll(self) -> None:
        """"""
        dialog: RolloverTool = RolloverTool(self)
        dialog.exec_()
Member
avatar
加入于:
帖子: 1472
声望: 105

感谢分享!

© 2015-2022 上海韦纳软件科技有限公司
备案服务号:沪ICP备18006526号

沪公网安备 31011502017034号

【用户协议】
【隐私政策】
【免责条款】