上一主题失败了
上一主题,“为K线图表添砖加瓦——让CTA策略的运行看得见”,可以说是失败了,原因已经找到,是因为掉到类变量和实例变量的坑里了。具体过程参加这个主题:
https://www.vnpy.com/forum/topic/3860-wei-kxian-tu-biao-tian-zhuan-jia-wa-rang-ctace-lue-de-yun-xing-kan-de-jian
问题已经解决了,并且做了改进
1 解决了不同合约同时运行_kx_strategy策略时,K线图会互相影响的问题;
2 去掉策略管理器中的“K线图表”按钮,保持与原来的界面一致,在_kx_strategy策略中增加一个show_chart参数项目,如果想显示K线图,为它配置为True,否则不会显示K线图;
3 增加策略被移除时,删除该策略的K线图表功能
4 K线图表中的显示内容在_kx_strategy策略中配置,而不是一个固定的主图和附图搭配。参照我的init_kx_chart()方法,您也可以为自己的策略配置自己的K线主图和附图指标;
5 添加最后一根了临时K线的显示
改进后的实现方法
先备份文件
vnpy\app\cta_strategy\base.py
vnpy\app\cta_strategy\engine.py
vnpy\app\cta_strategy\ui\widget.py
vnpy\app\cta_backtester\engine.py
修改vnpy\app\cta_strategy\base.py
"""
Defines constants and objects used in CtaStrategy App.
"""
from dataclasses import dataclass, field
from enum import Enum
from datetime import timedelta
from vnpy.trader.constant import Direction, Offset, Interval
APP_NAME = "CtaStrategy"
STOPORDER_PREFIX = "STOP"
class StopOrderStatus(Enum):
WAITING = "等待中"
CANCELLED = "已撤销"
TRIGGERED = "已触发"
class EngineType(Enum):
LIVE = "实盘"
BACKTESTING = "回测"
class BacktestingMode(Enum):
BAR = 1
TICK = 2
@dataclass
class StopOrder:
vt_symbol: str
direction: Direction
offset: Offset
price: float
volume: float
stop_orderid: str
strategy_name: str
lock: bool = False
vt_orderids: list = field(default_factory=list)
status: StopOrderStatus = StopOrderStatus.WAITING
EVENT_CTA_LOG = "eCtaLog"
EVENT_CTA_STRATEGY = "eCtaStrategy"
EVENT_CTA_STOPORDER = "eCtaStopOrder"
EVENT_CTA_TICK = "eCtaTick" # hxxjava add
EVENT_CTA_HISTORY_BAR = "eCtaHistoryBar" # hxxjava add
EVENT_CTA_BAR = "eCtaBar" # hxxjava add
EVENT_CTA_ORDER = "eCtaOrder" # hxxjava add
EVENT_CTA_TRADE = "eCtaTrade" # hxxjava add
INTERVAL_DELTA_MAP = {
Interval.MINUTE: timedelta(minutes=1),
Interval.HOUR: timedelta(hours=1),
Interval.DAILY: timedelta(days=1),
}
修改vnpy\app\cta_strategy\engine.py
""""""
import importlib
import os
import traceback
from collections import defaultdict
from pathlib import Path
from typing import Any, Callable
from datetime import datetime, timedelta
from concurrent.futures import ThreadPoolExecutor
from copy import copy
from tzlocal import get_localzone
from vnpy.event import Event, EventEngine
from vnpy.trader.engine import BaseEngine, MainEngine
from vnpy.trader.object import (
OrderRequest,
SubscribeRequest,
HistoryRequest,
LogData,
TickData,
BarData,
ContractData
)
from vnpy.trader.event import (
EVENT_TICK,
EVENT_ORDER,
EVENT_TRADE,
EVENT_POSITION
)
from vnpy.trader.constant import (
Direction,
OrderType,
Interval,
Exchange,
Offset,
Status
)
from vnpy.trader.utility import load_json, save_json, extract_vt_symbol, round_to
from vnpy.trader.database import database_manager
from vnpy.trader.rqdata import rqdata_client
from vnpy.trader.converter import OffsetConverter
from .base import (
APP_NAME,
EVENT_CTA_LOG,
EVENT_CTA_STRATEGY,
EVENT_CTA_STOPORDER,
EngineType,
StopOrder,
StopOrderStatus,
STOPORDER_PREFIX
)
from .template import CtaTemplate
STOP_STATUS_MAP = {
Status.SUBMITTING: StopOrderStatus.WAITING,
Status.NOTTRADED: StopOrderStatus.WAITING,
Status.PARTTRADED: StopOrderStatus.TRIGGERED,
Status.ALLTRADED: StopOrderStatus.TRIGGERED,
Status.CANCELLED: StopOrderStatus.CANCELLED,
Status.REJECTED: StopOrderStatus.CANCELLED
}
class CtaEngine(BaseEngine):
""""""
engine_type = EngineType.LIVE # live trading engine
setting_filename = "cta_strategy_setting.json"
data_filename = "cta_strategy_data.json"
def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
""""""
super(CtaEngine, self).__init__(
main_engine, event_engine, APP_NAME)
self.strategy_setting = {} # strategy_name: dict
self.strategy_data = {} # strategy_name: dict
self.classes = {} # class_name: stategy_class
self.strategies = {} # strategy_name: strategy
self.symbol_strategy_map = defaultdict(
list) # vt_symbol: strategy list
self.orderid_strategy_map = {} # vt_orderid: strategy
self.strategy_orderid_map = defaultdict(
set) # strategy_name: orderid list
self.stop_order_count = 0 # for generating stop_orderid
self.stop_orders = {} # stop_orderid: stop_order
self.init_executor = ThreadPoolExecutor(max_workers=1)
self.rq_client = None
self.rq_symbols = set()
self.vt_tradeids = set() # for filtering duplicate trade
self.offset_converter = OffsetConverter(self.main_engine)
def init_engine(self):
"""
"""
self.init_rqdata()
self.load_strategy_class()
self.load_strategy_setting()
self.load_strategy_data()
self.register_event()
self.write_log("CTA策略引擎初始化成功")
def close(self):
""""""
self.stop_all_strategies()
def register_event(self):
""""""
self.event_engine.register(EVENT_TICK, self.process_tick_event)
self.event_engine.register(EVENT_ORDER, self.process_order_event)
self.event_engine.register(EVENT_TRADE, self.process_trade_event)
self.event_engine.register(EVENT_POSITION, self.process_position_event)
def init_rqdata(self):
"""
Init RQData client.
"""
result = rqdata_client.init()
if result:
self.write_log("RQData数据接口初始化成功")
def query_bar_from_rq(
self, symbol: str, exchange: Exchange, interval: Interval, start: datetime, end: datetime
):
"""
Query bar data from RQData.
"""
req = HistoryRequest(
symbol=symbol,
exchange=exchange,
interval=interval,
start=start,
end=end
)
data = rqdata_client.query_history(req)
return data
def process_tick_event(self, event: Event):
""""""
tick = event.data
strategies = self.symbol_strategy_map[tick.vt_symbol]
if not strategies:
return
self.check_stop_order(tick)
for strategy in strategies:
if strategy.inited:
self.call_strategy_func(strategy, strategy.on_tick, tick)
def process_order_event(self, event: Event):
""""""
order = event.data
self.offset_converter.update_order(order)
strategy = self.orderid_strategy_map.get(order.vt_orderid, None)
if not strategy:
return
# Remove vt_orderid if order is no longer active.
vt_orderids = self.strategy_orderid_map[strategy.strategy_name]
if order.vt_orderid in vt_orderids and not order.is_active():
vt_orderids.remove(order.vt_orderid)
# For server stop order, call strategy on_stop_order function
if order.type == OrderType.STOP:
so = StopOrder(
vt_symbol=order.vt_symbol,
direction=order.direction,
offset=order.offset,
price=order.price,
volume=order.volume,
stop_orderid=order.vt_orderid,
strategy_name=strategy.strategy_name,
status=STOP_STATUS_MAP[order.status],
vt_orderids=[order.vt_orderid],
)
self.call_strategy_func(strategy, strategy.on_stop_order, so)
# Call strategy on_order function
self.call_strategy_func(strategy, strategy.on_order, order)
def process_trade_event(self, event: Event):
""""""
trade = event.data
# Filter duplicate trade push
if trade.vt_tradeid in self.vt_tradeids:
return
self.vt_tradeids.add(trade.vt_tradeid)
self.offset_converter.update_trade(trade)
strategy = self.orderid_strategy_map.get(trade.vt_orderid, None)
if not strategy:
return
# Update strategy pos before calling on_trade method
if trade.direction == Direction.LONG:
strategy.pos += trade.volume
else:
strategy.pos -= trade.volume
self.call_strategy_func(strategy, strategy.on_trade, trade)
# Sync strategy variables to data file
self.sync_strategy_data(strategy)
# Update GUI
self.put_strategy_event(strategy)
def process_position_event(self, event: Event):
""""""
position = event.data
self.offset_converter.update_position(position)
def check_stop_order(self, tick: TickData):
""""""
for stop_order in list(self.stop_orders.values()):
if stop_order.vt_symbol != tick.vt_symbol:
continue
long_triggered = (
stop_order.direction == Direction.LONG and tick.last_price >= stop_order.price
)
short_triggered = (
stop_order.direction == Direction.SHORT and tick.last_price <= stop_order.price
)
if long_triggered or short_triggered:
strategy = self.strategies[stop_order.strategy_name]
# To get excuted immediately after stop order is
# triggered, use limit price if available, otherwise
# use ask_price_5 or bid_price_5
if stop_order.direction == Direction.LONG:
if tick.limit_up:
price = tick.limit_up
else:
price = tick.ask_price_5
else:
if tick.limit_down:
price = tick.limit_down
else:
price = tick.bid_price_5
contract = self.main_engine.get_contract(stop_order.vt_symbol)
vt_orderids = self.send_limit_order(
strategy,
contract,
stop_order.direction,
stop_order.offset,
price,
stop_order.volume,
stop_order.lock
)
# Update stop order status if placed successfully
if vt_orderids:
# Remove from relation map.
self.stop_orders.pop(stop_order.stop_orderid)
strategy_vt_orderids = self.strategy_orderid_map[strategy.strategy_name]
if stop_order.stop_orderid in strategy_vt_orderids:
strategy_vt_orderids.remove(stop_order.stop_orderid)
# Change stop order status to cancelled and update to strategy.
stop_order.status = StopOrderStatus.TRIGGERED
stop_order.vt_orderids = vt_orderids
self.call_strategy_func(
strategy, strategy.on_stop_order, stop_order
)
self.put_stop_order_event(stop_order)
def send_server_order(
self,
strategy: CtaTemplate,
contract: ContractData,
direction: Direction,
offset: Offset,
price: float,
volume: float,
type: OrderType,
lock: bool
):
"""
Send a new order to server.
"""
# Create request and send order.
original_req = OrderRequest(
symbol=contract.symbol,
exchange=contract.exchange,
direction=direction,
offset=offset,
type=type,
price=price,
volume=volume,
)
# Convert with offset converter
req_list = self.offset_converter.convert_order_request(original_req, lock)
# Send Orders
vt_orderids = []
for req in req_list:
req.reference = strategy.strategy_name # Add strategy name as order reference
vt_orderid = self.main_engine.send_order(
req, contract.gateway_name)
# Check if sending order successful
if not vt_orderid:
continue
vt_orderids.append(vt_orderid)
self.offset_converter.update_order_request(req, vt_orderid)
# Save relationship between orderid and strategy.
self.orderid_strategy_map[vt_orderid] = strategy
self.strategy_orderid_map[strategy.strategy_name].add(vt_orderid)
return vt_orderids
def send_limit_order(
self,
strategy: CtaTemplate,
contract: ContractData,
direction: Direction,
offset: Offset,
price: float,
volume: float,
lock: bool
):
"""
Send a limit order to server.
"""
return self.send_server_order(
strategy,
contract,
direction,
offset,
price,
volume,
OrderType.LIMIT,
lock
)
def send_server_stop_order(
self,
strategy: CtaTemplate,
contract: ContractData,
direction: Direction,
offset: Offset,
price: float,
volume: float,
lock: bool
):
"""
Send a stop order to server.
Should only be used if stop order supported
on the trading server.
"""
return self.send_server_order(
strategy,
contract,
direction,
offset,
price,
volume,
OrderType.STOP,
lock
)
def send_local_stop_order(
self,
strategy: CtaTemplate,
direction: Direction,
offset: Offset,
price: float,
volume: float,
lock: bool
):
"""
Create a new local stop order.
"""
self.stop_order_count += 1
stop_orderid = f"{STOPORDER_PREFIX}.{self.stop_order_count}"
stop_order = StopOrder(
vt_symbol=strategy.vt_symbol,
direction=direction,
offset=offset,
price=price,
volume=volume,
stop_orderid=stop_orderid,
strategy_name=strategy.strategy_name,
lock=lock
)
self.stop_orders[stop_orderid] = stop_order
vt_orderids = self.strategy_orderid_map[strategy.strategy_name]
vt_orderids.add(stop_orderid)
self.call_strategy_func(strategy, strategy.on_stop_order, stop_order)
self.put_stop_order_event(stop_order)
return [stop_orderid]
def cancel_server_order(self, strategy: CtaTemplate, vt_orderid: str):
"""
Cancel existing order by vt_orderid.
"""
order = self.main_engine.get_order(vt_orderid)
if not order:
self.write_log(f"撤单失败,找不到委托{vt_orderid}", strategy)
return
req = order.create_cancel_request()
self.main_engine.cancel_order(req, order.gateway_name)
def cancel_local_stop_order(self, strategy: CtaTemplate, stop_orderid: str):
"""
Cancel a local stop order.
"""
stop_order = self.stop_orders.get(stop_orderid, None)
if not stop_order:
return
strategy = self.strategies[stop_order.strategy_name]
# Remove from relation map.
self.stop_orders.pop(stop_orderid)
vt_orderids = self.strategy_orderid_map[strategy.strategy_name]
if stop_orderid in vt_orderids:
vt_orderids.remove(stop_orderid)
# Change stop order status to cancelled and update to strategy.
stop_order.status = StopOrderStatus.CANCELLED
self.call_strategy_func(strategy, strategy.on_stop_order, stop_order)
self.put_stop_order_event(stop_order)
def send_order(
self,
strategy: CtaTemplate,
direction: Direction,
offset: Offset,
price: float,
volume: float,
stop: bool,
lock: bool
):
"""
"""
contract = self.main_engine.get_contract(strategy.vt_symbol)
if not contract:
self.write_log(f"委托失败,找不到合约:{strategy.vt_symbol}", strategy)
return ""
# Round order price and volume to nearest incremental value
price = round_to(price, contract.pricetick)
volume = round_to(volume, contract.min_volume)
if stop:
if contract.stop_supported:
return self.send_server_stop_order(strategy, contract, direction, offset, price, volume, lock)
else:
return self.send_local_stop_order(strategy, direction, offset, price, volume, lock)
else:
return self.send_limit_order(strategy, contract, direction, offset, price, volume, lock)
def cancel_order(self, strategy: CtaTemplate, vt_orderid: str):
"""
"""
if vt_orderid.startswith(STOPORDER_PREFIX):
self.cancel_local_stop_order(strategy, vt_orderid)
else:
self.cancel_server_order(strategy, vt_orderid)
def cancel_all(self, strategy: CtaTemplate):
"""
Cancel all active orders of a strategy.
"""
vt_orderids = self.strategy_orderid_map[strategy.strategy_name]
if not vt_orderids:
return
for vt_orderid in copy(vt_orderids):
self.cancel_order(strategy, vt_orderid)
def get_engine_type(self):
""""""
return self.engine_type
def get_pricetick(self, strategy: CtaTemplate):
"""
Return contract pricetick data.
"""
contract = self.main_engine.get_contract(strategy.vt_symbol)
if contract:
return contract.pricetick
else:
return None
def load_bar(
self,
vt_symbol: str,
days: int,
interval: Interval,
callback: Callable[[BarData], None],
use_database: bool
):
""""""
symbol, exchange = extract_vt_symbol(vt_symbol)
end = datetime.now(get_localzone())
start = end - timedelta(days)
bars = []
# Pass gateway and RQData if use_database set to True
if not use_database:
# Query bars from gateway if available
contract = self.main_engine.get_contract(vt_symbol)
if contract and contract.history_data:
req = HistoryRequest(
symbol=symbol,
exchange=exchange,
interval=interval,
start=start,
end=end
)
bars = self.main_engine.query_history(req, contract.gateway_name)
# Try to query bars from RQData, if not found, load from database.
else:
bars = self.query_bar_from_rq(symbol, exchange, interval, start, end)
if not bars:
bars = database_manager.load_bar_data(
symbol=symbol,
exchange=exchange,
interval=interval,
start=start,
end=end,
)
for bar in bars:
callback(bar)
def load_tick(
self,
vt_symbol: str,
days: int,
callback: Callable[[TickData], None]
):
""""""
symbol, exchange = extract_vt_symbol(vt_symbol)
end = datetime.now()
start = end - timedelta(days)
ticks = database_manager.load_tick_data(
symbol=symbol,
exchange=exchange,
start=start,
end=end,
)
for tick in ticks:
callback(tick)
def call_strategy_func(
self, strategy: CtaTemplate, func: Callable, params: Any = None
):
"""
Call function of a strategy and catch any exception raised.
"""
try:
if params:
func(params)
else:
func()
except Exception:
strategy.trading = False
strategy.inited = False
msg = f"触发异常已停止\n{traceback.format_exc()}"
self.write_log(msg, strategy)
def add_strategy(
self, class_name: str, strategy_name: str, vt_symbol: str, setting: dict
):
"""
Add a new strategy.
"""
if strategy_name in self.strategies:
self.write_log(f"创建策略失败,存在重名{strategy_name}")
return
strategy_class = self.classes.get(class_name, None)
if not strategy_class:
self.write_log(f"创建策略失败,找不到策略类{class_name}")
return
strategy = strategy_class(self, strategy_name, vt_symbol, setting)
self.strategies[strategy_name] = strategy
# Add vt_symbol to strategy map.
strategies = self.symbol_strategy_map[vt_symbol]
strategies.append(strategy)
# Update to setting file.
self.update_strategy_setting(strategy_name, setting)
self.put_strategy_event(strategy)
def init_strategy(self, strategy_name: str):
"""
Init a strategy.
"""
self.init_executor.submit(self._init_strategy, strategy_name)
def _init_strategy(self, strategy_name: str):
"""
Init strategies in queue.
"""
strategy = self.strategies[strategy_name]
if strategy.inited:
self.write_log(f"{strategy_name}已经完成初始化,禁止重复操作")
return
self.write_log(f"{strategy_name}开始执行初始化")
# Call on_init function of strategy
self.call_strategy_func(strategy, strategy.on_init)
# Restore strategy data(variables)
data = self.strategy_data.get(strategy_name, None)
if data:
for name in strategy.variables:
value = data.get(name, None)
if value:
setattr(strategy, name, value)
# Subscribe market data
contract = self.main_engine.get_contract(strategy.vt_symbol)
if contract:
req = SubscribeRequest(
symbol=contract.symbol, exchange=contract.exchange)
self.main_engine.subscribe(req, contract.gateway_name)
else:
self.write_log(f"行情订阅失败,找不到合约{strategy.vt_symbol}", strategy)
# Put event to update init completed status.
strategy.inited = True
self.put_strategy_event(strategy)
self.write_log(f"{strategy_name}初始化完成")
def start_strategy(self, strategy_name: str):
"""
Start a strategy.
"""
strategy = self.strategies[strategy_name]
if not strategy.inited:
self.write_log(f"策略{strategy.strategy_name}启动失败,请先初始化")
return
if strategy.trading:
self.write_log(f"{strategy_name}已经启动,请勿重复操作")
return
self.call_strategy_func(strategy, strategy.on_start)
strategy.trading = True
self.put_strategy_event(strategy)
def stop_strategy(self, strategy_name: str):
"""
Stop a strategy.
"""
strategy = self.strategies[strategy_name]
if not strategy.trading:
return
# Call on_stop function of the strategy
self.call_strategy_func(strategy, strategy.on_stop)
# Change trading status of strategy to False
strategy.trading = False
# Cancel all orders of the strategy
self.cancel_all(strategy)
# Sync strategy variables to data file
self.sync_strategy_data(strategy)
# Update GUI
self.put_strategy_event(strategy)
def edit_strategy(self, strategy_name: str, setting: dict):
"""
Edit parameters of a strategy.
"""
strategy = self.strategies[strategy_name]
strategy.update_setting(setting)
self.update_strategy_setting(strategy_name, setting)
self.put_strategy_event(strategy)
def remove_strategy(self, strategy_name: str):
"""
Remove a strategy.
"""
strategy = self.strategies[strategy_name]
if strategy.trading:
self.write_log(f"策略{strategy.strategy_name}移除失败,请先停止")
return
# Remove setting
self.remove_strategy_setting(strategy_name)
# Remove from symbol strategy map
strategies = self.symbol_strategy_map[strategy.vt_symbol]
strategies.remove(strategy)
# Remove from active orderid map
if strategy_name in self.strategy_orderid_map:
vt_orderids = self.strategy_orderid_map.pop(strategy_name)
# Remove vt_orderid strategy map
for vt_orderid in vt_orderids:
if vt_orderid in self.orderid_strategy_map:
self.orderid_strategy_map.pop(vt_orderid)
# Remove from strategies
self.strategies.pop(strategy_name)
return True
def load_strategy_class(self):
"""
Load strategy class from source code.
"""
path1 = Path(__file__).parent.joinpath("strategies")
self.load_strategy_class_from_folder(
path1, "vnpy.app.cta_strategy.strategies")
path2 = Path.cwd().joinpath("strategies")
self.load_strategy_class_from_folder(path2, "strategies")
def load_strategy_class_from_folder(self, path: Path, module_name: str = ""):
"""
Load strategy class from certain folder.
"""
for dirpath, dirnames, filenames in os.walk(str(path)):
for filename in filenames:
if filename.split(".")[-1] in ("py", "pyd", "so"):
strategy_module_name = ".".join([module_name, filename.split(".")[0]])
self.load_strategy_class_from_module(strategy_module_name)
def load_strategy_class_from_module(self, module_name: str):
"""
Load strategy class from module file.
"""
try:
module = importlib.import_module(module_name)
# print(f"{module_name}'s module:{module}") # hxxjava add
for name in dir(module):
# print(f"name:{name}") # hxxjava add
value = getattr(module, name)
if (isinstance(value, type) and issubclass(value, CtaTemplate) and value is not CtaTemplate):
self.classes[value.__name__] = value
# print(f"value.__name__:{value.__name__}") # hxxjava add
except: # noqa
msg = f"策略文件{module_name}加载失败,触发异常:\n{traceback.format_exc()}"
self.write_log(msg)
def load_strategy_data(self):
"""
Load strategy data from json file.
"""
self.strategy_data = load_json(self.data_filename)
def sync_strategy_data(self, strategy: CtaTemplate):
"""
Sync strategy data into json file.
"""
data = strategy.get_variables()
data.pop("inited") # Strategy status (inited, trading) should not be synced.
data.pop("trading")
self.strategy_data[strategy.strategy_name] = data
save_json(self.data_filename, self.strategy_data)
def get_all_strategy_class_names(self):
"""
Return names of strategy classes loaded.
"""
return list(self.classes.keys())
def get_strategy_class_parameters(self, class_name: str):
"""
Get default parameters of a strategy class.
"""
strategy_class = self.classes[class_name]
parameters = {}
for name in strategy_class.parameters:
parameters[name] = getattr(strategy_class, name)
return parameters
def get_strategy_parameters(self, strategy_name):
"""
Get parameters of a strategy.
"""
strategy = self.strategies[strategy_name]
return strategy.get_parameters()
def init_all_strategies(self):
"""
"""
for strategy_name in self.strategies.keys():
self.init_strategy(strategy_name)
def start_all_strategies(self):
"""
"""
for strategy_name in self.strategies.keys():
self.start_strategy(strategy_name)
def stop_all_strategies(self):
"""
"""
for strategy_name in self.strategies.keys():
self.stop_strategy(strategy_name)
def load_strategy_setting(self):
"""
Load setting file.
"""
self.strategy_setting = load_json(self.setting_filename)
for strategy_name, strategy_config in self.strategy_setting.items():
self.add_strategy(
strategy_config["class_name"],
strategy_name,
strategy_config["vt_symbol"],
strategy_config["setting"]
)
def update_strategy_setting(self, strategy_name: str, setting: dict):
"""
Update setting file.
"""
strategy = self.strategies[strategy_name]
self.strategy_setting[strategy_name] = {
"class_name": strategy.__class__.__name__,
"vt_symbol": strategy.vt_symbol,
"setting": setting,
}
save_json(self.setting_filename, self.strategy_setting)
def remove_strategy_setting(self, strategy_name: str):
"""
Update setting file.
"""
if strategy_name not in self.strategy_setting:
return
self.strategy_setting.pop(strategy_name)
save_json(self.setting_filename, self.strategy_setting)
def put_stop_order_event(self, stop_order: StopOrder):
"""
Put an event to update stop order status.
"""
event = Event(EVENT_CTA_STOPORDER, stop_order)
self.event_engine.put(event)
def put_strategy_event(self, strategy: CtaTemplate):
"""
Put an event to update strategy status.
"""
data = strategy.get_data()
event = Event(EVENT_CTA_STRATEGY, data)
self.event_engine.put(event)
#--------------------------------------------------------------------------------------------------
def get_position_detail(self, vt_symbol:str):
"""
查询long_pos,short_pos(持仓),long_pnl,short_pnl(盈亏),active_order(未成交字典)
收到PositionHolding类数据
"""
try:
return self.offset_converter.get_position_holding(vt_symbol)
except:
self.write_log(f"当前获取持仓信息为:{self.offset_converter.get_position_holding(vt_symbol)},等待获取持仓信息")
position_detail = OrderedDict()
position_detail.active_orders = {}
position_detail.long_pos = 0
position_detail.long_pnl = 0
position_detail.long_yd = 0
position_detail.long_td = 0
position_detail.long_pos_frozen = 0
position_detail.long_price = 0
position_detail.short_pos = 0
position_detail.short_pnl = 0
position_detail.short_yd = 0
position_detail.short_td = 0
position_detail.short_price = 0
position_detail.short_pos_frozen = 0
return position_detail
def write_log(self, msg: str, strategy: CtaTemplate = None):
"""
Create cta engine log event.
"""
if strategy:
msg = f"{strategy.strategy_name}: {msg}"
log = LogData(msg=msg, gateway_name="CtaStrategy")
event = Event(type=EVENT_CTA_LOG, data=log)
self.event_engine.put(event)
def send_email(self, msg: str, strategy: CtaTemplate = None):
"""
Send email to default receiver.
"""
if strategy:
subject = f"{strategy.strategy_name}"
else:
subject = "CTA策略引擎"
self.main_engine.send_email(subject, msg)
修改vnpy\app\cta_strategy\ui\widget.py
from vnpy.event import Event, EventEngine
from vnpy.trader.engine import MainEngine
from vnpy.trader.ui import QtCore, QtGui, QtWidgets
from vnpy.trader.ui.widget import (
BaseCell,
EnumCell,
MsgCell,
TimeCell,
BaseMonitor
)
from ..base import (
APP_NAME,
EVENT_CTA_LOG,
EVENT_CTA_STOPORDER,
EVENT_CTA_STRATEGY,
)
from ..engine import CtaEngine
from vnpy.usertools.kx_chart import NewChartWidget # hxxjava add
class CtaManager(QtWidgets.QWidget):
""""""
signal_log = QtCore.pyqtSignal(Event)
signal_strategy = QtCore.pyqtSignal(Event)
def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
super(CtaManager, self).__init__()
self.main_engine = main_engine
self.event_engine = event_engine
self.cta_engine = main_engine.get_engine(APP_NAME)
self.managers = {}
self.init_ui()
self.register_event()
self.cta_engine.init_engine()
self.update_class_combo()
def init_ui(self):
""""""
self.setWindowTitle("CTA策略")
# Create widgets
self.class_combo = QtWidgets.QComboBox()
add_button = QtWidgets.QPushButton("添加策略")
add_button.clicked.connect(self.add_strategy)
init_button = QtWidgets.QPushButton("全部初始化")
init_button.clicked.connect(self.cta_engine.init_all_strategies)
start_button = QtWidgets.QPushButton("全部启动")
start_button.clicked.connect(self.cta_engine.start_all_strategies)
stop_button = QtWidgets.QPushButton("全部停止")
stop_button.clicked.connect(self.cta_engine.stop_all_strategies)
clear_button = QtWidgets.QPushButton("清空日志")
clear_button.clicked.connect(self.clear_log)
self.scroll_layout = QtWidgets.QVBoxLayout()
self.scroll_layout.addStretch()
scroll_widget = QtWidgets.QWidget()
scroll_widget.setLayout(self.scroll_layout)
scroll_area = QtWidgets.QScrollArea()
scroll_area.setWidgetResizable(True)
scroll_area.setWidget(scroll_widget)
self.log_monitor = LogMonitor(self.main_engine, self.event_engine)
self.stop_order_monitor = StopOrderMonitor(
self.main_engine, self.event_engine
)
# Set layout
hbox1 = QtWidgets.QHBoxLayout()
hbox1.addWidget(self.class_combo)
hbox1.addWidget(add_button)
hbox1.addStretch()
hbox1.addWidget(init_button)
hbox1.addWidget(start_button)
hbox1.addWidget(stop_button)
hbox1.addWidget(clear_button)
grid = QtWidgets.QGridLayout()
grid.addWidget(scroll_area, 0, 0, 2, 1)
grid.addWidget(self.stop_order_monitor, 0, 1)
grid.addWidget(self.log_monitor, 1, 1)
vbox = QtWidgets.QVBoxLayout()
vbox.addLayout(hbox1)
vbox.addLayout(grid)
self.setLayout(vbox)
def update_class_combo(self):
""""""
self.class_combo.addItems(
self.cta_engine.get_all_strategy_class_names()
)
def register_event(self):
""""""
self.signal_strategy.connect(self.process_strategy_event)
self.event_engine.register(
EVENT_CTA_STRATEGY, self.signal_strategy.emit
)
def process_strategy_event(self, event):
"""
Update strategy status onto its monitor.
"""
data = event.data
strategy_name = data["strategy_name"]
if strategy_name in self.managers:
manager = self.managers[strategy_name]
manager.update_data(data)
else:
manager = StrategyManager(self, self.cta_engine, data)
self.scroll_layout.insertWidget(0, manager)
self.managers[strategy_name] = manager
def remove_strategy(self, strategy_name):
""""""
manager = self.managers.pop(strategy_name)
manager.deleteLater()
def add_strategy(self):
""""""
class_name = str(self.class_combo.currentText())
if not class_name:
return
parameters = self.cta_engine.get_strategy_class_parameters(class_name)
editor = SettingEditor(parameters, class_name=class_name)
n = editor.exec_()
if n == editor.Accepted:
setting = editor.get_setting()
vt_symbol = setting.pop("vt_symbol")
strategy_name = setting.pop("strategy_name")
self.cta_engine.add_strategy(
class_name, strategy_name, vt_symbol, setting
)
def clear_log(self):
""""""
self.log_monitor.setRowCount(0)
def show(self):
""""""
self.showMaximized()
class StrategyManager(QtWidgets.QFrame):
"""
Manager for a strategy
"""
def __init__(
self, cta_manager: CtaManager, cta_engine: CtaEngine, data: dict
):
""""""
super(StrategyManager, self).__init__()
self.cta_manager = cta_manager
self.cta_engine = cta_engine
self.strategy_name = data["strategy_name"]
self._data = data
self.init_ui()
def init_ui(self):
""""""
self.setFixedHeight(300)
self.setFrameShape(self.Box)
self.setLineWidth(1)
self.init_button = QtWidgets.QPushButton("初始化")
self.init_button.clicked.connect(self.init_strategy)
self.start_button = QtWidgets.QPushButton("启动")
self.start_button.clicked.connect(self.start_strategy)
self.start_button.setEnabled(False)
self.stop_button = QtWidgets.QPushButton("停止")
self.stop_button.clicked.connect(self.stop_strategy)
self.stop_button.setEnabled(False)
self.edit_button = QtWidgets.QPushButton("编辑")
self.edit_button.clicked.connect(self.edit_strategy)
self.remove_button = QtWidgets.QPushButton("移除")
self.remove_button.clicked.connect(self.remove_strategy)
strategy_name = self._data["strategy_name"]
vt_symbol = self._data["vt_symbol"]
class_name = self._data["class_name"]
author = self._data["author"]
label_text = (
f"{strategy_name} - {vt_symbol} ({class_name} by {author})"
)
label = QtWidgets.QLabel(label_text)
label.setAlignment(QtCore.Qt.AlignCenter)
self.parameters_monitor = DataMonitor(self._data["parameters"])
self.variables_monitor = DataMonitor(self._data["variables"])
hbox = QtWidgets.QHBoxLayout()
hbox.addWidget(self.init_button)
hbox.addWidget(self.start_button)
hbox.addWidget(self.stop_button)
hbox.addWidget(self.edit_button)
hbox.addWidget(self.remove_button)
vbox = QtWidgets.QVBoxLayout()
vbox.addWidget(label)
vbox.addLayout(hbox)
vbox.addWidget(self.parameters_monitor)
vbox.addWidget(self.variables_monitor)
self.setLayout(vbox)
def update_data(self, data: dict):
""""""
self._data = data
self.parameters_monitor.update_data(data["parameters"])
self.variables_monitor.update_data(data["variables"])
# Update button status
variables = data["variables"]
inited = variables["inited"]
trading = variables["trading"]
if not inited:
return
self.init_button.setEnabled(False)
if trading:
self.start_button.setEnabled(False)
self.stop_button.setEnabled(True)
self.edit_button.setEnabled(False)
self.remove_button.setEnabled(False)
else:
self.start_button.setEnabled(True)
self.stop_button.setEnabled(False)
self.edit_button.setEnabled(True)
self.remove_button.setEnabled(True)
def init_strategy(self):
""""""
self.open_kx_chart() # hxxjava add
self.cta_engine.init_strategy(self.strategy_name)
def start_strategy(self):
""""""
self.cta_engine.start_strategy(self.strategy_name)
def stop_strategy(self):
""""""
self.cta_engine.stop_strategy(self.strategy_name)
def edit_strategy(self):
""""""
strategy_name = self._data["strategy_name"]
parameters = self.cta_engine.get_strategy_parameters(strategy_name)
editor = SettingEditor(parameters, strategy_name=strategy_name)
n = editor.exec_()
if n == editor.Accepted:
setting = editor.get_setting()
self.cta_engine.edit_strategy(strategy_name, setting)
def remove_strategy(self):
""""""
result = self.cta_engine.remove_strategy(self.strategy_name)
# Only remove strategy gui manager if it has been removed from engine
if result:
self.cta_manager.remove_strategy(self.strategy_name)
if self.kx_chart: # hxxjava add
self.kx_chart.close()
self.kx_chart = None
def open_kx_chart(self): # hxxjava add
strategy = self.cta_engine.strategies[self.strategy_name]
setting = self.cta_engine.strategy_setting[self.strategy_name]['setting']
show_chart = setting.get("show_chart",None)
self.kx_chart = None
if show_chart:
event_engine = self.cta_engine.event_engine
kx_interval = setting.get("kx_interval",None)
self.kx_chart = NewChartWidget(event_engine = event_engine,strategy_name = self.strategy_name)
self.kx_chart.setWindowTitle(f"K线图表:{self.strategy_name},周期:{kx_interval}")
strategy.init_kx_chart(self.kx_chart)
self.kx_chart.register_event() # 注册消息
self.kx_chart.show() # 显示K线图
class DataMonitor(QtWidgets.QTableWidget):
"""
Table monitor for parameters and variables.
"""
def __init__(self, data: dict):
""""""
super(DataMonitor, self).__init__()
self._data = data
self.cells = {}
self.init_ui()
def init_ui(self):
""""""
labels = list(self._data.keys())
self.setColumnCount(len(labels))
self.setHorizontalHeaderLabels(labels)
self.setRowCount(1)
self.verticalHeader().setSectionResizeMode(
QtWidgets.QHeaderView.Stretch
)
self.verticalHeader().setVisible(False)
self.setEditTriggers(self.NoEditTriggers)
for column, name in enumerate(self._data.keys()):
value = self._data[name]
cell = QtWidgets.QTableWidgetItem(str(value))
cell.setTextAlignment(QtCore.Qt.AlignCenter)
self.setItem(0, column, cell)
self.cells[name] = cell
def update_data(self, data: dict):
""""""
for name, value in data.items():
cell = self.cells[name]
cell.setText(str(value))
class StopOrderMonitor(BaseMonitor):
"""
Monitor for local stop order.
"""
event_type = EVENT_CTA_STOPORDER
data_key = "stop_orderid"
sorting = True
headers = {
"stop_orderid": {"display": "停止委托号","cell": BaseCell,"update": False,},
"vt_orderids": {"display": "限价委托号", "cell": BaseCell, "update": True},
"vt_symbol": {"display": "本地代码", "cell": BaseCell, "update": False},
"direction": {"display": "方向", "cell": EnumCell, "update": False},
"offset": {"display": "开平", "cell": EnumCell, "update": False},
"price": {"display": "价格", "cell": BaseCell, "update": False},
"volume": {"display": "数量", "cell": BaseCell, "update": False},
"status": {"display": "状态", "cell": EnumCell, "update": True},
"lock": {"display": "锁仓", "cell": BaseCell, "update": False},
"strategy_name": {"display": "策略名", "cell": BaseCell, "update": False},
}
class LogMonitor(BaseMonitor):
"""
Monitor for log data.
"""
event_type = EVENT_CTA_LOG
data_key = ""
sorting = False
headers = {
"time": {"display": "时间", "cell": TimeCell, "update": False},
"msg": {"display": "信息", "cell": MsgCell, "update": False},
}
def init_ui(self):
"""
Stretch last column.
"""
super(LogMonitor, self).init_ui()
self.horizontalHeader().setSectionResizeMode(
1, QtWidgets.QHeaderView.Stretch
)
def insert_new_row(self, data):
"""
Insert a new row at the top of table.
"""
super(LogMonitor, self).insert_new_row(data)
self.resizeRowToContents(0)
class SettingEditor(QtWidgets.QDialog):
"""
For creating new strategy and editing strategy parameters.
"""
def __init__(
self, parameters: dict, strategy_name: str = "", class_name: str = ""
):
""""""
super(SettingEditor, self).__init__()
self.parameters = parameters
self.strategy_name = strategy_name
self.class_name = class_name
self.edits = {}
self.init_ui()
def init_ui(self):
""""""
form = QtWidgets.QFormLayout()
# Add vt_symbol and name edit if add new strategy
if self.class_name:
self.setWindowTitle(f"添加策略:{self.class_name}")
button_text = "添加"
parameters = {"strategy_name": "", "vt_symbol": ""}
parameters.update(self.parameters)
else:
self.setWindowTitle(f"参数编辑:{self.strategy_name}")
button_text = "确定"
parameters = self.parameters
for name, value in parameters.items():
type_ = type(value)
edit = QtWidgets.QLineEdit(str(value))
if type_ is int:
validator = QtGui.QIntValidator()
edit.setValidator(validator)
elif type_ is float:
validator = QtGui.QDoubleValidator()
edit.setValidator(validator)
form.addRow(f"{name} {type_}", edit)
self.edits[name] = (edit, type_)
button = QtWidgets.QPushButton(button_text)
button.clicked.connect(self.accept)
form.addRow(button)
widget = QtWidgets.QWidget()
widget.setLayout(form)
scroll = QtWidgets.QScrollArea()
scroll.setWidgetResizable(True)
scroll.setWidget(widget)
vbox = QtWidgets.QVBoxLayout()
vbox.addWidget(scroll)
self.setLayout(vbox)
def get_setting(self):
""""""
setting = {}
if self.class_name:
setting["class_name"] = self.class_name
for name, tp in self.edits.items():
edit, type_ = tp
value_text = edit.text()
if type_ == bool:
if value_text == "True":
value = True
else:
value = False
else:
value = type_(value_text)
setting[name] = value
return setting
修改vnpy\app\cta_backtester\engine.py
import os
import importlib
import traceback
from datetime import datetime
from threading import Thread
from pathlib import Path
from inspect import getfile
from vnpy.event import Event, EventEngine
from vnpy.trader.engine import BaseEngine, MainEngine
from vnpy.trader.constant import Interval
from vnpy.trader.utility import extract_vt_symbol
from vnpy.trader.object import HistoryRequest
from vnpy.trader.rqdata import rqdata_client
from vnpy.trader.database import database_manager
from vnpy.app.cta_strategy import CtaTemplate
from vnpy.app.cta_strategy.backtesting import BacktestingEngine, OptimizationSetting
APP_NAME = "CtaBacktester"
EVENT_BACKTESTER_LOG = "eBacktesterLog"
EVENT_BACKTESTER_BACKTESTING_FINISHED = "eBacktesterBacktestingFinished"
EVENT_BACKTESTER_OPTIMIZATION_FINISHED = "eBacktesterOptimizationFinished"
from vnpy.app.cta_strategy.base import EngineType # hxxjava add
class BacktesterEngine(BaseEngine):
"""
For running CTA strategy backtesting.
"""
engine_type = EngineType.BACKTESTING # hxxjava add --- 供策略回测时使用
def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
""""""
super().__init__(main_engine, event_engine, APP_NAME)
self.classes = {}
self.backtesting_engine = None
self.thread = None
# Backtesting reuslt
self.result_df = None
self.result_statistics = None
# Optimization result
self.result_values = None
def init_engine(self):
""""""
self.write_log("初始化CTA回测引擎")
self.backtesting_engine = BacktestingEngine()
# Redirect log from backtesting engine outside.
self.backtesting_engine.output = self.write_log
self.load_strategy_class()
self.write_log("策略文件加载完成")
self.init_rqdata()
def init_rqdata(self):
"""
Init RQData client.
"""
result = rqdata_client.init()
if result:
self.write_log("RQData数据接口初始化成功")
def write_log(self, msg: str):
""""""
event = Event(EVENT_BACKTESTER_LOG)
event.data = msg
self.event_engine.put(event)
def load_strategy_class(self):
"""
Load strategy class from source code.
"""
app_path = Path(__file__).parent.parent
path1 = app_path.joinpath("cta_strategy", "strategies")
self.load_strategy_class_from_folder(
path1, "vnpy.app.cta_strategy.strategies")
path2 = Path.cwd().joinpath("strategies")
self.load_strategy_class_from_folder(path2, "strategies")
def load_strategy_class_from_folder(self, path: Path, module_name: str = ""):
"""
Load strategy class from certain folder.
"""
for dirpath, dirnames, filenames in os.walk(path):
for filename in filenames:
# Load python source code file
if filename.endswith(".py"):
strategy_module_name = ".".join(
[module_name, filename.replace(".py", "")])
self.load_strategy_class_from_module(strategy_module_name)
# Load compiled pyd binary file
elif filename.endswith(".pyd"):
strategy_module_name = ".".join(
[module_name, filename.split(".")[0]])
self.load_strategy_class_from_module(strategy_module_name)
def load_strategy_class_from_module(self, module_name: str):
"""
Load strategy class from module file.
"""
try:
module = importlib.import_module(module_name)
importlib.reload(module)
for name in dir(module):
value = getattr(module, name)
if (isinstance(value, type) and issubclass(value, CtaTemplate) and value is not CtaTemplate):
self.classes[value.__name__] = value
except: # noqa
msg = f"策略文件{module_name}加载失败,触发异常:\n{traceback.format_exc()}"
self.write_log(msg)
def reload_strategy_class(self):
""""""
self.classes.clear()
self.load_strategy_class()
self.write_log("策略文件重载刷新完成")
def get_strategy_class_names(self):
""""""
return list(self.classes.keys())
def run_backtesting(
self,
class_name: str,
vt_symbol: str,
interval: str,
start: datetime,
end: datetime,
rate: float,
slippage: float,
size: int,
pricetick: float,
capital: int,
inverse: bool,
setting: dict
):
""""""
self.result_df = None
self.result_statistics = None
engine = self.backtesting_engine
engine.clear_data()
engine.set_parameters(
vt_symbol=vt_symbol,
interval=interval,
start=start,
end=end,
rate=rate,
slippage=slippage,
size=size,
pricetick=pricetick,
capital=capital,
inverse=inverse
)
strategy_class = self.classes[class_name]
engine.add_strategy(
strategy_class,
setting
)
engine.load_data()
engine.run_backtesting()
self.result_df = engine.calculate_result()
self.result_statistics = engine.calculate_statistics(output=False)
# Clear thread object handler.
self.thread = None
# Put backtesting done event
event = Event(EVENT_BACKTESTER_BACKTESTING_FINISHED)
self.event_engine.put(event)
def start_backtesting(
self,
class_name: str,
vt_symbol: str,
interval: str,
start: datetime,
end: datetime,
rate: float,
slippage: float,
size: int,
pricetick: float,
capital: int,
inverse: bool,
setting: dict
):
if self.thread:
self.write_log("已有任务在运行中,请等待完成")
return False
self.write_log("-" * 40)
self.thread = Thread(
target=self.run_backtesting,
args=(
class_name,
vt_symbol,
interval,
start,
end,
rate,
slippage,
size,
pricetick,
capital,
inverse,
setting
)
)
self.thread.start()
return True
def get_result_df(self):
""""""
return self.result_df
def get_result_statistics(self):
""""""
return self.result_statistics
def get_result_values(self):
""""""
return self.result_values
def get_default_setting(self, class_name: str):
""""""
strategy_class = self.classes[class_name]
return strategy_class.get_class_parameters()
def run_optimization(
self,
class_name: str,
vt_symbol: str,
interval: str,
start: datetime,
end: datetime,
rate: float,
slippage: float,
size: int,
pricetick: float,
capital: int,
inverse: bool,
optimization_setting: OptimizationSetting,
use_ga: bool
):
""""""
if use_ga:
self.write_log("开始遗传算法参数优化")
else:
self.write_log("开始多进程参数优化")
self.result_values = None
engine = self.backtesting_engine
engine.clear_data()
engine.set_parameters(
vt_symbol=vt_symbol,
interval=interval,
start=start,
end=end,
rate=rate,
slippage=slippage,
size=size,
pricetick=pricetick,
capital=capital,
inverse=inverse
)
strategy_class = self.classes[class_name]
engine.add_strategy(
strategy_class,
{}
)
if use_ga:
self.result_values = engine.run_ga_optimization(
optimization_setting,
output=False
)
else:
self.result_values = engine.run_optimization(
optimization_setting,
output=False
)
# Clear thread object handler.
self.thread = None
self.write_log("多进程参数优化完成")
# Put optimization done event
event = Event(EVENT_BACKTESTER_OPTIMIZATION_FINISHED)
self.event_engine.put(event)
def start_optimization(
self,
class_name: str,
vt_symbol: str,
interval: str,
start: datetime,
end: datetime,
rate: float,
slippage: float,
size: int,
pricetick: float,
capital: int,
inverse: bool,
optimization_setting: OptimizationSetting,
use_ga: bool
):
if self.thread:
self.write_log("已有任务在运行中,请等待完成")
return False
self.write_log("-" * 40)
self.thread = Thread(
target=self.run_optimization,
args=(
class_name,
vt_symbol,
interval,
start,
end,
rate,
slippage,
size,
pricetick,
capital,
inverse,
optimization_setting,
use_ga
)
)
self.thread.start()
return True
def run_downloading(
self,
vt_symbol: str,
interval: str,
start: datetime,
end: datetime
):
"""
Query bar data from RQData.
"""
self.write_log(f"{vt_symbol}-{interval}开始下载历史数据")
try:
symbol, exchange = extract_vt_symbol(vt_symbol)
except ValueError:
self.write_log(f"{vt_symbol}解析失败,请检查交易所后缀")
self.thread = None
return
req = HistoryRequest(
symbol=symbol,
exchange=exchange,
interval=Interval(interval),
start=start,
end=end
)
contract = self.main_engine.get_contract(vt_symbol)
try:
# If history data provided in gateway, then query
if contract and contract.history_data:
data = self.main_engine.query_history(
req, contract.gateway_name
)
# Otherwise use RQData to query data
else:
data = rqdata_client.query_history(req)
if data:
database_manager.save_bar_data(data)
self.write_log(f"{vt_symbol}-{interval}历史数据下载完成")
else:
self.write_log(f"数据下载失败,无法获取{vt_symbol}的历史数据")
except Exception:
msg = f"数据下载失败,触发异常:\n{traceback.format_exc()}"
self.write_log(msg)
# Clear thread object handler.
self.thread = None
def start_downloading(
self,
vt_symbol: str,
interval: str,
start: datetime,
end: datetime
):
if self.thread:
self.write_log("已有任务在运行中,请等待完成")
return False
self.write_log("-" * 40)
self.thread = Thread(
target=self.run_downloading,
args=(
vt_symbol,
interval,
start,
end
)
)
self.thread.start()
return True
def get_all_trades(self):
""""""
return self.backtesting_engine.get_all_trades()
def get_all_orders(self):
""""""
return self.backtesting_engine.get_all_orders()
def get_all_daily_results(self):
""""""
return self.backtesting_engine.get_all_daily_results()
def get_history_data(self):
""""""
return self.backtesting_engine.history_data
def get_strategy_class_file(self, class_name: str):
""""""
strategy_class = self.classes[class_name]
file_path = getfile(strategy_class)
return file_path
修改用户策略KxMonitor.py
from typing import Any,List,Dict,Tuple
import copy
from vnpy.app.cta_strategy import (
CtaTemplate,
BarGenerator,
ArrayManager,
StopOrder,
Direction
)
from vnpy.trader.engine import MainEngine,EventEngine
from vnpy.app.cta_strategy.engine import CtaEngine
from vnpy.event.engine import Event
from vnpy.trader.object import (
LogData,
TickData,
BarData,
TradeData,
OrderData,
)
from vnpy.app.cta_strategy import StopOrder
from vnpy.app.cta_strategy.base import EngineType
from vnpy.trader.constant import Interval
from vnpy.app.cta_strategy.base import (
APP_NAME,
EVENT_CTA_LOG,
EVENT_CTA_TICK,
EVENT_CTA_HISTORY_BAR,
EVENT_CTA_BAR,
EVENT_CTA_ORDER,
EVENT_CTA_TRADE,
EVENT_CTA_STOPORDER,
EVENT_CTA_STRATEGY,
)
from vnpy.usertools.kx_chart import ( # hxxjava add
NewChartWidget,
CandleItem,
VolumeItem,
LineItem,
SmaItem,
RsiItem,
MacdItem,
)
from vnpy.usertools.kx_chart import NewChartWidget # hxxjava add
class _kx_strategy(CtaTemplate):
""""""
author = "hxxjava"
kx_interval = 1
show_chart = False # 显示K线图表
parameters = [
"kx_interval",
"show_chart"
]
kx_count:int = 0
cta_manager = None
variables = ["kx_count"]
def __init__(
self,
cta_engine: Any,
strategy_name: str,
vt_symbol: str,
setting: dict,
):
super().__init__(cta_engine,strategy_name,vt_symbol,setting)
self.bg = BarGenerator(self.on_bar,self.kx_interval,self.on_Nmin_bar)
self.am = ArrayManager()
cta_engine:CtaEngine = self.cta_engine
self.engine_type = cta_engine.engine_type
self.even_engine = cta_engine.main_engine.event_engine
# 必须在这里声明,因为它们是实例变量
self.all_bars:List[BarData] = []
self.current_tick:[TickData] = None
self.current_bar:[BarData] = None
self.last_tick:[TickData] = None
def on_init(self):
"""
Callback when strategy is inited.
"""
self.load_bar(20)
if len(self.all_bars)>0:
self.send_event(EVENT_CTA_HISTORY_BAR,self.all_bars)
def on_start(self):
""" """
self.write_log("已开始")
def on_stop(self):
""""""
self.write_log("_kx_strategy 已停止")
def on_tick(self, tick: TickData):
"""
Callback of new tick data update.
"""
self.current_tick = tick # 记录最新tick
if self.inited:
# 先产生当前临时K线
self.cur_bar = self.get_cur_bar(tick)
if self.cur_bar:
# 发送当前临时K线更新消息
self.send_event(EVENT_CTA_BAR,self.cur_bar)
# 再更新tick,产生1分钟K线乃至N 分钟线
self.bg.update_tick(tick)
self.send_event(EVENT_CTA_TICK,tick)
self.last_tick = tick
def on_bar(self, bar: BarData):
"""
Callback of new bar data update.
"""
if self.inited:
self.write_log(f"I got a 1min BarData")
self.bg.update_bar(bar)
def on_Nmin_bar(self, bar: BarData):
"""
Callback of new bar data update.
"""
self.all_bars.append(bar)
self.kx_count = len(self.all_bars)
if self.inited:
self.write_log(f"I got a {self.kx_interval}min BarData")
self.send_event(EVENT_CTA_BAR,bar)
if self.current_tick:
# 当新N分钟K线产生的时候,立即产生新的临时K线
self.current_bar = None
self.get_cur_bar(self.current_tick)
self.put_event()
def on_trade(self, trade: TradeData):
"""
Callback of new trade data update.
"""
self.send_event(EVENT_CTA_TRADE,trade)
def on_order(self, order: OrderData):
"""
Callback of new order data update.
"""
self.send_event(EVENT_CTA_ORDER,order)
def on_stop_order(self, stop_order: StopOrder):
"""
Callback of stop order update.
"""
self.send_event(EVENT_CTA_STOPORDER,stop_order)
def get_cur_bar(self,tick:TickData)->BarData:
"""
产生临时K线,每个tick都会更新。除非把self.window_bar赋值为None,
不会产生新的K线,只会更新K线的量和加。
注意:self.last_tick是在BarGenerator中声明和改变的
"""
if not self.inited or not self.last_tick:
return None
if self.last_tick and tick.datetime < self.last_tick.datetime:
return None
if not self.current_bar:
# Generate timestamp for bar data
if self.bg.interval == Interval.MINUTE:
dt = tick.datetime.replace(second=0, microsecond=0)
else:
dt = tick.datetime.replace(minute=0, second=0, microsecond=0)
self.current_bar = BarData(
symbol=tick.symbol,
exchange=tick.exchange,
datetime=dt,
gateway_name=tick.gateway_name,
open_price=tick.last_price,
high_price=tick.last_price,
low_price=tick.last_price,
)
# Otherwise, update high/low price into window bar
else:
self.current_bar.high_price = max(self.current_bar.high_price, tick.last_price)
self.current_bar.low_price = min(self.current_bar.low_price, tick.last_price)
# Update last price/volume into window bar
self.current_bar.close_price = tick.last_price
volume_change = tick.volume - self.last_tick.volume
self.current_bar.volume += volume_change
self.current_bar.open_interest = tick.open_interest
return copy.deepcopy(self.current_bar)
def send_event(self,event_type:str,data:Any):
if self.engine_type==EngineType.LIVE and self.show_chart: # "如果显示K线图表"
self.even_engine.put(Event(event_type,(self.strategy_name,data)))
def init_kx_chart(self,kx_chart:NewChartWidget=None): # hxxjava add ----- 提供给外部调用
# self.write_log("init_kx_chart executed !!!")
if kx_chart:
kx_chart.add_plot("candle", hide_x_axis=True)
kx_chart.add_plot("volume", maximum_height=150)
kx_chart.add_plot("rsi", maximum_height=150)
kx_chart.add_plot("macd", maximum_height=150)
kx_chart.add_item(CandleItem, "candle", "candle")
kx_chart.add_item(VolumeItem, "volume", "volume")
kx_chart.add_item(LineItem, "line", "candle")
kx_chart.add_item(SmaItem, "sma", "candle")
kx_chart.add_item(RsiItem, "rsi", "rsi")
kx_chart.add_item(MacdItem, "macd", "macd")
kx_chart.add_last_price_line()
kx_chart.add_cursor()
添加可以显示的策略
启动VnTrader,进入策略管理界面,完成如下步骤:
1)从策略下拉框中选择_kx_strategy策略
2)点击添加策略按钮进入3界面
3)输入策略名称、vt_symbol、kx_interval和show_chart的值,注意kx_interval这里是你想要的K线周期,单位是分钟。show_chart参数为True标识需要显示K线图表,其他值则不显示。
4)初始化策略,如果参数为True的话,完成后显示K线图表窗口,并且显示20日里的历史K线图
5)按启动按钮启动策略,如果是交易时段,则K线图表就会显示最新收到的K线。提示还会实时显示未完成的临时K线