1. 合约信息是一种基础信息,应该在本地持久化保存

策略中使用到交易合约的合约信息是非常困难发生的事情,例如使用其中的合约乘数、最小交易交易单位、最小价格变动等 ... ... 。
如果您的策略中使用到交易合约的合约信息时,当您在休市的时候无法连接Gateway,那么你在回测的时候将无法进行下去。
怎么办,坐等交易所开市吗?本帖就讨论这样一个问题。

2. 解决方法

2.1 合约信息包含哪些内容

vnpy\trader\object.py中ContractData是这样定义的:

@dataclass
class ContractData(BaseData):
    """
    Contract data contains basic information about each contract traded.
    """

    symbol: str
    exchange: Exchange
    name: str
    product: Product
    size: float
    pricetick: float

    open_date:datetime = None       # 上市日    open date   hxxjava add
    expire_date:datetime = None     # 到期日    expire date hxxjava add

    min_volume: float = 1           # minimum trading volume of the contract
    stop_supported: bool = False    # whether server supports stop order
    net_position: bool = False      # whether gateway uses net position volume
    history_data: bool = False      # whether gateway provides bar history data

    option_strike: float = 0
    option_underlying: str = ""     # vt_symbol of underlying contract
    option_type: OptionType = None
    option_listed: datetime = None
    option_expiry: datetime = None
    option_portfolio: str = ""
    option_index: str = ""          # for identifying options with same strike price

    def __post_init__(self):
        """"""
        self.vt_symbol = f"{self.symbol}.{self.exchange.value}"

2.2 修改一下CTP的网关:

修改vnpy_ctp\gateway\ctp_gateway.py中的CtpTdApi接口onRspQryInstrument():

    def onRspQryInstrument(self, data: dict, error: dict, reqid: int, last: bool) -> None:
        """合约查询回报"""
        product: Product = PRODUCT_CTP2VT.get(data["ProductClass"], None)
        if product:
            contract: ContractData = ContractData(
                symbol=data["InstrumentID"],
                exchange=EXCHANGE_CTP2VT[data["ExchangeID"]],
                name=data["InstrumentName"],
                product=product,
                size=data["VolumeMultiple"],
                pricetick=data["PriceTick"],
                open_date=datetime.strptime(data["OpenDate"], "%Y%m%d"),        # hxxjava add
                expire_date=datetime.strptime(data["ExpireDate"], "%Y%m%d"),    # hxxjava add
                gateway_name=self.gateway_name
            )

            # 期权相关
            if contract.product == Product.OPTION:
                # 移除郑商所期权产品名称带有的C/P后缀
                if contract.exchange == Exchange.CZCE:
                    contract.option_portfolio = data["ProductID"][:-1]
                else:
                    contract.option_portfolio = data["ProductID"]

                contract.option_underlying = data["UnderlyingInstrID"]
                contract.option_type = OPTIONTYPE_CTP2VT.get(data["OptionsType"], None)
                contract.option_strike = data["StrikePrice"]
                contract.option_index = str(data["StrikePrice"])
                contract.option_listed = datetime.strptime(data["OpenDate"], "%Y%m%d")
                contract.option_expiry = datetime.strptime(data["ExpireDate"], "%Y%m%d")

            self.gateway.on_contract(contract)

            symbol_contract_map[contract.symbol] = contract

        if last:
            self.contract_inited = True
            self.gateway.write_log("合约信息查询成功")

            for data in self.order_data:
                self.onRtnOrder(data)
            self.order_data.clear()

            for data in self.trade_data:
                self.onRtnTrade(data)
            self.trade_data.clear()

2.3 选择用数据库保存合约信息

在vnpy\trader\database.py中为BaseDatabase增加两个与合约信息相关的接口:

    @abstractmethod
    def save_constract_data(self, constracts: List[ContractData]) -> bool:
        """
        Save constract data into database.    hxxjava add
        """
        pass

    @abstractmethod
    def load_contract_data(self,vt_symbol:str="") -> List[ContractData]:
        """
        Load constract data from database.    hxxjava add
        """
        pass

2.4 为MySqlDatabase扩展save_constract_data()和load_contract_data():

