1. 合约信息是一种基础信息,应该在本地持久化保存
策略中使用到交易合约的合约信息是非常困难发生的事情,例如使用其中的合约乘数、最小交易交易单位、最小价格变动等 ... ... 。
如果您的策略中使用到交易合约的合约信息时,当您在休市的时候无法连接Gateway,那么你在回测的时候将无法进行下去。
怎么办,坐等交易所开市吗?本帖就讨论这样一个问题。
2. 解决方法
2.1 合约信息包含哪些内容
vnpy\trader\object.py中ContractData是这样定义的:
@dataclass
class ContractData(BaseData):
"""
Contract data contains basic information about each contract traded.
"""
symbol: str
exchange: Exchange
name: str
product: Product
size: float
pricetick: float
open_date:datetime = None # 上市日 open date hxxjava add
expire_date:datetime = None # 到期日 expire date hxxjava add
min_volume: float = 1 # minimum trading volume of the contract
stop_supported: bool = False # whether server supports stop order
net_position: bool = False # whether gateway uses net position volume
history_data: bool = False # whether gateway provides bar history data
option_strike: float = 0
option_underlying: str = "" # vt_symbol of underlying contract
option_type: OptionType = None
option_listed: datetime = None
option_expiry: datetime = None
option_portfolio: str = ""
option_index: str = "" # for identifying options with same strike price
def __post_init__(self):
""""""
self.vt_symbol = f"{self.symbol}.{self.exchange.value}"
2.2 修改一下CTP的网关:
修改vnpy_ctp\gateway\ctp_gateway.py中的CtpTdApi接口onRspQryInstrument():
def onRspQryInstrument(self, data: dict, error: dict, reqid: int, last: bool) -> None:
"""合约查询回报"""
product: Product = PRODUCT_CTP2VT.get(data["ProductClass"], None)
if product:
contract: ContractData = ContractData(
symbol=data["InstrumentID"],
exchange=EXCHANGE_CTP2VT[data["ExchangeID"]],
name=data["InstrumentName"],
product=product,
size=data["VolumeMultiple"],
pricetick=data["PriceTick"],
open_date=datetime.strptime(data["OpenDate"], "%Y%m%d"), # hxxjava add
expire_date=datetime.strptime(data["ExpireDate"], "%Y%m%d"), # hxxjava add
gateway_name=self.gateway_name
)
# 期权相关
if contract.product == Product.OPTION:
# 移除郑商所期权产品名称带有的C/P后缀
if contract.exchange == Exchange.CZCE:
contract.option_portfolio = data["ProductID"][:-1]
else:
contract.option_portfolio = data["ProductID"]
contract.option_underlying = data["UnderlyingInstrID"]
contract.option_type = OPTIONTYPE_CTP2VT.get(data["OptionsType"], None)
contract.option_strike = data["StrikePrice"]
contract.option_index = str(data["StrikePrice"])
contract.option_listed = datetime.strptime(data["OpenDate"], "%Y%m%d")
contract.option_expiry = datetime.strptime(data["ExpireDate"], "%Y%m%d")
self.gateway.on_contract(contract)
symbol_contract_map[contract.symbol] = contract
if last:
self.contract_inited = True
self.gateway.write_log("合约信息查询成功")
for data in self.order_data:
self.onRtnOrder(data)
self.order_data.clear()
for data in self.trade_data:
self.onRtnTrade(data)
self.trade_data.clear()
2.3 选择用数据库保存合约信息
在vnpy\trader\database.py中为BaseDatabase增加两个与合约信息相关的接口:
@abstractmethod
def save_constract_data(self, constracts: List[ContractData]) -> bool:
"""
Save constract data into database. hxxjava add
"""
pass
@abstractmethod
def load_contract_data(self,vt_symbol:str="") -> List[ContractData]:
"""
Load constract data from database. hxxjava add
"""
pass
2.4 为MySqlDatabase扩展save_constract_data()和load_contract_data():
本人使用的是MySql Server,所以就在vnpy_mysql\mysql_database.py中扩展这两个接口:
引用部分添加:
from vnpy.trader.constant import Exchange, Interval, Product # hxxjava add Product
定义合约信息数据表模型
class DbContractData(Model): # hxxjava add
"""K线数据表映射对象"""
id = AutoField()
symbol: str = CharField()
exchange: str = CharField()
name : str = CharField()
product : str = CharField()
size : float = FloatField()
pricetick : float = FloatField()
open_date: datetime = DateTimeField(default=None)
expire_date: datetime = DateTimeField(default=None)
min_volume : float = FloatField(default=1)
stop_supported : bool = BooleanField(default=False)
net_position : bool = BooleanField(default=False)
history_data : bool = BooleanField(default=False)
option_strike : float = FloatField(default=0)
option_underlying: str = CharField(default="")
option_type: str = CharField(default="")
option_listed: datetime = DateTimeField(default=None)
option_expiry: datetime = DateTimeField(default=None)
option_portfolio: str = CharField(default="")
option_index: str = CharField(default="")
gateway_name:str = CharField()
class Meta:
database = db
indexes = ((("open_date","exchange","symbol"), True),)
修改MysqlDatabase的init(),添加上面两个接口函数:
class MysqlDatabase(BaseDatabase):
"""Mysql数据库接口"""
def __init__(self) -> None:
""""""
self.db = db
self.db.connect()
self.db.create_tables([DbContractData, DbBarData, DbTickData, DbBarOverview]) # hxxjava add DbContractData
def save_constract_data(self, contracts: List[ContractData]) -> bool:
"""
Save constract data into database. hxxjava add
"""
# 将constracts数据转换为字典
data = []
for contract in contracts:
copy_c = deepcopy(contract)
d = copy_c.__dict__
d["exchange"] = d["exchange"].value
d["product"] = d["product"].value
d.pop("vt_symbol")
data.append(d)
# 使用upsert操作将数据更新到数据库中
with self.db.atomic():
for c in chunked(data, 50):
DbContractData.insert_many(c).on_conflict_replace().execute()
return True
def load_contract_data(self,vt_symbol:str="") -> List[ContractData]:
"""
Load constract data from database. hxxjava add
"""
symbol,exchange = "",""
if vt_symbol:
symbol,exchange = vt_symbol.split('.')
s: ModelSelect = DbContractData.select().where(
(not symbol or DbContractData.symbol == symbol)
& (not symbol or DbContractData.exchange == exchange)
).order_by(DbContractData.open_date,DbContractData.exchange,DbContractData.symbol)
contracts: List[ContractData] = []
for db_c in s:
# 取出四个时间
open_date = datetime.fromtimestamp(db_c.open_date.timestamp(), DB_TZ)
expire_date = datetime.fromtimestamp(db_c.expire_date.timestamp(), DB_TZ)
option_listed = None
option_expiry = None
product = Product(db_c.product)
if product == Product.OPTION:
option_listed = datetime.fromtimestamp(db_c.option_listed.timestamp(), DB_TZ)
option_expiry = datetime.fromtimestamp(db_c.option_expiry.timestamp(), DB_TZ)
contract = ContractData(
symbol = db_c.symbol,
exchange = Exchange(db_c.exchange),
name = db_c.name,
product = Product(db_c.product),
size = db_c.size,
pricetick = db_c.pricetick,
open_date = open_date,
expire_date = expire_date,
min_volume = db_c.min_volume,
stop_supported = db_c.stop_supported,
net_position = db_c.net_position,
history_data = db_c.history_data,
option_strike = db_c.option_strike,
option_underlying = db_c.option_underlying,
option_type = db_c.option_type,
option_listed = option_listed,
option_expiry = option_expiry,
option_portfolio = db_c.option_portfolio,
option_index = db_c.option_index,
gateway_name = db_c.gateway_name,
)
contracts.append(contract)
return contracts
2.5 修改OmsEngine,使之能够持久化保存合约信息
修改vnpy\trader\engine.py做如下修改:
2.5.1 添加引用
from threading import Thread # hxxjava add
from copy import deepcopy # hxxjava add
from .database import get_database # hxxjava add
2.5.2 为OmsEngine添加下面函数:
修改OmsEngine的init():
class OmsEngine(BaseEngine):
"""
Provides order management system function.
"""
contract_file = "contracts.json" # hxxjava add
def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
""""""
super(OmsEngine, self).__init__(main_engine, event_engine, "oms")
self.ticks: Dict[str, TickData] = {}
self.orders: Dict[str, OrderData] = {}
self.trades: Dict[str, TradeData] = {}
self.positions: Dict[str, PositionData] = {}
self.accounts: Dict[str, AccountData] = {}
self.contracts: Dict[str, ContractData] = {}
self.quotes: Dict[str, QuoteData] = {}
self.active_orders: Dict[str, OrderData] = {}
self.active_quotes: Dict[str, QuoteData] = {}
self.db = get_database()
self.load_contracts() # hxxjava add 启动时就从数据库读取所有合约信息
self.add_function()
self.register_event()
在收到账户信息时才保存变化的合约信息。
def process_account_event(self, event: Event) -> None:
""""""
account: AccountData = event.data
self.accounts[account.vt_accountid] = account
self.save_contracts_changed() # hxxjava add
def process_contract_event(self, event: Event) -> None:
""""""
contract: ContractData = event.data
# self.contracts[contract.vt_symbol] = contract
# hxxjava change
if contract.vt_symbol not in self.contracts:
self.contracts[contract.vt_symbol] = contract
self.contracts_changed.append(contract)
启动线程对变化的合约进行增量保存,否则会堵塞系统的消息循环。
def save_contracts_changed(self):
""" 启动线程保存所有新增的合约 """
# 复制所有新增的合约
contracts = [deepcopy(contract) for contract in self.contracts_changed]
self.contracts_changed = []
if contracts:
# 如果有新增的合约,启动线程保存
t = Thread(target=self.save_contracts,kwargs=({"contracts":contracts}))
t.start()
def save_contracts(self,contracts:List): # hxxjava add
""" save contracts into database """
self.db.save_constract_data(contracts)
print(f"一共保存了{len(contracts)}个合约 !")
def load_contracts(self): # hxxjava add
""" save all contracts into a json file """
contracts = self.db.load_contract_data()
for c in contracts:
self.contracts[c.vt_symbol] = c
self.contracts_changed:List[ContractData] = []
print(f"一共读取了{len(contracts)}个合约 !")
# print(f"self.contracts={self.contracts}个合约 !")
2.6 为CtaEngine增加获取合约信息函数get_contract()
修改vnpy_ctastrategy\engine.py中的CtaEngine,添加get_contract():
def get_contract(self, strategy: CtaTemplate) -> Optional[ContractData]: # hxxjava add
"""
Get strategy's contract data.
"""
return self.main_engine.get_contract(strategy.vt_symbol)
2.7 为CTA策略模板添加合约信息获取函数get_contract()
修改vnpy_ctastrategy\template.py中的CtaTemplate,添加get_contract():
def get_contract(self):
"""
Return trading_hours of trading contract. # hxxjava add
"""
return self.cta_engine.get_contract(self)
3 策略如何使用get_contract()?
经过什么这么复杂的步骤,就赋予了您的CTA策略一种直接读取交易合约的合约信息的能力。
如何使用呢?非常简单,这里给出一段用户策略的例子代码:
""" 得到交易合约的合约信息 """
contract:ContractData = self.get_contract()
size_n = contract.size # 合约乘数
price_tick = contract.price_tick # 最小价格变动