VeighNa量化社区
你的开源社区量化交易平台
Member
avatar
加入于:
帖子: 11
声望: 0

引言

首先感谢官方开源的功能,但portfoliostrategy很需要移仓换月功能

功能描述

  • 合约自动匹配
  • 合约移仓换月

流程图

description

原型图

description

代码

  • 代码位置-vnpy_portfoliostrategy.vnpy_portfoliostrategy.ui.rollover.py

```from dataclasses import dataclass
from datetime import datetime
import re
from time import sleep
from typing import Any, List, Optional, TYPE_CHECKING, Tuple
from copy import copy

import rqdatac

from vnpy.trader.datafeed import BaseDatafeed, get_datafeed
from vnpy.trader.engine import MainEngine
from vnpy.trader.constant import OrderType
from vnpy.trader.object import ContractData, OrderRequest, SubscribeRequest, TickData
from vnpy.trader.object import Direction, Offset
from vnpy.trader.ui import QtWidgets, QtCore
from vnpy.trader.converter import OffsetConverter, PositionHolding

from ..base import APP_NAME
from ..engine import StrategyEngine
from ..template import StrategyTemplate

if TYPE_CHECKING:
from .widget import PortfolioStrategyManager

CELL_HEIGHT = 50
CELL_WIDTH = 120

@dataclass
class RolloverData:
"""移仓合约数据"""
vt_symbol: str = "" # 移仓合约
new_vt_symbol: str = "" # 目标合约
payup: int = 5 # 委托超价
max_volume: int = 100 # 单比上限

class RolloverTool(QtWidgets.QDialog):
""""""
header_cols: list
new_vt_symbols: list
data_list: list

def __init__(self, portfolioStrategyManager: "PortfolioStrategyManager") -> None:
    """"""
    super().__init__()

    self.portfolioStrategyManager: "PortfolioStrategyManager" = portfolioStrategyManager

    self.strategy_engine: StrategyEngine = portfolioStrategyManager.strategy_engine
    self.main_engine: MainEngine = portfolioStrategyManager.main_engine

    self.header_cols = ["移仓合约", "目标合约", "委托超价", "单比上限"]
    self.new_vt_symbols = []
    self.data_list = []

    self.init_ui()
    self.refresh_data()
    self.refresh_ui()

def init_ui(self) -> None:
    """"""
    self.setWindowTitle("移仓助手")

    self.text_edit = QtWidgets.QLineEdit()
    self.text_edit.setPlaceholderText("请输入目标合约,多个合约用,号分隔")

    import_button = QtWidgets.QPushButton("导入")
    import_button.setSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Fixed)
    import_button.clicked.connect(self.handle_import_button_clicked)

    # download_button = QtWidgets.QPushButton("网络获取")
    # download_button.setSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Fixed)
    # download_button.clicked.connect(self.handle_download_button_clicked)

    top_hbox = QtWidgets.QHBoxLayout()
    top_hbox.addWidget(self.text_edit, 2)
    top_hbox.addWidget(import_button)
    # top_hbox.addWidget(download_button)
    top_hbox.setContentsMargins(0, 0, 0, 0)
    top_hbox.setSpacing(15)

    self.table = QtWidgets.QTableWidget()
    self.table.horizontalHeader().setVisible(False)
    self.table.verticalHeader().setSectionResizeMode(QtWidgets.QHeaderView.Fixed)
    self.table.setEditTriggers(QtWidgets.QAbstractItemView.NoEditTriggers)
    self.table.setSelectionMode(QtWidgets.QAbstractItemView.NoSelection)
    self.table.verticalScrollBar().setEnabled(False)
    self.table.verticalHeader().setDefaultSectionSize(CELL_HEIGHT)
    self.table.setSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Fixed)
    self.table.setFixedHeight(CELL_HEIGHT * len(self.header_cols) + self.table.horizontalScrollBar().height())
    self.table.setVerticalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOff)
    self.table.setHorizontalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOn)

    self.rollover_button = QtWidgets.QPushButton("移仓")
    self.rollover_button.clicked.connect(self.handle_rollover_button_clicked)

    self.log_edit: QtWidgets.QTextEdit = QtWidgets.QTextEdit()
    self.log_edit.setReadOnly(True)
    self.log_edit.setMinimumWidth(500)

    vbox = QtWidgets.QVBoxLayout()
    vbox.addLayout(top_hbox)
    vbox.addWidget(self.table)
    vbox.addWidget(self.rollover_button)
    vbox.addWidget(self.log_edit)

    self.setLayout(vbox)

def refresh_ui(self) -> None:
    """刷新UI"""
    self.refresh_table()
    self.refresh_button()

def extract_non_digits(self, string) -> str:
    pattern = r'\D+'
    result = re.findall(pattern, string)
    return ''.join(result)

def refresh_data(self) -> None:
    """刷新列表数据"""
    self.data_list: list[RolloverData] = []

    symbol_dict: dict[str, str] = {}
    if self.new_vt_symbols:
        for new_vt_symbol in self.new_vt_symbols:
            key = self.extract_non_digits(new_vt_symbol)
            symbol_dict[key] = new_vt_symbol

    old_symbols: list = []
    for vt_symbol, strategies in self.strategy_engine.symbol_strategy_map.items():
        if strategies:
            old_symbols.append(vt_symbol)
            data = RolloverData()
            data.vt_symbol = vt_symbol
            non_digits_sybmol = self.extract_non_digits(vt_symbol)
            if non_digits_sybmol in symbol_dict and vt_symbol != symbol_dict[non_digits_sybmol]:
                data.new_vt_symbol = symbol_dict[non_digits_sybmol]
            self.data_list.append(data)

    self.sort_data()

def sort_data(self) -> None:
    """置顶移仓合约"""
    self.data_list.sort(key=lambda x: not (x.new_vt_symbol is not None and x.new_vt_symbol != "" and x.new_vt_symbol != x.vt_symbol))

def filter_data(self) -> list:
    """筛选数据"""
    return [data for data in self.data_list if data.new_vt_symbol is not None and data.new_vt_symbol != ""]

def refresh_table(self) -> None:
    """刷新表格"""
    self.table.setRowCount(len(self.header_cols))
    self.table.setColumnCount(len(self.data_list))

    self.table.setVerticalHeaderLabels(self.header_cols)

    for col_index, data in enumerate(self.data_list):
        self.table.setColumnWidth(col_index, CELL_WIDTH)

        highlighted: bool = data.new_vt_symbol is not None and \
            data.new_vt_symbol != "" and data.new_vt_symbol != data.vt_symbol

        for row_index, row in enumerate(self.header_cols):
            self.table.setRowHeight(row_index, CELL_HEIGHT)

            cell_widget = self.table.cellWidget(row_index, col_index)
            if cell_widget:
                self.table.removeCellWidget(row_index, col_index)
                cell_widget.deleteLater()

            cell: RolloverBaseTableWidget = None
            if row == "移仓合约":
                cell = RolloverTextTableWidget()
                cell.set_content(data.vt_symbol)
                self.table.setCellWidget(row_index, col_index, cell)
            elif row == "目标合约":
                cell = RolloverComboTableWidget()
                cell.set_setting(self.new_vt_symbols)
                cell.set_value(data.new_vt_symbol)
                cell.addTarget(self.combobox_activated)
                self.table.setCellWidget(row_index, col_index, cell)
            elif row == "委托超价" and highlighted:
                cell = RolloverSpinTableWidget()
                cell.set_setting(5, 1000)
                cell.set_value(data.payup)
                cell.addTarget(self.spinbox_valueChanged)
                self.table.setCellWidget(row_index, col_index, cell)
            elif row == "单比上限" and highlighted:
                cell = RolloverSpinTableWidget()
                cell.set_setting(1, 10000)
                cell.set_value(data.max_volume)
                cell.addTarget(self.spinbox_valueChanged)
                self.table.setCellWidget(row_index, col_index, cell)
            if cell:
                cell.set_highlighted(highlighted)

def refresh_button(self) -> None:
    """刷新移仓按钮状态"""
    selected_list = self.filter_data()
    if selected_list:
        self.rollover_button.setEnabled(True)
    else:
        self.rollover_button.setEnabled(False)

def write_log(self, text: str) -> None:
    """"""
    now: datetime = datetime.now()
    text: str = now.strftime("%H:%M:%S\t") + text
    self.log_edit.append(text)

def subscribe(self, vt_symbol: str) -> None:
    """"""
    contract: Optional[ContractData] = self.main_engine.get_contract(vt_symbol)
    if not contract:
        return

    req: SubscribeRequest = SubscribeRequest(contract.symbol, contract.exchange)
    self.main_engine.subscribe(req, contract.gateway_name)

def roll_all(self) -> None:
    """"""
    selected_list = self.filter_data()

    # 检策略状态
    for selected in selected_list:
        # Check all strategies inited (pos data loaded from disk json file) and not trading
        strategies = self.strategy_engine.symbol_strategy_map[selected.vt_symbol]
        for strategy in strategies:
            if not strategy.inited:
                self.write_log(f"策略{strategy.strategy_name}尚未初始化,无法执行移仓")
                return

            if strategy.trading:
                self.write_log(f"策略{strategy.strategy_name}正在运行中,无法执行移仓")
                return

    for selected in selected_list:
        old_symbol: str = selected.vt_symbol

        new_symbol: str = selected.new_vt_symbol
        self.subscribe(new_symbol)
        sleep(1)

        new_tick: Optional[TickData] = self.main_engine.get_tick(new_symbol)
        if not new_tick:
            self.write_log(f"无法获取目标合约{new_symbol}的盘口数据,请先订阅行情")
            return

        payup: int = selected.payup
        max_volume: int = selected.max_volume

        # Roll position first
        self.roll_position(old_symbol, new_symbol, payup, max_volume)

    # Then roll strategy
    symbol_pairs = []
    strategies = []
    for selected in selected_list:
        symbol_pairs.append((selected.vt_symbol, selected.new_vt_symbol))
        strategies.extend(self.strategy_engine.symbol_strategy_map[selected.vt_symbol])

    strategies = list(set(strategies))

    for strategy in copy(strategies):
        self.roll_strategy(strategy, symbol_pairs)

    # Disable self
    self.setEnabled(False)

def roll_position(self, old_symbol: str, new_symbol: str, payup: int, max_volume: int) -> None:
    """"""
    contract: ContractData = self.main_engine.get_contract(old_symbol)
    converter: OffsetConverter = self.main_engine.get_converter(contract.gateway_name)
    holding: PositionHolding = converter.get_position_holding(old_symbol)

    # Roll long position
    if holding.long_pos:
        volume: float = holding.long_pos

        self.send_order(
            old_symbol,
            Direction.SHORT,
            Offset.CLOSE,
            payup,
            volume,
            max_volume
        )

        self.send_order(
            new_symbol,
            Direction.LONG,
            Offset.OPEN,
            payup,
            volume,
            max_volume
        )

    # Roll short postiion
    if holding.short_pos:
        volume: float = holding.short_pos

        self.send_order(
            old_symbol,
            Direction.LONG,
            Offset.CLOSE,
            payup,
            volume,
            max_volume
        )

        self.send_order(
            new_symbol,
            Direction.SHORT,
            Offset.OPEN,
            payup,
            volume,
            max_volume
        )

def roll_strategy(self, strategy: StrategyTemplate, symbol_pairs: List[Tuple[str, str]]) -> None:
    """"""
    if not strategy.inited:
        self.strategy_engine._init_strategy(strategy.strategy_name)

    # Save data of old strategy
    name: str = strategy.strategy_name
    pos_data = strategy.pos_data
    vt_symbols = strategy.vt_symbols
    parameters: dict = strategy.get_parameters()

    # Update vt_symbols
    new_vt_symbols = []
    for item in vt_symbols:
        flag: bool = False
        for symbol_pair in symbol_pairs:
            if item == symbol_pair[0]:
                new_vt_symbols.append(symbol_pair[1])
                flag = True
                break

        if not flag:
            new_vt_symbols.append(item)

    # Remove old strategy
    result: bool = self.strategy_engine.remove_strategy(name)
    if result:
        self.portfolioStrategyManager.remove_strategy(name)

    self.write_log(f"移除老策略{name}[{strategy.vt_symbols}]")

    # Add new strategy
    self.strategy_engine.add_strategy(
        strategy.__class__.__name__,
        name,
        new_vt_symbols,
        parameters
    )
    self.write_log(f"创建策略{name}[{new_vt_symbols}]")

    # Init new strategy
    self.strategy_engine.init_strategy(name)
    self.write_log(f"初始化策略{name}[{new_vt_symbols}]")

    # Update pos to new strategy
    new_strategy: StrategyTemplate = self.strategy_engine.strategies[name]
    new_strategy.pos_data = pos_data
    new_strategy.sync_data()
    self.write_log(f"更新策略仓位{name}[{new_vt_symbols}]")

def send_order(
    self,
    vt_symbol: str,
    direction: Direction,
    offset: Offset,
    payup: int,
    volume: float,
    max_volume: int
) -> None:
    """
    Send a new order to server.
    """
    contract: Optional[ContractData] = self.main_engine.get_contract(vt_symbol)
    tick: Optional[TickData] = self.main_engine.get_tick(vt_symbol)

    if direction == Direction.LONG:
        price = tick.ask_price_1 + contract.pricetick * payup
    else:
        price = tick.bid_price_1 - contract.pricetick * payup

    while True:
        order_volume: int = min(volume, max_volume)

        original_req: OrderRequest = OrderRequest(
            symbol=contract.symbol,
            exchange=contract.exchange,
            direction=direction,
            offset=offset,
            type=OrderType.LIMIT,
            price=price,
            volume=order_volume,
            reference=f"{APP_NAME}_Rollover"
        )

        req_list: List[OrderRequest] = self.main_engine.convert_order_request(
            original_req,
            contract.gateway_name,
            False,
            False
        )

        vt_orderids: list = []
        for req in req_list:
            vt_orderid: str = self.main_engine.send_order(req, contract.gateway_name)
            if not vt_orderid:
                continue

            vt_orderids.append(vt_orderid)
            self.main_engine.update_order_request(req, vt_orderid, contract.gateway_name)

            msg: str = f"发出委托{vt_symbol},{direction.value} {offset.value},{volume}@{price}"
            self.write_log(msg)

        # Check whether all volume sent
        volume = volume - order_volume
        if not volume:
            break

def import_symbols(self, text: str) -> None:
    """导入合约"""
    result = self.check_input(text)
    if result:
        self.show_message_box(result, self.check_new_symbols_message_box_action)
    else:
        self.new_vt_symbols = text.split(",")
        self.new_vt_symbols.insert(0, "")
        self.refresh_data()
        self.refresh_ui()

def check_input(self, text: str) -> str:
    """
    1. 文字非空判断,如空,弹窗提示目标合约为空
    2. 文字合法性判断,如不合法,弹窗提示目标合约不合法
    """
    if text:
        return ""
    else:
        return "目标合约不能为空!"

def handle_enter_pressed(self) -> None:
    text = self.text_edit.text()
    self.import_symbols(text)

def handle_import_button_clicked(self) -> None:
    self.import_symbols(self.text_edit.text())

def handle_download_button_clicked(self) -> None:
    self.text_edit.setText("")
    self.import_symbols("")

def handle_rollover_button_clicked(self) -> None:
    """"""
    selected_result = self.filter_data()

    message = "请确认以下移仓信息: \n"
    for selected in selected_result:
        row_str = f"移仓合约: {selected.vt_symbol}\t目标合约: {selected.new_vt_symbol}\t委托超价: {selected.payup}\t单比上限: {selected.max_volume}"
        message += row_str + "\n"

    self.show_dialog(message, self.check_rollerover_message_box_action)

def combobox_activated(self) -> None:
    """"""
    combobox: QtWidgets.QComboBox = self.sender()
    combobox_parent = combobox.parent()
    index = self.table.indexAt(combobox_parent.pos())
    col = self.data_list[index.column()]
    row = self.header_cols[index.row()]

    if row == "目标合约":
        col.new_vt_symbol = combobox.currentText()

    self.sort_data()
    self.refresh_ui()

def spinbox_valueChanged(self) -> None:
    """"""
    spinbox: QtWidgets.QSpinBox = self.sender()
    spinbox_parent = spinbox.parent()
    index = self.table.indexAt(spinbox_parent.pos())

    col = self.data_list[index.column()]
    row = self.header_cols[index.row()]

    if row == "委托超价":
        col.payup = spinbox.value()
    elif row == "单比上限":
        col.max_volume = spinbox.value()

    self.sort_data()
    self.refresh_ui()

def show_message_box(self, message, response_callback) -> None:
    msg_box = QtWidgets.QMessageBox()
    msg_box.setWindowTitle("提示")
    msg_box.setText(message)

    msg_box.setIcon(QtWidgets.QMessageBox.Information)

    msg_box.addButton("确认", QtWidgets.QMessageBox.AcceptRole)
    msg_box.addButton("取消", QtWidgets.QMessageBox.RejectRole)

    result = msg_box.exec()

    if result == QtWidgets.QMessageBox.AcceptRole:
        response_callback(True)
    elif result == QtWidgets.QMessageBox.RejectRole:
        response_callback(False)

def show_dialog(self, message, response_callback) -> None:
    """"""
    dialog = QtWidgets.QDialog()
    dialog.setWindowTitle("提示")
    dialog.setMinimumSize(200, 100)

    layout = QtWidgets.QVBoxLayout(dialog)
    label = QtWidgets.QLabel(message)
    layout.addWidget(label)

    ok_button = QtWidgets.QPushButton("确定")
    ok_button.clicked.connect(dialog.accept)
    cancel_button = QtWidgets.QPushButton("取消")
    cancel_button.clicked.connect(dialog.reject)
    layout.addWidget(ok_button)
    layout.addWidget(cancel_button)

    result = dialog.exec()
    if result == QtWidgets.QDialog.Accepted:
        response_callback(True)
    else:
        response_callback(False)

def check_new_symbols_message_box_action(self, result) -> None:
    """"""
    if result:
        pass
    else:
        pass

def check_rollerover_message_box_action(self, result) -> None:
    """"""
    if result:
        self.roll_all()
    else:
        pass


class RolloverDialog(QtWidgets.QDialog):
"""
提示确认已经配置好的移仓信息
"""

def __init__(self) -> None:
    super().__init__()

def init_ui(self) -> None:
    """
    初始化界面
    """


class RolloverBaseTableWidget(QtWidgets.QWidget):
"""基类"""
def init(self) -> None:
"""构造函数"""
super().init()

    self.setSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding)

def set_highlighted(self, highlighted: bool) -> None:
    if highlighted:
        self.setStyleSheet("background-color: rgb(25,100,107);")
    else:
        self.setStyleSheet("")


class RolloverTextTableWidget(RolloverBaseTableWidget):
"""
文本的表格项
"""
def init(self) -> None:
"""构造函数"""
super().init()

    self.init_ui()

def init_ui(self) -> None:
    """初始化界面"""
    self.label = QtWidgets.QLabel()
    self.label.setSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding)

    hbox = QtWidgets.QHBoxLayout()
    hbox.addWidget(self.label)
    hbox.setContentsMargins(0, 0, 0, 0)

    self.setLayout(hbox)

def set_content(self, content: object) -> None:
    """设置内容"""
    self.label.setText(content)


class RolloverComboTableWidget(RolloverBaseTableWidget):
"""
下拉框的表格项
"""
def init(self) -> None:
"""构造函数"""
super().init()

    self.init_ui()

def init_ui(self) -> None:
    """初始化界面"""
    self.vt_symbol_combo: QtWidgets.QComboBox = QtWidgets.QComboBox()
    self.vt_symbol_combo.setSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding)

    hbox = QtWidgets.QHBoxLayout()
    hbox.addWidget(self.vt_symbol_combo)
    hbox.setContentsMargins(0, 0, 0, 0)

    self.setLayout(hbox)

def set_setting(self, content: object) -> None:
    """设置内容"""
    self.vt_symbol_combo.addItems(content)

def set_value(self, content: object) -> None:
    """设置内容"""
    self.vt_symbol_combo.setCurrentText(content)

def addTarget(self, target: Any) -> None:
    """添加点击事件"""
    self.vt_symbol_combo.activated.connect(target)


class RolloverSpinTableWidget(RolloverBaseTableWidget):
"""
数值输入框的表格项
"""
def init(self) -> None:
"""构造函数"""
super().init()

    self.init_ui()

def init_ui(self) -> None:
    """初始化界面"""
    self.span: QtWidgets.QSpinBox = QtWidgets.QSpinBox()
    self.span.setSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding)

    hbox = QtWidgets.QHBoxLayout()
    hbox.addWidget(self.span)
    hbox.setContentsMargins(0, 0, 0, 0)

    self.setLayout(hbox)

def set_setting(self, minimum: int, maximum: int) -> None:
    """设置内容"""
    self.span.setMinimum(minimum)
    self.span.setMaximum(maximum)

def set_value(self, content: int) -> None:
    """设置值"""
    self.span.setValue(content)

def addTarget(self, target: Any) -> None:
    """添加点击事件"""
    self.span.valueChanged.connect(target)

- 代码位置-vnpy_portfoliostrategy.vnpy_portfoliostrategy.ui.widget.py

class PortfolioStrategyManager(QtWidgets.QWidget):
...
def init_ui(self) -> None:
...
init_button: QtWidgets.QPushButton = QtWidgets.QPushButton("全部初始化")
init_button.clicked.connect(self.strategy_engine.init_all_strategies)
roll_button: QtWidgets.QPushButton = QtWidgets.QPushButton("移仓助手")
roll_button.clicked.connect(self.roll)
...

def roll(self) -> None:
""""""
dialog: RolloverTool = RolloverTool(self)
dialog.exec_()
```

Member
avatar
加入于:
帖子: 1631
声望: 118

感谢分享!

Member
avatar
加入于:
帖子: 3
声望: 0

感谢

Member
avatar
加入于:
帖子: 39
声望: 0

感谢分享,但是我添加之后没有显示按钮,不知道是什么原因

Member
avatar
加入于:
帖子: 11
声望: 0

wrote:

感谢分享,但是我添加之后没有显示按钮,不知道是什么原因

pip install一下你修改后的vnpy_portfolistrategy

© 2015-2022 上海韦纳软件科技有限公司
备案服务号:沪ICP备18006526号

沪公网安备 31011502017034号

【用户协议】
【隐私政策】
【免责条款】