本人使用的是MySql Server,所以就在vnpy_mysql\mysql_database.py中扩展这两个接口:

引用部分添加:

from vnpy.trader.constant import Exchange, Interval, Product    # hxxjava add Product

定义合约信息数据表模型

class DbContractData(Model):     # hxxjava add
    """K线数据表映射对象"""

    id = AutoField()

    symbol: str = CharField()
    exchange: str = CharField()
    name : str = CharField()
    product : str = CharField()
    size : float = FloatField()
    pricetick : float = FloatField()

    open_date: datetime = DateTimeField(default=None)
    expire_date: datetime = DateTimeField(default=None)

    min_volume : float = FloatField(default=1)
    stop_supported : bool = BooleanField(default=False)
    net_position : bool = BooleanField(default=False)
    history_data : bool = BooleanField(default=False)

    option_strike : float = FloatField(default=0)
    option_underlying: str = CharField(default="")
    option_type: str = CharField(default="")
    option_listed: datetime = DateTimeField(default=None)
    option_expiry: datetime = DateTimeField(default=None)

    option_portfolio: str = CharField(default="")
    option_index: str = CharField(default="")

    gateway_name:str = CharField()

    class Meta:
        database = db
        indexes = ((("open_date","exchange","symbol"), True),)

修改MysqlDatabase的init(),添加上面两个接口函数:

class MysqlDatabase(BaseDatabase):
    """Mysql数据库接口"""

    def __init__(self) -> None:
        """"""
        self.db = db
        self.db.connect()
        self.db.create_tables([DbContractData, DbBarData, DbTickData, DbBarOverview])   # hxxjava add DbContractData

    def save_constract_data(self, contracts: List[ContractData]) -> bool:
        """
        Save constract data into database.    hxxjava add
        """
        # 将constracts数据转换为字典
        data = []

        for contract in contracts:
            copy_c = deepcopy(contract)
            d = copy_c.__dict__
            d["exchange"] = d["exchange"].value
            d["product"] = d["product"].value
            d.pop("vt_symbol")
            data.append(d)

        # 使用upsert操作将数据更新到数据库中
        with self.db.atomic():
            for c in chunked(data, 50):
                DbContractData.insert_many(c).on_conflict_replace().execute()

        return True

    def load_contract_data(self,vt_symbol:str="") -> List[ContractData]:
        """
        Load constract data from database.    hxxjava add
        """
        symbol,exchange = "",""
        if vt_symbol:
            symbol,exchange = vt_symbol.split('.')

        s: ModelSelect = DbContractData.select().where(
                (not symbol or DbContractData.symbol == symbol)
                & (not symbol or DbContractData.exchange == exchange)
                ).order_by(DbContractData.open_date,DbContractData.exchange,DbContractData.symbol)

        contracts: List[ContractData] = []
        for db_c in s:
            # 取出四个时间
            open_date = datetime.fromtimestamp(db_c.open_date.timestamp(), DB_TZ)
            expire_date = datetime.fromtimestamp(db_c.expire_date.timestamp(), DB_TZ)

            option_listed = None
            option_expiry = None

            product = Product(db_c.product)
            if product == Product.OPTION:               
                option_listed = datetime.fromtimestamp(db_c.option_listed.timestamp(), DB_TZ)
                option_expiry = datetime.fromtimestamp(db_c.option_expiry.timestamp(), DB_TZ)

            contract = ContractData(
                    symbol = db_c.symbol,
                    exchange = Exchange(db_c.exchange),
                    name = db_c.name,
                    product = Product(db_c.product),
                    size = db_c.size,
                    pricetick = db_c.pricetick,

                    open_date = open_date,
                    expire_date = expire_date,

                    min_volume = db_c.min_volume,
                    stop_supported = db_c.stop_supported,
                    net_position = db_c.net_position,
                    history_data = db_c.history_data,

                    option_strike = db_c.option_strike,
                    option_underlying = db_c.option_underlying,
                    option_type = db_c.option_type,

                    option_listed = option_listed,
                    option_expiry = option_expiry,

                    option_portfolio = db_c.option_portfolio,
                    option_index = db_c.option_index,
                    gateway_name = db_c.gateway_name,
            )
            contracts.append(contract)

        return contracts

2.5 修改OmsEngine,使之能够持久化保存合约信息

修改vnpy\trader\engine.py做如下修改:

2.5.1 添加引用

from threading import Thread            # hxxjava add
from copy import deepcopy               # hxxjava add
from .database import get_database      # hxxjava add

2.5.2 为OmsEngine添加下面函数:

修改OmsEngine的init():

class OmsEngine(BaseEngine):
    """
    Provides order management system function.
    """

    contract_file = "contracts.json"    # hxxjava add

    def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
        """"""
        super(OmsEngine, self).__init__(main_engine, event_engine, "oms")

        self.ticks: Dict[str, TickData] = {}
        self.orders: Dict[str, OrderData] = {}
        self.trades: Dict[str, TradeData] = {}
        self.positions: Dict[str, PositionData] = {}
        self.accounts: Dict[str, AccountData] = {}
        self.contracts: Dict[str, ContractData] = {}
        self.quotes: Dict[str, QuoteData] = {}

        self.active_orders: Dict[str, OrderData] = {}
        self.active_quotes: Dict[str, QuoteData] = {}

        self.db = get_database()

        self.load_contracts()   # hxxjava add   启动时就从数据库读取所有合约信息

        self.add_function()
        self.register_event()

在收到账户信息时才保存变化的合约信息。

    def process_account_event(self, event: Event) -> None:
        """"""
        account: AccountData = event.data
        self.accounts[account.vt_accountid] = account

        self.save_contracts_changed()   # hxxjava add

    def process_contract_event(self, event: Event) -> None:
        """"""
        contract: ContractData = event.data
        # self.contracts[contract.vt_symbol] = contract
        # hxxjava change
        if contract.vt_symbol not in self.contracts:
            self.contracts[contract.vt_symbol] = contract
            self.contracts_changed.append(contract)

启动线程对变化的合约进行增量保存,否则会堵塞系统的消息循环。

    def save_contracts_changed(self):
        """ 启动线程保存所有新增的合约 """
        # 复制所有新增的合约
        contracts = [deepcopy(contract) for contract in self.contracts_changed]
        self.contracts_changed = []

        if contracts:
            # 如果有新增的合约,启动线程保存
            t = Thread(target=self.save_contracts,kwargs=({"contracts":contracts}))
            t.start()

    def save_contracts(self,contracts:List):   # hxxjava add
        """ save contracts into database """
        self.db.save_constract_data(contracts)
        print(f"一共保存了{len(contracts)}个合约 !")

    def load_contracts(self):   # hxxjava add
        """ save all contracts into a json file """      
        contracts = self.db.load_contract_data()
        for c in contracts:
            self.contracts[c.vt_symbol] = c

        self.contracts_changed:List[ContractData] = []

        print(f"一共读取了{len(contracts)}个合约 !")
        # print(f"self.contracts={self.contracts}个合约 !")

2.6 为CtaEngine增加获取合约信息函数get_contract()

修改vnpy_ctastrategy\engine.py中的CtaEngine,添加get_contract():

    def get_contract(self, strategy: CtaTemplate) -> Optional[ContractData]:    # hxxjava add
        """
        Get strategy's contract data.
        """
        return self.main_engine.get_contract(strategy.vt_symbol)

2.7 为CTA策略模板添加合约信息获取函数get_contract()

修改vnpy_ctastrategy\template.py中的CtaTemplate,添加get_contract():

    def get_contract(self):
        """
        Return trading_hours of trading contract.   # hxxjava add
        """
        return self.cta_engine.get_contract(self)

3 策略如何使用get_contract()?

经过什么这么复杂的步骤,就赋予了您的CTA策略一种直接读取交易合约的合约信息的能力。
如何使用呢?非常简单,这里给出一段用户策略的例子代码:

        """ 得到交易合约的合约信息 """              
        contract:ContractData = self.get_contract()
        size_n = contract.size   # 合约乘数
        price_tick = contract.price_tick # 最小价格变动

本人这么修改了之后,觉得非常方便,呼吁官方能够把这样的功能合并到系统中,惠及广大vnpy会员。