vn.py量化社区
By Traders, For Traders.

置顶主题

获得属于自己的保证金率和手续费(率)

1. 合约信息中包含保证金率

1.1 合约信息查询命令:

ReqQryInstrument : 请求查询合约,填空可以查询到所有合约。
响应:OnRspQryInstrument
◇ 1.函数原型
virtual int ReqQryInstrument(CThostFtdcQryInstrumentField *pQryInstrument, int nRequestID) = 0;
◇ 2.参数
pQryInstrument:查询合约
struct CThostFtdcQryInstrumentField
{
    TThostFtdcInstrumentIDType InstrumentID; ///合约代码
    TThostFtdcExchangeIDType ExchangeID; ///交易所代码
    TThostFtdcExchangeInstIDType ExchangeInstID; ///合约在交易所的代码
    TThostFtdcInstrumentIDType ProductID;///产品代码
};
nRequestID:请求ID,对应响应里的nRequestID,无递增规则,由用户自行维护。
◇ 3.返回
0,代表成功。
-1,表示网络连接失败;
-2,表示未处理请求超过许可数;
-3,表示每秒发送请求数超过许可数。

1.2 合约信息查询结果:

请求查询合约响应,当执行ReqQryInstrument后,该方法被调用。
◇ 1.函数原型
virtual void OnRspQryInstrument(CThostFtdcInstrumentField *pInstrument, CThostFtdcRspInfoField *pRspInfo, int nRequestID, bool bIsLast) {};
◇ 2.参数pInstrument:
合约
struct CThostFtdcInstrumentField
{
    TThostFtdcInstrumentIDType InstrumentID;///合约代码
    TThostFtdcExchangeIDType ExchangeID; ///交易所代码
    TThostFtdcInstrumentNameType InstrumentName; ///合约名称
    TThostFtdcExchangeInstIDType ExchangeInstID;///合约在交易所的代码
    TThostFtdcInstrumentIDType ProductID; ///产品代码
    TThostFtdcProductClassType ProductClass; ///产品类型
    TThostFtdcYearType DeliveryYear; ///交割年份
    TThostFtdcMonthType DeliveryMonth;///交割月
    TThostFtdcVolumeType MaxMarketOrderVolume; ///市价单最大下单量
    TThostFtdcVolumeType MinMarketOrderVolume;///市价单最小下单量
    TThostFtdcVolumeType MaxLimitOrderVolume; ///限价单最大下单量
    TThostFtdcVolumeType MinLimitOrderVolume; ///限价单最小下单量
    TThostFtdcVolumeMultipleType VolumeMultiple; ///合约数量乘数
    TThostFtdcPriceType PriceTick; ///最小变动价位
    TThostFtdcDateType CreateDate; ///创建日
    TThostFtdcDateType OpenDate; ///上市日
    TThostFtdcDateType ExpireDate;///到期日
    TThostFtdcDateType StartDelivDate; ///开始交割日
    TThostFtdcDateType EndDelivDate; ///结束交割日
    TThostFtdcInstLifePhaseType InstLifePhase; ///合约生命周期状态
    TThostFtdcBoolType IsTrading;///当前是否交易
    TThostFtdcPositionTypeType PositionType; ///持仓类型
    TThostFtdcPositionDateTypeType PositionDateType;///持仓日期类型
    TThostFtdcRatioType LongMarginRatio;///多头保证金率
    TThostFtdcRatioType ShortMarginRatio; ///空头保证金率
    TThostFtdcMaxMarginSideAlgorithmType MaxMarginSideAlgorithm;///是否使用大额单边保证金算法
    TThostFtdcInstrumentIDType UnderlyingInstrID;///基础商品代码
    TThostFtdcPriceType StrikePrice;///执行价
    TThostFtdcOptionsTypeType OptionsType;///期权类型
    TThostFtdcUnderlyingMultipleType UnderlyingMultiple; ///合约基础商品乘数
    TThostFtdcCombinationTypeType CombinationType;///组合类型
};
VolumeMultiple:合约乘数(同交易所)
PriceTick:最小变动价位(同交易所)
IsTrading:是否活跃(同交易所)
DeliveryYear:交割年份(同交易所)
DeliveryMonth:交割月(同交易所)
OpenDate:上市日(同交易所)
CreateDate:创建日(同交易所)
ExpireDate:到期日(同交易所)
StartDeliveDate:开始交割日(同交易所)
EndDelivDate:结束交割日(同交易所)

同交易所表示这些字段每天更新自交易所,其余字段为柜台设置值。如果发现有些字段值有误,则以此来判断是交易所问题还是CTP柜台设置问题。
pRspInfo:响应信息
struct CThostFtdcRspInfoField
{
    TThostFtdcErrorIDType ErrorID; ///错误代码
    TThostFtdcErrorMsgType ErrorMsg;///错误信息
};
nRequestID:返回用户操作请求的ID,该ID 由用户在操作请求时指定。
bIsLast:指示该次返回是否为针对nRequestID的最后一次返回。

2. 保证金率查询结果中包含保证金

2.1 保证金率查询命令

ReqQryInstrumentMarginRate
请求查询合约保证金率,对应响应OnRspQryInstrumentMarginRate。如果InstrumentID填空,则返回持仓对应的合约保证金率,否则返回相应InstrumentID的保证金率。
目前无法通过一次查询得到所有合约保证金率,如果要查询所有,则需要通过多次查询得到。

◇ 1.函数原型
virtual int ReqQryInstrumentMarginRate(CThostFtdcQryInstrumentMarginRateField *pQryInstrumentMarginRate, int nRequestID) = 0;
◇ 2.参数pQryInstrumentMarginRate:
查询合约保证金率
struct CThostFtdcQryInstrumentMarginRateField
{
    ///经纪公司代码
    TThostFtdcBrokerIDType BrokerID;
    ///投资者代码
    TThostFtdcInvestorIDType InvestorID;
    ///合约代码
    TThostFtdcInstrumentIDType InstrumentID;
    ///投机套保标志
    TThostFtdcHedgeFlagType HedgeFlag;
    ///交易所代码
    TThostFtdcExchangeIDType ExchangeID;
    ///投资单元代码
    TThostFtdcInvestUnitIDType InvestUnitID;
};
nRequestID:请求ID,对应响应里的nRequestID,无递增规则,由用户自行维护。
◇ 3.返回
0,代表成功。
-1,表示网络连接失败;
-2,表示未处理请求超过许可数;
-3,表示每秒发送请求数超过许可数。

2.2 保证金率查询结果

OnRspQryInstrumentMarginRate
请求查询合约保证金率响应,当执行ReqQryInstrumentMarginRate后,该方法被调用。

◇ 1.函数原型
virtual void OnRspQryInstrumentMarginRate(CThostFtdcInstrumentMarginRateField *pInstrumentMarginRate, CThostFtdcRspInfoField *pRspInfo, int nRequestID, bool bIsLast) {};

◇ 2.参数    ///:
合约保证金率
struct CThostFtdcInstrumentMarginRateField
{
    TThostFtdcInstrumentIDType InstrumentID;///合约代码
    TThostFtdcInvestorRangeType InvestorRange;///投资者范围
    TThostFtdcBrokerIDType BrokerID; ///经纪公司代码
    TThostFtdcInvestorIDType InvestorID;///投资者代码
    TThostFtdcHedgeFlagType HedgeFlag; ///投机套保标志
    TThostFtdcRatioType LongMarginRatioByMoney;///多头保证金率
    TThostFtdcMoneyType LongMarginRatioByVolume;///多头保证金费
    TThostFtdcRatioType ShortMarginRatioByMoney; ///空头保证金率
    TThostFtdcMoneyType ShortMarginRatioByVolume; ///空头保证金费
    TThostFtdcBoolType IsRelative;///是否相对交易所收取
    TThostFtdcExchangeIDType ExchangeID;///交易所代码
    TThostFtdcInvestUnitIDType InvestUnitID; ///投资单元代码
};
pRspInfo:响应信息

struct CThostFtdcRspInfoField
{
    TThostFtdcErrorIDType ErrorID;///错误代码
    TThostFtdcErrorMsgType ErrorMsg;///错误信息
};
nRequestID:返回用户操作请求的ID,该ID 由用户在操作请求时指定。

bIsLast:指示该次返回是否为针对nRequestID的最后一次返回。

3. 手续费(率)查询结果中包含手续费

3.1 手续费(率)查询命令

ReqQryInstrumentCommissionRate
请求查询合约手续费率,对应响应OnRspQryInstrumentCommissionRate。如果InstrumentID填空,则返回持仓对应的合约手续费率。
目前无法通过一次查询得到所有合约手续费率,如果要查询所有,则需要通过多次查询得到。
◇ 1.函数原型
virtual int ReqQryInstrumentCommissionRate(CThostFtdcQryInstrumentCommissionRateField *pQryInstrumentCommissionRate, int nRequestID) = 0;
◇ 2.参数pQryInstrumentCommissionRate:
查询手续费率
struct CThostFtdcQryInstrumentCommissionRateField
{
    TThostFtdcBrokerIDType BrokerID; ///经纪公司代码
    TThostFtdcInvestorIDType InvestorID;///投资者代码
    TThostFtdcInstrumentIDType InstrumentID;///合约代码
    TThostFtdcExchangeIDType ExchangeID;///交易所代码
    TThostFtdcInvestUnitIDType InvestUnitID;///投资单元代码
};

InstrumentID:返回手续费率对应的合约。
但是如果在柜台没有设置具体合约的手续费率,则默认会返回产品的手续费率,InstrumentID就为对应产品ID。
nRequestID:请求ID,对应响应里的nRequestID,无递增规则,由用户自行维护。
◇ 3.返回
0,代表成功。
-1,表示网络连接失败;
-2,表示未处理请求超过许可数;
-3,表示每秒发送请求数超过许可数。

3.3 手续费(率)查询结果

OnRspQryInstrumentCommissionRate
请求查询合约手续费率响应,当执行ReqQryInstrumentCommissionRate后,该方法被调用。

◇ 1.函数原型
virtual void OnRspQryInstrumentCommissionRate(CThostFtdcInstrumentCommissionRateField *pInstrumentCommissionRate, CThostFtdcRspInfoField *pRspInfo, int nRequestID, bool bIsLast) {};
◇ 2.参数pInstrumentCommissionRate:合约手续费率
struct CThostFtdcInstrumentCommissionRateField
{
    TThostFtdcInstrumentIDType InstrumentID; ///合约代码
    TThostFtdcInvestorRangeType InvestorRange; ///投资者范围
    TThostFtdcBrokerIDType BrokerID;///经纪公司代码
    TThostFtdcInvestorIDType InvestorID; ///投资者代码
    TThostFtdcRatioType OpenRatioByMoney; ///开仓手续费率
    TThostFtdcRatioType OpenRatioByVolume; ///开仓手续费
    TThostFtdcRatioType CloseRatioByMoney;///平仓手续费率
    TThostFtdcRatioType CloseRatioByVolume;///平仓手续费
    TThostFtdcRatioType CloseTodayRatioByMoney;///平今手续费率
    TThostFtdcRatioType CloseTodayRatioByVolume;///平今手续费
    TThostFtdcExchangeIDType ExchangeID; ///交易所代码
    TThostFtdcBizTypeType BizType;///业务类型    
    TThostFtdcInvestUnitIDType InvestUnitID;///投资单元代码
};

pRspInfo:
响应信息
struct CThostFtdcRspInfoField
{
    TThostFtdcErrorIDType ErrorID; ///错误代码
    TThostFtdcErrorMsgType ErrorMsg; ///错误信息
};
nRequestID:返回用户操作请求的ID,该ID 由用户在操作请求时指定。

bIsLast:指示该次返回是否为针对nRequestID的最后一次返回。

4. 合约+保证金率+手续费(率)= 完整的合约参数

令:
合约查询结果 = C
保证金率查询结果 = M
手续费查询结果 = S
则:

合约乘数:

C["VolumeMultiple"]

保证金率:

if M["Is_Relative"] == 1:
    多头保证金率 = C["LongMarginRatio"] + M["LongMarginRatioByMoney"] 
    空头保证金率 = C["ShortMarginRatio"] + M["ShortMarginRatioByMoney"] 

else:
多头保证金率 = M["LongMarginRatioByMoney"]
空头保证金率 = M["ShortMarginRatioByMoney"]

手续费(率):

        if S.open_ratio_bymoney == 0.0:
            开仓手续费= [FeeType.LOT,S["OpenRatioByVolume"] ]
            平仓手续费= [FeeType.LOT,S["CloseRatioByVolume"] ]
            平今手续费= [FeeType.LOT,S["CloseTodayRatioByVolume"] ]
        else:
            开仓手续费 = [FeeType.RATE,S["OpenRatioByMoney"] ]
            平仓手续费 = [FeeType.RATE,S["CloseRatioByMoney"] ]
            平今手续费 = [FeeType.RATE,S["CloseTodayRatioByMoney"] ]      


【VNPY进阶】on_tick函数内撤单追单详解,实盘在用的代码,没有坑哦

0.修改OrderData如下:

@dataclass
class OrderData(BaseData):
    """
    Order data contains information for tracking lastest status 
    of a specific order.
    """

    symbol: str
    exchange: Exchange
    orderid: str

    type: OrderType = OrderType.LIMIT
    direction: Direction = Direction.NET
    offset: Offset = Offset.NONE
    price: float = 0
    volume: float = 0
    traded: float = 0
    status: Status = Status.SUBMITTING
    datetime: datetime = None

    cancel_time: str = ""
    def __post_init__(self):
        """"""
        self.vt_symbol = f"{self.symbol}_{self.exchange.value}/{self.gateway_name}"
        self.vt_orderid = f"{self.gateway_name}_{self.orderid}"
        #未成交量
        self.untrade = self.volume - self.traded

1.策略init初始化参数

        #状态控制初始化
        self.chase_long_trigger = False
        self.chase_sell_trigger = False
        self.chase_short_trigger = False
        self.chase_cover_trigger = False  
        self.cancel_status = False
        self.long_trade_volume = 0
        self.short_trade_volume = 0
        self.sell_trade_volume = 0
        self.cover_trade_volume = 0 
        self.chase_interval   =    10    #拆单间隔:秒

2.on_tick里面的代码如下

get_position_detail参考这个帖子 https://www.vnpy.com/forum/topic/2167-cha-xun-cang-wei-chi-cang-jun-jie-wei-cheng-jiao-wei-tuo-dan-yi-ge-han-shu-gao-ding

    def on_tick(self, tick: TickData):
        working_order_dict = self.get_position_detail(tick.vt_symbol).active_orders
        #working_order_dict = self.order_dict
        if working_order_dict:
            #委托完成状态
            order_finished = False
            vt_orderid = list(working_order_dict.items())[0][0]         #委托单vt_orderid
            working_order = list(working_order_dict.items())[0][1]      #委托单字典 
            #开平仓追单,部分交易没有平仓指令(Offset.NONE)
            """获取到未成交委托单后检查未成交委托量>0,tick.datetime - 未成交委托单的datetime>追单间隔(chase_interval),同时chase_long_trigger状态未触发和有vt_orderid的判定(之前有收过空vt_orderid,所有要加个过滤),撤销该未成交委托单,赋值chase_long_trigger为True.chase_long_trigger为True且没有未成交委托单时执行追单,如有未成交委托单则调用cancel_surplus_order取消所有未成交委托单,追单的委托单发送出去后初始化chase_long_trigger.其他方向的撤单追单也是一样的流程"""
            if working_order.offset in (Offset.NONE,Offset.OPEN):
                if working_order.direction == Direction.LONG:
                    self.long_trade_volume = working_order.untrade
                    if (tick.datetime - working_order.datetime).seconds > self.chase_interval and self.long_trade_volume > 0 and (not self.chase_long_trigger) and vt_orderid:
                        #撤销之前发出的未成交订单
                        self.cancel_order(vt_orderid)
                        self.chase_long_trigger = True
                elif working_order.direction == Direction.SHORT:
                    self.short_trade_volume = working_order.untrade
                    if (tick.datetime - working_order.datetime).seconds > self.chase_interval and self.short_trade_volume > 0 and (not self.chase_short_trigger) and vt_orderid:  
                        self.cancel_order(vt_orderid)
                        self.chase_short_trigger = True
            #平仓追单
            elif working_order.offset in (Offset.CLOSE,Offset.CLOSETODAY):
                if working_order.direction == Direction.SHORT: 
                    self.sell_trade_volume = working_order.untrade
                    if (tick.datetime - working_order.datetime).seconds > self.chase_interval and self.sell_trade_volume > 0 and (not self.chase_sell_trigger) and vt_orderid: 
                        self.cancel_order(vt_orderid)
                        self.chase_sell_trigger = True                                                    
                if working_order.direction == Direction.LONG:
                    self.cover_trade_volume = working_order.untrade
                    if (tick.datetime - working_order.datetime).seconds > self.chase_interval and self.cover_trade_volume > 0 and (not self.chase_cover_trigger) and vt_orderid:                                                       
                        self.cancel_order(vt_orderid)
                        self.chase_cover_trigger = True   
        else:
            order_finished = True
            self.cancel_status = False
        if self.chase_long_trigger:
            if order_finished:
                self.buy(tick.ask_price_1,self.long_trade_volume)
                self.chase_long_trigger = False  
            else:
                self.cancel_surplus_order(list(working_order_dict))
        elif self.chase_short_trigger:
            if  order_finished:
                self.short(tick.bid_price_1,self.short_trade_volume)
                self.chase_short_trigger = False 
            else:
                self.cancel_surplus_order(list(working_order_dict))
        elif self.chase_sell_trigger:
            if order_finished:
                self.sell(tick.bid_price_1,self.sell_trade_volume)
                self.chase_sell_trigger = False                      
            else:
                self.cancel_surplus_order(list(working_order_dict))
        elif self.chase_cover_trigger:
            if order_finished:
                self.cover(tick.ask_price_1,self.cover_trade_volume)
                self.chase_cover_trigger = False
            else:
                self.cancel_surplus_order(list(working_order_dict))
    #------------------------------------------------------------------------------------
    def cancel_surplus_order(self,orderids:list):
        """
        撤销剩余活动委托单
        """
        if not self.cancel_status:
            for vt_orderid in  orderids:
                self.cancel_order(vt_orderid)
            self.cancel_status = True


vnpy获取手续费率合并到contract

首选修改vnpy\trader,object.py里面的ContractData

@dataclass
class ContractData(BaseData):
    """
    Contract data contains basic information about each contract traded.
    """
    price_tick: float = 0                               #最小价格变动

    margin_ratio : float = 0                            #保证金率
    max_order_volume : float = 0                        #限价单最大单次委托量
    open_commission_ratio : float = 0                   #开仓手续费率
    open_commission : float = 0                         #开仓手续费
    close_commission_ratio : float = 0                  #平仓手续费率
    close_commission : float = 0                        #平仓手续费  
    close_commission_today_ratio : float = 0            #平今手续费率
    close_commission_today : float = 0                  #平今手续费

CtpGateway里面加上

import shelve
#------------------------------------------------------------------------------------
def remain_alpha(convert_contract:str) -> str:
    """
    返回vt_symbol或者symbol的字母字符串
    """
    if "." in convert_contract:
        convert_contract = extract_vt_symbol(convert_contract)[0]
    symbol_mark = "".join(list(filter(str.isalpha,convert_contract)))
    return symbol_mark
#-------------------------------------------------------------------------------------------------
class CtpGateway(BaseGateway):
    #-------------------------------------------------------------------------------------------------
    def query_commission(self):
        """查询手续费数据"""
        self.td_api.query_commission()
    #-------------------------------------------------------------------------------------------------
    def save_commission(self):
        """保存手续费数据"""
        self.td_api.save_commission()
    #-------------------------------------------------------------------------------------------------
    def query_margin_ratio(self):
        """查询保证金率数据"""
        self.td_api.query_margin_ratio()
    #-------------------------------------------------------------------------------------------------
    def save_margin_ratio(self):
        """保存保证金率数据"""
        self.td_api.save_margin_ratio()
    #-------------------------------------------------------------------------------------------------
    def close(self):
        """"""
        #关闭ctp api前保存手续费,保证金率数据到硬盘
        self.save_commission()
        self.save_margin_ratio()
        self.td_api.close()
        self.md_api.close()
        self.query_functions = [self.query_account, self.query_position,self.query_commission,self.query_margin_ratio]

tdapi里面加上

class CtpTdApi(TdApi):
    """"""

    def __init__(self, gateway):
        self.commission_file_name = 'commission_data'
        self.commission_file_path = get_folder_path(self.commission_file_name)        
        self.commission_req = {}        #手续费查询字典   
        self.commission_data = {}       #手续费字典
        self.margin_ratio_file_name = 'margin_ratio_data'
        self.margin_ratio_file_path = get_folder_path(self.margin_ratio_file_name)        
        self.margin_ratio_req = {}        #保证金率查询字典   
        self.margin_ratio_data = {}       #保证金率字典 
    def onRspQryInstrument(self, data: dict, error: dict, reqid: int, last: bool): 
        #读取硬盘存储手续费数据,保证金率数据
        self.load_commission() 
        self.load_margin_ratio()
        """合约查询回报"""        
        contract = ContractData(
            symbol=data["InstrumentID"],
            exchange=EXCHANGE_CTP2VT[data["ExchangeID"]],
            name=data["InstrumentName"],
            product=PRODUCT_CTP2VT.get(data["ProductClass"], None),
            size=data["VolumeMultiple"],
            price_tick=data["PriceTick"],                    #合约最小价格变动
            max_order_volume=data["MaxLimitOrderVolume"],  #限价单次最大委托量
            gateway_name=self.gateway_name)
        #手续费数据合并到contract
        for symbol in self.commission_data.keys():
            if symbol == contract.symbol:
                contract.open_commission_ratio=self.commission_data[symbol]['OpenRatioByMoney']                  #开仓手续费率
                contract.open_commission=self.commission_data[symbol]['OpenRatioByVolume']                       #开仓手续费
                contract.close_commission_ratio=self.commission_data[symbol]['CloseRatioByMoney']                #平仓手续费率
                contract.close_commission=self.commission_data[symbol]['CloseRatioByVolume']                     #平仓手续费
                contract.close_commission_today_ratio=self.commission_data[symbol]['CloseTodayRatioByMoney']     #平今手续费率
                contract.close_commission_today=self.commission_data[symbol]['CloseTodayRatioByVolume']          #平今手续费                
            elif remain_alpha(symbol) == remain_alpha(contract.symbol):
                contract.open_commission_ratio=self.commission_data[symbol]['OpenRatioByMoney']                  #开仓手续费率
                contract.open_commission=self.commission_data[symbol]['OpenRatioByVolume']                       #开仓手续费
                contract.close_commission_ratio=self.commission_data[symbol]['CloseRatioByMoney']                #平仓手续费率
                contract.close_commission=self.commission_data[symbol]['CloseRatioByVolume']                     #平仓手续费
                contract.close_commission_today_ratio=self.commission_data[symbol]['CloseTodayRatioByMoney']     #平今手续费率
                contract.close_commission_today=self.commission_data[symbol]['CloseTodayRatioByVolume']          #平今手续费

        for symbol in self.margin_ratio_data.keys():
            if symbol == contract.symbol:
                contract.margin_ratio = self.margin_ratio_data[symbol]['LongMarginRatioByMoney']                 #合约保证金比率
    #-------------------------------------------------------------------------------------------------
    def onRspQryInstrumentCommissionRate(self, data: dict, error: dict, reqid: int, last: bool):
        """查询合约手续费率"""
        symbol = data.get('InstrumentID',None)
        if symbol:
            self.commission_data[symbol] = data
    #-------------------------------------------------------------------------------------------------
    def onRspQryInstrumentMarginRate(self, data: dict, error: dict, reqid: int, last: bool):
        """查询保证金率"""
        symbol = data.get('InstrumentID',None)
        if symbol:
            self.margin_ratio_data[symbol] = data
    #-------------------------------------------------------------------------------------------------
    def load_commission(self):
        """从硬盘读取手续费数据"""
        f = shelve.open(f"{self.commission_file_path}\\commission_data.vt")
        if 'data' in f:
            d = f['data']
            for key, value in list(d.items()):
                self.commission_data[key] = value
        f.close()
    #-------------------------------------------------------------------------------------------------
    def save_commission(self):
        """保存手续费数据到硬盘"""
        f = shelve.open(f"{self.commission_file_path}\\commission_data.vt")
        f['data'] = self.commission_data
        f.close()         
    #-------------------------------------------------------------------------------------------------
    def load_margin_ratio(self):
        """从硬盘读取保证金率数据"""
        f = shelve.open(f"{self.margin_ratio_file_path}\\margin_ratio_data.vt")
        if 'data' in f:
            d = f['data']
            for key, value in list(d.items()):
                self.margin_ratio_data[key] = value
        f.close()
    #-------------------------------------------------------------------------------------------------
    def save_margin_ratio(self):
        """保存保证金率数据到硬盘"""
        f = shelve.open(f"{self.margin_ratio_file_path}\\margin_ratio_data.vt")
        f['data'] = self.margin_ratio_data
        f.close() 
    #-------------------------------------------------------------------------------------------------
    #commission_vt_symbol,margin_ratio_vt_symbol都是全市场合约列表,需要自己维护
    #我用的合约连接交易所符号是'_',用'.'自己替换
    #-------------------------------------------------------------------------------------------------
    def query_commission(self):
        #查询手续费率
        if len(commission_vt_symbol) > 0:
            symbol = commission_vt_symbol[0].split('_')[0]
            #手续费率查询字典
            self.commission_req['BrokerID'] = self.brokerid
            self.commission_req['InvestorID'] = self.userid
            self.commission_req['InstrumentID'] = symbol
            self.reqid += 1 
            #请求查询手续费率
            self.reqQryInstrumentCommissionRate(self.commission_req,self.reqid)  
            commission_vt_symbol.pop(0)
    def query_margin_ratio(self):
        if len(margin_ratio_vt_symbol) > 0:
            symbol = margin_ratio_vt_symbol[0].split('_')[0]
            #保证金率查询字典
            self.margin_ratio_req['BrokerID'] = self.brokerid
            self.margin_ratio_req['InvestorID'] = self.userid
            self.margin_ratio_req['InstrumentID'] = symbol
            self.margin_ratio_req['HedgeFlag'] = THOST_FTDC_HF_Speculation
            self.reqid += 1 
            #请求查询保证金率
            self.reqQryInstrumentMarginRate(self.margin_ratio_req,self.reqid)  
            margin_ratio_vt_symbol.pop(0)

trader\ui,widget.py里面修改

class ContractManager(QtWidgets.QWidget):
    """
    Query contract data available to trade in system.
    """

    headers = {
        "vt_symbol": "本地代码",
        "symbol": "代码",
        "exchange": "交易所",
        "name": "名称",
        "product": "合约分类",
        "size": "合约乘数",
        "price_tick": "价格跳动",
        "min_volume": "最小委托量",
        "margin_ratio": "保证金率",
        "open_commission_ratio": "开仓手续费率",
        "open_commission": "开仓手续费",
        "close_commission_ratio": "平仓手续费率",
        "close_commission": "平仓手续费",
        "close_commission_today_ratio": "平今手续费率",
        "close_commission_today": "平今手续费",
        "gateway_name": "交易接口",
    }


如何解决VXcode调试python代码每次都出现select a debug configuration,实现F5一键调试

每次使用VX code调试python代码时,都会跳出如下界面,按照下面步骤操作只需设置一次,以后就可以实现F5调试,不再跳出来下面界面:

description
解决步骤:

第一步,将编写的代码保存为py格式文件。

第二步,点击菜单打开文件夹,定位到代码文件所在文件夹,点击选择文件夹。

description

第三步,在资源管理器内打开文件夹下的python代码文件

description

第四步,点击菜单栏运行——添加配置

description

第五步,选择python file

description

第六步,出现launch.json配置文件,表示添加成功

description



一个可以交易夜盘的RBreakerStrategy

1. 为什么要修改?

原因见这个帖子:https://www.vnpy.com/forum/topic/4461-shuo-shi-r-breakerce-lue-de-wen-ti

2. 修改步骤:

2.1 添加文件vnpy\usertools\trade_hour.py

内容如下:

"""
本文件主要实现合约的交易时间段
作者:hxxjava
日期:2020-8-1
"""
from typing import Callable,List,Dict, Tuple, Union
from enum import Enum

import datetime
import pytz
CHINA_TZ = pytz.timezone("Asia/Shanghai")

from vnpy.trader.utility import extract_vt_symbol
from vnpy.trader.constant import Interval

from rqdatac.utils import to_date
import rqdatac as rq


def get_listed_date(symbol:str):
    ''' 
    获得上市日期 
    '''
    info = rq.instruments(symbol)
    return to_date(info.listed_date)

def get_de_listed_date(symbol:str):
    ''' 
    获得交割日期 
    '''
    info = rq.instruments(symbol)
    return to_date(info.de_listed_date)

class Timeunit(Enum):
    """ 
    时间单位 
    """
    SECOND = '1s'
    MINUTE = '1m'
    HOUR = '1h'

class TradeHours(object):
    """ 合约交易时间段 """
    def __init__(self,symbol:str):
        self.symbol = symbol.upper()
        self.init()

    def init(self):
        """ 
        初始化交易日字典及交易时间段数据列表 
        """
        self.listed_date = get_listed_date(self.symbol)
        self.de_listed_date = get_de_listed_date(self.symbol)

        self.trade_date_index = {}   # 合约的交易日索引字典
        self.trade_index_date = {}   # 交易天数与交易日字典

        trade_dates = rq.get_trading_dates(self.listed_date,self.de_listed_date) # 合约的所有的交易日
        days = 0
        for td in trade_dates:
            self.trade_date_index[td] = days
            self.trade_index_date[days] = td
            days += 1

        trading_hours = rq.get_trading_hours(self.symbol,date=self.listed_date,frequency='tick',expected_fmt='datetime')

        self.time_dn_pairs = self._get_trading_times_dn(trading_hours)

        trading_hours0 = [(CHINA_TZ.localize(start),CHINA_TZ.localize(stop)) for start,stop in trading_hours]
        self.trade_date_index[self.listed_date] = (0,trading_hours0)
        for day in range(1,days):
            td = self.trade_index_date[day]
            trade_datetimes = []
            for (start,dn1),(stop,dn2) in self.time_dn_pairs:
                #start:开始时间,dn1:相对交易日前推天数,
                #stop :开始时间,dn2:相对开始时间后推天数     
                d = self.trade_index_date[day+dn1]
                start_dt = CHINA_TZ.localize(datetime.datetime.combine(d,start))
                stop_dt = CHINA_TZ.localize(datetime.datetime.combine(d,stop))
                trade_datetimes.append((start_dt,stop_dt+datetime.timedelta(days=dn2)))
            self.trade_date_index[td] = (day,trade_datetimes)

    def _get_trading_times_dn(self,trading_hours:List[Tuple[datetime.datetime,datetime.datetime]]): 
        """ 
        交易时间跨天处理,不推荐外部使用 。
        产生的结果:[((start1,dn11),(stop1,dn21)),((start2,dn12),(stop2,dn22)),...,((startN,dn1N),(stopN,dn2N))]
        其中:
            startN:开始时间,dn1N:相对交易日前推天数,
            stopN:开始时间,dn2N:相对开始时间后推天数      
        """
        ilen = len(trading_hours)
        if ilen == 0:
            return []
        start_stops = []
        for start,stop in trading_hours:
            start_stops.insert(0,(start.time(),stop.time()))

        pre_start,pre_stop = start_stops[0]
        dn1 = 0
        dn2 = 1 if pre_start > pre_stop else 0
        time_dn_pairs = [((pre_start,dn1),(pre_stop,dn2))]
        for start,stop in start_stops[1:]:
            if start > pre_start:
                dn1 -= 1
            dn2 = 1 if start > stop else 0
            time_dn_pairs.insert(0,((start,dn1),(stop,dn2)))
            pre_start,pre_stop = start,stop

        return time_dn_pairs

    def get_date_tradetimes(self,date:datetime.date):
        """ 
        得到合约date日期的交易时间段 
        """
        idx,trade_times = self.trade_date_index.get(date,(None,[]))
        return idx,trade_times

    def get_trade_datetimes(self,dt:datetime,allday:bool=False):
        """ 
        得到合约date日期的交易时间段 
        """
        # 得到最早的交易时间
        idx0,trade_times0 = self.get_date_tradetimes(self.listed_date)
        start0,stop0 = trade_times0[0]
        if dt < start0:
            return None,[]

        # 首先找到dt日期自上市以来的交易天数
        date,dn = dt.date(),0
        days = None
        while date < self.de_listed_date:
            days,ths = self.trade_date_index.get(date,(None,[]))
            if not days:
                dn += 1
                date = (dt+datetime.timedelta(days=dn)).date()
            else:
                break
        # 如果超出交割日也没有找到,那这就不是一个有效的交易时间
        if days is None:
            return (None,[])

        index_3 = [days,days+1,days-1]  # 前后三天的

        date_3d = []
        for day in index_3: 
            date = self.trade_index_date.get(day,None)
            date_3d.append(date)

        # print(date_3d)

        for date in date_3d:
            if not date:
                # print(f"{date} is not trade date")
                continue

            idx,trade_dts = self.get_date_tradetimes(date)
            # print(f"{date} tradetimes {trade_dts}")
            ilen = len(trade_dts)
            if ilen > 0:
                start0,stop = trade_dts[0]      # start0 是date交易日的开始时间
                start,stop0 = trade_dts[-1]
            if dt<start0 or dt>stop0:
                continue

            for start,stop in trade_dts:
                if dt>=start and dt < stop:
                    if allday:
                        return idx,trade_dts
                    else:
                        return idx,[(start,stop)]

        return None,[]

    def get_trade_time_perday(self):
        """ 
        计算每日的交易总时长(单位:分钟) 
        """
        TTPD = datetime.timedelta(0,0,0)

        datetimes = []
        today = datetime.datetime.now().date()

        for (start,dn1),(stop,dn2) in self.time_dn_pairs:
            start_dt = CHINA_TZ.localize(datetime.datetime.combine(today,start)) + datetime.timedelta(days=dn1)
            stop_dt = CHINA_TZ.localize(datetime.datetime.combine(today,stop)) + datetime.timedelta(days=dn2)
            time_delta = stop_dt - start_dt
            TTPD = TTPD + time_delta
        return int(TTPD.seconds/60)

    def get_trade_time_inday(self,dt:datetime,unit:Timeunit=Timeunit.MINUTE):
        """ 
        计算dt在交易日内的分钟数 
        unit: '1s':second;'1m':minute;'1h';1h
        """
        TTID = datetime.timedelta(0,0,0)

        day,trade_times = self.get_trade_datetimes(dt,allday=True)
        if not trade_times:
            return None

        for start,stop in trade_times:
            if dt > stop:
                time_delta = stop - start
                TTID += time_delta
            elif dt > start:
                time_delta = dt - start
                TTID += time_delta     
                break
            else:
                break          

        if unit == Timeunit.SECOND:
            return TTID.seconds
        elif unit == Timeunit.MINUTE:
            return int(TTID.seconds/60) 
        elif unit == Timeunit.HOUR:
            return int(TTID.seconds/3600) 
        else:
            return TTID

    def get_day_tradetimes(self,dt:datetime):
        """ 
        得到合约日盘的交易时间段 
        """
        index,trade_times = self.get_trade_datetimes(dt,allday=True)
        trade_times1 = []
        if trade_times:
            for start_dt,stop_dt in trade_times:
                if start_dt.time() < datetime.time(18,0,0):
                    trade_times1.append((start_dt,stop_dt))
            return index,trade_times1
        return (index,trade_times1)

    def get_night_tradetimes(self,dt:datetime):
        """ 
        得到合约夜盘的交易时间段 
        """
        index,trade_times = self.get_trade_datetimes(dt,allday=True)
        trade_times1 = []
        if trade_times:
            for start_dt,stop_dt in trade_times:
                if start_dt.time() > datetime.time(18,0,0):
                    trade_times1.append((start_dt,stop_dt))
            return index,trade_times1
        return (index,trade_times1)

    def convet_to_datetime(self,day:int,minutes:int):
        """ 
        计算minutes在第day交易日内的datetime形式的时间 
        """
        date = self.trade_index_date.get(day,None)
        if date is None:
            return None
        idx,trade_times = self.trade_date_index.get(date,(None,[]))
        if not trade_times:     # 不一定必要
            return None
        for (start,stop) in trade_times:
            timedelta = stop - start 
            if minutes < int(timedelta.seconds/60):
                return start + datetime.timedelta(minutes=minutes)
            else:
                minutes -= int(timedelta.seconds/60)
        return None

    def get_bar_window(self,dt:datetime,window:int,interval:Interval=Interval.MINUTE):
        """ 
        计算dt所在K线的起止时间 
        """
        bar_windows = (None,None)

        day,trade_times = self.get_trade_datetimes(dt,allday=True)
        if not trade_times:
            # print(f"day={day} trade_times={trade_times}")
            return bar_windows

        # 求每个交易日的交易时间分钟数
        TTPD = self.get_trade_time_perday()

        # 求dt在交易日内的分钟数
        TTID = self.get_trade_time_inday(dt,unit=Timeunit.MINUTE)

        # 得到dt时刻K线的起止时间 
        total_minites = day*TTPD + TTID

        # 计算K线宽度(分钟数)
        if interval == Interval.MINUTE:
            bar_width = window
        elif interval == Interval.HOUR:
            bar_width = 60*window
        elif interval == Interval.DAILY:
            bar_width = TTPD*window
        elif interval == Interval.WEEKLY:
            bar_width = TTPD*window*5
        else:
            return bar_windows

        # 求K线的开始时间的和结束的分钟形式
        start_m = int(total_minites/bar_width)*bar_width
        stop_m = start_m + bar_width

        # 计算K开始时间的datetime形式
        start_d = int(start_m / TTPD)
        minites = start_m % TTPD
        start_dt = self.convet_to_datetime(start_d,minites)
        # print(f"start_d={start_d} minites={minites}---->{start_dt}")

        # 计算K结束时间的datetime形式
        stop_d = int(stop_m / TTPD)
        minites = stop_m % TTPD
        stop_dt = self.convet_to_datetime(stop_d,minites)
        # print(f"stop_d={stop_d} minites={minites}---->{stop_dt}")

        return start_dt,stop_dt

    def get_date_start_stop(self,dt:datetime):
        """
        获得dt所在交易日的开始和停止时间
        """
        index,trade_times = self.get_trade_datetimes(dt,allday=True)
        if trade_times:
            valid_dt = False
            for t1,t2 in trade_times:
                if t1 < dt and dt < t2:
                    valid_dt = True
                    break
            if valid_dt:
                start_dt = trade_times[0][0]
                stop_dt = trade_times[-1][1]
                return True,(start_dt,stop_dt)
        return False,(None,None)

    def get_day_start_stop(self,dt:datetime):
        """
        获得dt所在交易日日盘的开始和停止时间
        """
        index,trade_times = self.get_day_tradetimes(dt)
        if trade_times:
            valid_dt = False
            for t1,t2 in trade_times:
                if t1 < dt and dt < t2:
                    valid_dt = True
                    break
            if valid_dt:
                start_dt = trade_times[0][0]
                stop_dt = trade_times[-1][1]
                return True,(start_dt,stop_dt)
        return False,(None,None)

    def get_night_start_stop(self,dt:datetime):
        """
        获得dt所在交易日夜盘的开始和停止时间
        """
        index,trade_times = self.get_night_tradetimes(dt)
        if trade_times:
            valid_dt = False
            for t1,t2 in trade_times:
                if t1 < dt and dt < t2:
                    valid_dt = True
                    break
            if valid_dt:
                start_dt = trade_times[0][0]
                stop_dt = trade_times[-1][1]
                return True,(start_dt,stop_dt)
        return False,(None,None)


if __name__ == "__main__":
    rq.init('xxxxx','******',("rqdatad-pro.ricequant.com",16011))

    # vt_symbols = ["rb2010.SHFE","ag2012.SHFE","i2010.DCE"]
    vt_symbols = ["ag2012.SHFE"]
    date0 = datetime.date(2020,8,31)
    dt0 = CHINA_TZ.localize(datetime.datetime(2020,8,31,9,20,15))
    for vt_symbol in vt_symbols:
        symbol,exchange = extract_vt_symbol(vt_symbol)
        th = TradeHours(symbol)
        # trade_hours = th.get_date_tradetimes(date0)
        # print(f"\n{vt_symbol} {date0} trade_hours={trade_hours}")

        days,trade_hours = th.get_trade_datetimes(dt0,allday=True)

        print(f"\n{vt_symbol} {dt0} days:{days} trade_hours={trade_hours}")

        if trade_hours:
            day_start = trade_hours[0][0]
            day_end = trade_hours[-1][1]
            print(f"day_start={day_start} day_end={day_end}")
            exit_time = day_end + datetime.timedelta(minutes=-5)
            print(f"exit_time={exit_time}")

        dt1 = CHINA_TZ.localize(datetime.datetime(2020,8,31,9,20,15))
        dt2 = CHINA_TZ.localize(datetime.datetime(2020,9,1,1,1,15))

        for dt in [dt1,dt2]:
            in_trade,(start,stop) = th.get_date_start_stop(dt)
            if (in_trade):
                print(f"\n{vt_symbol} 时间 {dt} 交易日起止:{start,stop}")
            else:
                print(f"\n{vt_symbol} 时间 {dt} 非交易时间")

            in_day,(start,stop) = th.get_day_start_stop(dt)
            if (in_day):
                print(f"\n{vt_symbol} 时间 {dt} 日盘起止:{start,stop}")
            else:
                print(f"\n{vt_symbol} 时间 {dt} 非日盘时间")

            in_night,(start,stop) = th.get_night_start_stop(dt)
            if in_night:
                print(f"\n{vt_symbol} 时间 {dt} 夜盘起止:{start,stop}")
            else:
                print(f"\n{vt_symbol} 时间 {dt} 非夜盘时间")

2.2 修改策略文件 RBreakerStrategy.py

代码如下:

from datetime import datetime,time,timedelta
from vnpy.app.cta_strategy import (
    CtaTemplate,
    StopOrder,
    TickData,
    BarData,
    TradeData,
    OrderData,
    BarGenerator,
    ArrayManager
)

from vnpy.trader.utility import extract_vt_symbol
from vnpy.usertools.trade_hour import TradeHours

class RBreakStrategy2(CtaTemplate):
    """"""
    author = "KeKe"

    setup_coef = 0.25
    break_coef = 0.2
    enter_coef_1 = 1.07
    enter_coef_2 = 0.07
    fixed_size = 1
    donchian_window = 30

    trailing_long = 0.4
    trailing_short = 0.4
    multiplier = 3

    buy_break = 0   # 突破买入价
    sell_setup = 0  # 观察卖出价
    sell_enter = 0  # 反转卖出价
    buy_enter = 0   # 反转买入价
    buy_setup = 0   # 观察买入价
    sell_break = 0  # 突破卖出价

    intra_trade_high = 0
    intra_trade_low = 0

    day_high = 0
    day_open = 0
    day_close = 0
    day_low = 0
    tend_high = 0
    tend_low = 0

    parameters = ["setup_coef", "break_coef", "enter_coef_1", "enter_coef_2", "fixed_size", "donchian_window"]
    variables = ["buy_break", "sell_setup", "sell_enter", "buy_enter", "buy_setup", "sell_break"]

    def __init__(self, cta_engine, strategy_name, vt_symbol, setting):
        """"""
        super(RBreakStrategy2, self).__init__(
            cta_engine, strategy_name, vt_symbol, setting
        )

        self.bg = BarGenerator(self.on_bar)
        self.am = ArrayManager()
        self.bars = []

        symbol,exchange = vt_symbol.split('.')
        self.trade_hour = TradeHours(symbol)
        self.trade_datetimes = None
        self.exit_time = None

    def on_init(self):
        """
        Callback when strategy is inited.
        """
        self.write_log("策略初始化")
        self.load_bar(10)

    def on_start(self):
        """
        Callback when strategy is started.
        """
        self.write_log("策略启动")

    def on_stop(self):
        """
        Callback when strategy is stopped.
        """
        self.write_log("策略停止")

    def on_tick(self, tick: TickData):
        """
        Callback of new tick data update.
        """
        self.bg.update_tick(tick)

    def is_new_day(self,dt:datetime):
        """
        判断dt时间是否在当天的交易时间段内
        """
        if not self.trade_datetimes: 
            return True
        day_start = self.trade_datetimes[0][0]
        day_end = self.trade_datetimes[-1][1]
        if day_start<=dt and dt < day_end:
            return False
        return True

    def on_bar(self, bar: BarData):
        """
        Callback of new bar data update.
        """
        self.cancel_all()

        am = self.am
        am.update_bar(bar)
        if not am.inited:
            return

        # 判断是否是下一交易日
        self.new_day = self.is_new_day(bar.datetime)
        if self.new_day:
            # 计算下一交易日的交易时间段
            days,self.trade_datetimes = self.trade_hour.get_trade_datetimes(bar.datetime,allday=True) 

            # 计算退出时间
            # print(f"trade_datetimes={self.trade_datetimes}")
            if self.trade_datetimes:
                day_end = self.trade_datetimes[-1][1]
                self.exit_time = day_end + timedelta(minutes=-5)

        if not self.trade_datetimes:
            # 不是个有效的K线,不可以处理,
            # 为什么会有K线推送?因为非交易时段接口的行为是不可理喻的
            return

        self.bars.append(bar)
        if len(self.bars) <= 2:
            return
        else:
            self.bars.pop(0)
        last_bar = self.bars[-2]

        # New Day
        if self.new_day:    # 如果是新交易日
            if self.day_open:
                self.buy_setup = self.day_low - self.setup_coef * (self.day_high - self.day_close)  # 观察买入价
                self.sell_setup = self.day_high + self.setup_coef * (self.day_close - self.day_low)  # 观察卖出价

                self.buy_enter = (self.enter_coef_1 / 2) * (self.day_high + self.day_low) - self.enter_coef_2 * self.day_high  # 反转买入价
                self.sell_enter = (self.enter_coef_1 / 2) * (self.day_high + self.day_low) - self.enter_coef_2 * self.day_low  # 反转卖出价

                self.buy_break = self.buy_setup + self.break_coef * (self.sell_setup - self.buy_setup)  # 突破买入价
                self.sell_break = self.sell_setup - self.break_coef * (self.sell_setup - self.buy_setup)  # 突破卖出价

            self.day_open = bar.open_price
            self.day_high = bar.high_price
            self.day_close = bar.close_price
            self.day_low = bar.low_price

        # Today
        else:
            self.day_high = max(self.day_high, bar.high_price)
            self.day_low = min(self.day_low, bar.low_price)
            self.day_close = bar.close_price

        if not self.sell_setup:
            return

        self.tend_high, self.tend_low = am.donchian(self.donchian_window)

        if bar.datetime < self.exit_time:

            if self.pos == 0:
                self.intra_trade_low = bar.low_price
                self.intra_trade_high = bar.high_price

                if self.tend_high > self.sell_setup:
                    long_entry = max(self.buy_break, self.day_high)
                    self.buy(long_entry, self.fixed_size, stop=True)

                    self.short(self.sell_enter, self.multiplier * self.fixed_size, stop=True)

                elif self.tend_low < self.buy_setup:
                    short_entry = min(self.sell_break, self.day_low)
                    self.short(short_entry, self.fixed_size, stop=True)

                    self.buy(self.buy_enter, self.multiplier * self.fixed_size, stop=True)

            elif self.pos > 0:
                self.intra_trade_high = max(self.intra_trade_high, bar.high_price)
                long_stop = self.intra_trade_high * (1 - self.trailing_long / 100)
                self.sell(long_stop, abs(self.pos), stop=True)

            elif self.pos < 0:
                self.intra_trade_low = min(self.intra_trade_low, bar.low_price)
                short_stop = self.intra_trade_low * (1 + self.trailing_short / 100)
                self.cover(short_stop, abs(self.pos), stop=True)

        # Close existing position
        else:
            if self.pos > 0:
                self.sell(bar.close_price * 0.99, abs(self.pos))
            elif self.pos < 0:
                self.cover(bar.close_price * 1.01, abs(self.pos))

        self.put_event()

    def on_order(self, order: OrderData):
        """
        Callback of new order data update.
        """
        pass

    def on_trade(self, trade: TradeData):
        """
        Callback of new trade data update.
        """
        self.put_event()

    def on_stop_order(self, stop_order: StopOrder):
        """
        Callback of stop order update.
        """
        pass


自用判断标的日内交易时间、以及收盘平仓

1.vnpy.app新建user_tools文件夹trading_hour.py放入user_tools
trading_hour.py如下:

from datetime import datetime,time,timedelta
import json

class TRADINGHOUR(object):
PATH = "***.vntrader/trading_hour.json" #PATH填trading_hour.json在电脑的绝对路径,trading_hour.json放入在vntrader目录下,如:c:/users/电脑名/vntrader/trading_hour.json
time_switch = 0
start_time = ""
end_time = ""

def get_trading_time(self,symbol):   #读取存入json中的标的开收盘时间
    with open(self.PATH,"r",encoding="utf_8") as f:            
        trading_hour = json.load(f)
    for key_symbol in trading_hour: 
        if key_symbol == symbol:                
            start = trading_hour[symbol][0]
            end = trading_hour[symbol][1]
            break
    self.start_time = datetime.strptime(start,"%H:%M:%S")   #datetime对象,后面比较时候要转换成time对象
    self.end_time = datetime.strptime(end,"%H:%M:%S")
    return self.start_time,self.end_time

def day_night_switch(self):             #time_switch = 0 白盘标的,1为夜盘0:00前收盘的标的(如:rb),2为0:00后收盘的标的(如:ag,cu)
    if self.start_time.time() < time(hour=20,minute=0):
        self.time_switch = 0

    elif self.end_time.time() > time(hour=3,minute=0):
        self.time_switch = 1

    else:
        self.time_switch = 2 
    return self.time_switch

def trading_period(self,bar):   #判断标的日内交易时间

        DAY_START = time(hour=9,minute=0)
        DAY_END = time(hour=14,minute=57)

        if self.day_night_switch() == 0:                
            return bar.datetime.time() < (self.end_time + timedelta(minutes=-3)).time()

        elif self.day_night_switch() == 1:
            return bar.datetime.time() >= self.start_time.time() and bar.datetime.time() < (self.end_time + timedelta(minutes=-3)).time() \
                or (bar.datetime.time() >= DAY_START and bar.datetime.time() < DAY_END )
        else:
            return bar.datetime.time() >= self.start_time.time() or bar.datetime.time() < (self.end_time + timedelta(minutes=-3)).time() \
                or (bar.datetime.time() >= DAY_START and bar.datetime.time() < DAY_END )

2.trading_hour.json如下: #记录了期货白盘期货标的开收盘时间,夜盘标的夜盘开盘以及夜盘收盘时间,遇到特殊日子(比如疫情期间,ag没有夜盘了,那时间要修正“AG”:["9:30:00","15:00:00"]),直接过来修改开收盘时间即可
{
"IF":["9:30:00","15:00:00"],
"IC":["9:30:00","15:00:00"],
"IH":["9:30:00","15:00:00"],
"T":["9:30:00","15:15:00"],
"AU":["21:00:00","2:30:00"],
"AG":["21:00:00","2:30:00"],
"CU":["21:00:00","1:00:00"],
"AL":["21:00:00","1:00:00"],
"ZN":["21:00:00","1:00:00"],
"PB":["21:00:00","1:00:00"],
"NI":["21:00:00","1:00:00"],
"SN":["21:00:00","1:00:00"],
"RB":["21:00:00","23:00:00"],
"I":["21:00:00","23:00:00"],
"HC":["21:00:00","23:00:00"],
"SS":["21:00:00","1:00:00"],
"SF":["9:00:00","15:00:00"],
"SM":["9:00:00","15:00:00"],
"JM":["21:00:00","23:00:00"],
"J":["21:00:00","23:00:00"],
"ZC":["21:00:00","23:00:00"],
"FG":["21:00:00","23:00:00"],
"SP":["21:00:00","23:00:00"],
"FU":["21:00:00","23:00:00"],
"LU":["21:00:00","23:00:00"],
"SC":["21:00:00","2:30:00"],
"BU":["21:00:00","23:00:00"],
"PG":["21:00:00","23:00:00"],
"RU":["21:00:00","23:00:00"],
"NR":["21:00:00","23:00:00"],
"L":["21:00:00","23:00:00"],
"TA":["21:00:00","23:00:00"],
"V":["21:00:00","23:00:00"],
"EG":["21:00:00","23:00:00"],
"MA":["21:00:00","23:00:00"],
"PP":["21:00:00","23:00:00"],
"EB":["21:00:00","23:00:00"],
"UR":["9:00:00","15:00:00"],
"SA":["21:00:00","23:00:00"],
"C":["21:00:00","23:00:00"],
"A":["21:00:00","23:00:00"],
"CS":["21:00:00","23:00:00"],
"B":["21:00:00","23:00:00"],
"M":["21:00:00","23:00:00"],
"Y":["21:00:00","23:00:00"],
"RM":["21:00:00","23:00:00"],
"OI":["21:00:00","23:00:00"],
"P":["21:00:00","23:00:00"],
"CF":["21:00:00","23:00:00"],
"SR":["21:00:00","23:00:00"],
"JD":["9:00:00","15:00:00"],
"AP":["9:00:00","15:00:00"],
"CJ":["9:00:00","15:00:00"]
}

3.在自己的策略中import TRADINHOUR from vnpy.app.user_tools.trading_hour import TRADINGHOUR 同时import re
def --init--():加入
self.tradingtime = TRADINGHOUR()
self.symbol = "".join(re.findall(r"\D+",self.get_data()["vt_symbol"].split(".")[0])).upper() #获取标的代码
self.start_time,self.end_time = self.tradingtime.get_trading_time(self.symbol)
这些完成后,就可以在你的on_bar函数里面调用self.tradingtime.trading_period(bar)获取标的日内交易时段了
例如:def on_bar(self, bar: BarData):
### ###
if self.tradingtime.trading_period(bar): #获取当下交易标的日内交易时段,平仓时间设置为收盘前3min

        if self.pos == 0:
                 ###****
        if self.pos > 0: 
               ### ****
        if self.pos < 0:
                ###*****

else: #收盘平仓

            if self.pos > 0:
                   self.sell(bar.close_price *0.99, abs(self.pos))
            elif self.pos < 0:
                   self.cover(bar.close_price*1.01, abs(self.pos))


海龟策略深入研究-策略回测系列-17 基于遗传算法的信号优化

海龟策略的信号来源于唐奇安通道突破。简单的说就是若突破上轨则做多,突破下轨则做空。我们可以对信号进行改良,如换成布林带通道,金肯特纳通道等等,并且增加过滤条件和离场条件。

但是呢?新的问题又来了:若用了新的指标,需要通过不断调试来得到“最优”参数,这样会耗费大量的时间。

那么,有没有办法在尽量少得时间内,尽可能得到全局最优解或者次优解呢?

答案就是遗传算法啦!

 
 
 

遗传算法原理


具体原理详见:
遗传算法原理简介

一文读懂遗传算法工作原理(附Python实现)

 

遗传算法要做的事情并不复杂:

  1. 随机生成一大推策略参数(称之为族群,族群内的个体对应某一组策略参数)
  2. 族群内个体间进行3类活动:个体间两两交叉互换;个体某个参数发生变异;个体繁殖(即直接复制参数)
  3. 形成子代
  4. 通过目标函数(如最大化夏普比率,最大化总盈亏等)对母代族群和子代族群进行评分
  5. 通过特点筛选标准(如NSGA-Ⅱ)从母代和子代中筛选个体,形成第二代族群。(类似进化论中的“自然选择”)
  6. 新的族群在特定的评分标准和筛选标准中不断迭代(即重复1-5步骤),得到最优解/次优解。

 
 
 

遗传算法代码示例


1.随机生成待优化的策略参数

def parameter_generate():
    '''
    根据设置的起始值,终止值和步进,随机生成待优化的策略参数
    '''
    parameter_list = []
    p1 = random.randrange(4,50,2)      #入场窗口
    p2 = random.randrange(4,50,2)      #出场窗口
    p3 = random.randrange(4,50,2)      #基于ATR窗口止损窗
    p4 = random.randrange(18,40,2)     #基于ATR的动态调仓 

    parameter_list.append(p1)
    parameter_list.append(p2)
    parameter_list.append(p3)
    parameter_list.append(p4)

    return parameter_list

 

  1. 设置目标优化函数(收益回撤比和夏普比率)
def object_func(strategy_avg):
    """
    本函数为优化目标函数,根据随机生成的策略参数,运行回测后自动返回2个结果指标:收益回撤比和夏普比率
    """
    # 创建回测引擎对象
    engine = BacktestingEngine()
    # 设置回测使用的数据                       
    engine.setBacktestingMode(engine.BAR_MODE)      # 设置引擎的回测模式为K线
    engine.setDatabase("VnTrader_Daily_Db", 'XBTHOUR')  # 设置使用的历史数据库
    engine.setStartDate('20170401')                 # 设置回测用的数据起始日期
    engine.setEndDate('20181230')                   # 设置回测用的数据起始日期

    # 配置回测引擎参数
    engine.setSlippage(0.5)                        
    engine.setRate(0.2/100)                     
    engine.setSize(10)                            
    engine.setPriceTick(0.5)                      
    engine.setCapital(1000000) 

    setting = {'entryWindow': strategy_avg[0],       #布林带窗口
               'exitWindow': strategy_avg[1],        #布林带通道阈值
               'atrWindow': strategy_avg[2],         #CCI窗口
               'artWindowUnit': strategy_avg[3],}    #ATR窗口               

    #加载策略          
    engine.initStrategy(TurtleTradingStrategy, setting)    
    # 运行回测,返回指定的结果指标   
    engine.runBacktesting()          # 运行回测
    #逐日回测   
    engine.calculateDailyResult()
    backresult = engine.calculateDailyStatistics()[1] 

    returnDrawdownRatio = round(backresult['returnDrawdownRatio'],2)  #收益回撤比
    sharpeRatio= round(backresult['sharpeRatio'],2)                   #夏普比率
    return returnDrawdownRatio , sharpeRatio

 

3.运行基于Deap库的遗传算法(具体步骤看代码中文注释)

#设置优化方向:最大化收益回撤比,最大化夏普比率
creator.create("FitnessMulti", base.Fitness, weights=(1.0, 1.0)) # 1.0 求最大值;-1.0 求最小值
creator.create("Individual", list, fitness=creator.FitnessMulti)

def optimize():
    """"""   
    toolbox = base.Toolbox()  #Toolbox是deap库内置的工具箱,里面包含遗传算法中所用到的各种函数

    # 初始化     
    toolbox.register("individual", tools.initIterate, creator.Individual,parameter_generate) # 注册个体:随机生成的策略参数parameter_generate()                                          
    toolbox.register("population", tools.initRepeat, list, toolbox.individual)               #注册种群:个体形成种群                                    
    toolbox.register("mate", tools.cxTwoPoint)                                               #注册交叉:两点交叉  
    toolbox.register("mutate", tools.mutUniformInt,low = 4,up = 40,indpb=0.6)                #注册变异:随机生成一定区间内的整数
    toolbox.register("evaluate", object_func)                                                #注册评估:优化目标函数object_func()    
    toolbox.register("select", tools.selNSGA2)                                               #注册选择:NSGA-II(带精英策略的非支配排序的遗传算法)


    #遗传算法参数设置
    MU = 40                                  #设置每一代选择的个体数
    LAMBDA = 160                             #设置每一代产生的子女数
    pop = toolbox.population(400)            #设置族群里面的个体数量
    CXPB, MUTPB, NGEN = 0.5, 0.35, 40        #分别为种群内部个体的交叉概率、变异概率、产生种群代数
    hof = tools.ParetoFront()                #解的集合:帕累托前沿(非占优最优集)

    #解的集合的描述统计信息
    #集合内平均值,标准差,最小值,最大值可以体现集合的收敛程度
    #收敛程度低可以增加算法的迭代次数
    stats = tools.Statistics(lambda ind: ind.fitness.values)
    np.set_printoptions(suppress=True)            #对numpy默认输出的科学计数法转换
    stats.register("mean", np.mean, axis=0)       #统计目标优化函数结果的平均值
    stats.register("std", np.std, axis=0)         #统计目标优化函数结果的标准差
    stats.register("min", np.min, axis=0)         #统计目标优化函数结果的最小值
    stats.register("max", np.max, axis=0)         #统计目标优化函数结果的最大值

    #运行算法
    algorithms.eaMuPlusLambda(pop, toolbox, MU, LAMBDA, CXPB, MUTPB, NGEN, stats,
                              halloffame=hof)     #esMuPlusLambda是一种基于(μ+λ)选择策略的多目标优化分段遗传算法

    return pop

 
 
 

遗传算法效果


夏普比率1.9,总收益率1958%,最大百分比回撤37%,收益回撤比达53。

enter image description here
其解集收敛程度如下:
enter image description here

在得到一个好的曲线后,还要检查一下这些参数是否符合市场逻辑,尽量去避免过拟合的情况。下面举个反例:

这里使用金肯特纳通道+基于固定百分比移动止损策略。不管从曲线的形态和收敛程度来看都是挺正常的

enter image description here
enter image description here
但是在这种优化后参数中我们观察到trailingPercent=18%,这意味着价格从最高点回落18个点才平仓离场。在正常情况,这会带来非常糟糕的盈亏比。
enter image description here

 
 
 

总结

遗传算法本质上是一种加快策略研究的技术,相对于暴力穷举,它可以大大节省电脑运算时间。我们可以使用它,但不能过度依赖它,因为有可能输出的仅仅是一些很巧合的参数。所以,针对这些参数,需要做更加细致的回测。

毕竟,在策略研究中,细心与耐心也是非常重要的。



修改数据保存方式 —— 保存至数据库中指定的数据表/集合

在使用VNPY时,只有两张表,一个是bar数据,一个是tick,实际使用中,这个数据库文件相当大,检索很费时。并看到有人也遇到了同样的困惑,建议vnpy-2.0.9 用关系型数据库

于是,我尝试基于vnpy 2.0.8,修改数据为 MongoDB,并将数据保存至指定的 Collection,同时保证当未指定 Collection 时,数据仍保存在原有的数据表中。适用于 tick数据 和 bar 数据。分享出来,如有错误,希望指正。本文也同时发于我的 CSDN 博客

更改数据库为 MongoDB

注意在MongoDB中需要创建新数据库,如“vnpytest”,然后在全局配置对话框中,修改相关配置(或直接修改vnpy运行目录 .vntrader 下的 vt_setting.json 文件):

"database.driver": "mongodb",
"database.database": "vnpytest",
"database.host": "localhost",
"database.port": 27017,
"database.user": "",
"database.password": "",
"database.authentication_source": ""

注意输入上述内容到配置对话框中时,请忽略引号。修改完毕保存后,请重新启动VN Trader,检查相关配置是否已经修改成功。

保存数据至指定的 mongodb collection

数据来源可以是 csv 文件,也可以从开源数据如 tushare (https://tushare.pro/register?reg=347489) 获取,等等,这里不再给出。

  1. 新建py文件,写入主函数
from vnpy.trader.database import database_manager
from vnpy.trader.object import BarData
from vnpy.trader.constant import Interval, Exchange


def ts2bar(df, collection_name=None):
    bars = []
    for i in range(df.shape[0]):
        bar = BarData(
            symbol=df.ts_code[i].split('.')[0],
            exchange=exc,
            datetime=df.trade_date[i],
            interval=Interval.DAILY,
            volume=df.vol[i],
            open_price=df.open[i],
            high_price=df.high[i],
            low_price=df.low[i],
            close_price=df.close[i],
            gateway_name='DB',
        )
        bars.append(bar)
    collection_name = collection_name
    print('Saving data to database ...')
    database_manager.save_bar_data(bars, collection_name)
  1. 修改 vnpy/trader/database/database.py 中的 save_bar_datasave_tick_data
@abstractmethod
def save_bar_data(
    self,
    datas: Sequence["BarData"],
    collection_name: str = None,
):
    pass

@abstractmethod
def save_tick_data(
    self,
    datas: Sequence["TickData"],
    collection_name: str = None,
):
    pass
  1. 修改 vnpy/trader/database/database_mongo.py 中的 save_bar_datasave_tick_data
def save_bar_data(self, datas: Sequence[BarData], collection_name: str = None):
        for d in datas:
            ...

            if collection_name is None:
                (
                    DbBarData.objects(
                        symbol=d.symbol, interval=d.interval.value, datetime=d.datetime
                    ).update_one(upsert=True, **updates)
                )
            else:
                with switch_collection(DbBarData, collection_name):
                    (
                        DbBarData.objects(
                            symbol=d.symbol, interval=d.interval.value, datetime=d.datetime
                        ).update_one(upsert=True, **updates)
                    )

    def save_tick_data(self, datas: Sequence[TickData], collection_name: str = None):
        for d in datas:
            ...

            if collection_name is None:
                (
                    DbTickData.objects(
                        symbol=d.symbol, exchange=d.exchange.value, datetime=d.datetime
                    ).update_one(upsert=True, **updates)
                )
            else:
                with switch_collection(DbTickData, collection_name):
                    (
                        DbTickData.objects(
                            symbol=d.symbol, exchange=d.exchange.value, datetime=d.datetime
                        ).update_one(upsert=True, **updates)
                    )

对于 MySQL 数据库,修改方法应该是类似的。



为K线图表添砖加瓦——MACD

看完了陈老师的线上公开课,化了2天时间终于把MACD幅图曲线给添加上了。
MACD曲线和RSI,SMA之类的不同之处在于它的y方向显示范围是可变的,需要根据K线显示范围的变化及时做出调整,有执行效率问题。
本人采用了字典保存了已经计算的y方向显示范围计算结果,避免了重复计算,执行效率还是相当流畅的。当然会需要一定的存储开销,但
是不大,而且也是值得开销的。代码如下:

from datetime import datetime
from typing import List, Tuple, Dict

import numpy as np
import pyqtgraph as pg
import talib
import copy

from vnpy.trader.ui import create_qapp, QtCore, QtGui, QtWidgets
from vnpy.trader.database import database_manager
from vnpy.trader.constant import Exchange, Interval
from vnpy.trader.object import BarData

from vnpy.chart import ChartWidget, VolumeItem, CandleItem
from vnpy.chart.item import ChartItem
from vnpy.chart.manager import BarManager
from vnpy.chart.base import NORMAL_FONT


class LineItem(CandleItem):
    """"""

    def __init__(self, manager: BarManager):
        """"""
        super().__init__(manager)

        self.white_pen: QtGui.QPen = pg.mkPen(color=(255, 255, 255), width=1)

    def _draw_bar_picture(self, ix: int, bar: BarData) -> QtGui.QPicture:
        """"""
        last_bar = self._manager.get_bar(ix - 1)

        # Create objects
        picture = QtGui.QPicture()
        painter = QtGui.QPainter(picture)

        # Set painter color
        painter.setPen(self.white_pen)

        # Draw Line
        end_point = QtCore.QPointF(ix, bar.close_price)

        if last_bar:
            start_point = QtCore.QPointF(ix - 1, last_bar.close_price)
        else:
            start_point = end_point

        painter.drawLine(start_point, end_point)

        # Finish
        painter.end()
        return picture


class SmaItem(CandleItem):
    """"""

    def __init__(self, manager: BarManager):
        """"""
        super().__init__(manager)

        self.blue_pen: QtGui.QPen = pg.mkPen(color=(100, 100, 255), width=2)

        self.sma_window = 10
        self.sma_data: Dict[int, float] = {}

    def get_sma_value(self, ix: int) -> float:
        """"""
        if ix < 0:
            return 0

        # When initialize, calculate all rsi value
        if not self.sma_data:
            bars = self._manager.get_all_bars()
            close_data = [bar.close_price for bar in bars]
            sma_array = talib.SMA(np.array(close_data), self.sma_window)

            for n, value in enumerate(sma_array):
                self.sma_data[n] = value

        # Return if already calcualted
        if ix in self.sma_data:
            return self.sma_data[ix]

        # Else calculate new value
        close_data = []
        for n in range(ix - self.sma_window, ix + 1):
            bar = self._manager.get_bar(n)
            close_data.append(bar.close_price)

        sma_array = talib.SMA(np.array(close_data), self.sma_window)
        sma_value = sma_array[-1]
        self.sma_data[ix] = sma_value

        return sma_value

    def _draw_bar_picture(self, ix: int, bar: BarData) -> QtGui.QPicture:
        """"""
        sma_value = self.get_sma_value(ix)
        last_sma_value = self.get_sma_value(ix - 1)

        # Create objects
        picture = QtGui.QPicture()
        painter = QtGui.QPainter(picture)

        # Set painter color
        painter.setPen(self.blue_pen)

        # Draw Line
        start_point = QtCore.QPointF(ix-1, last_sma_value)
        end_point = QtCore.QPointF(ix, sma_value)
        painter.drawLine(start_point, end_point)

        # Finish
        painter.end()
        return picture

    def get_info_text(self, ix: int) -> str:
        """"""
        if ix in self.sma_data:
            sma_value = self.sma_data[ix]
            text = f"SMA {sma_value:.1f}"
        else:
            text = "SMA -"

        return text


class RsiItem(ChartItem):
    """"""

    def __init__(self, manager: BarManager):
        """"""
        super().__init__(manager)

        self.white_pen: QtGui.QPen = pg.mkPen(color=(255, 255, 255), width=1)
        self.yellow_pen: QtGui.QPen = pg.mkPen(color=(255, 255, 0), width=2)

        self.rsi_window = 14
        self.rsi_data: Dict[int, float] = {}

    def get_rsi_value(self, ix: int) -> float:
        """"""
        if ix < 0:
            return 50

        # When initialize, calculate all rsi value
        if not self.rsi_data:
            bars = self._manager.get_all_bars()
            close_data = [bar.close_price for bar in bars]
            rsi_array = talib.RSI(np.array(close_data), self.rsi_window)

            for n, value in enumerate(rsi_array):
                self.rsi_data[n] = value

        # Return if already calcualted
        if ix in self.rsi_data:
            return self.rsi_data[ix]

        # Else calculate new value
        close_data = []
        for n in range(ix - self.rsi_window, ix + 1):
            bar = self._manager.get_bar(n)
            close_data.append(bar.close_price)

        rsi_array = talib.RSI(np.array(close_data), self.rsi_window)
        rsi_value = rsi_array[-1]
        self.rsi_data[ix] = rsi_value

        return rsi_value

    def _draw_bar_picture(self, ix: int, bar: BarData) -> QtGui.QPicture:
        """"""
        rsi_value = self.get_rsi_value(ix)
        last_rsi_value = self.get_rsi_value(ix - 1)

        # Create objects
        picture = QtGui.QPicture()
        painter = QtGui.QPainter(picture)

        # Draw RSI line
        painter.setPen(self.yellow_pen)

        if np.isnan(last_rsi_value) or np.isnan(rsi_value):
            # print(ix - 1, last_rsi_value,ix, rsi_value,)
            pass
        else:
            end_point = QtCore.QPointF(ix, rsi_value)
            start_point = QtCore.QPointF(ix - 1, last_rsi_value)
            painter.drawLine(start_point, end_point)

        # Draw oversold/overbought line
        painter.setPen(self.white_pen)

        painter.drawLine(
            QtCore.QPointF(ix, 70),
            QtCore.QPointF(ix - 1, 70),
        )

        painter.drawLine(
            QtCore.QPointF(ix, 30),
            QtCore.QPointF(ix - 1, 30),
        )

        # Finish
        painter.end()
        return picture

    def boundingRect(self) -> QtCore.QRectF:
        """"""
        # min_price, max_price = self._manager.get_price_range()
        rect = QtCore.QRectF(
            0,
            0,
            len(self._bar_picutures),
            100
        )
        return rect

    def get_y_range( self, min_ix: int = None, max_ix: int = None) -> Tuple[float, float]:
        """  """
        return 0, 100

    def get_info_text(self, ix: int) -> str:
        """"""
        if ix in self.rsi_data:
            rsi_value = self.rsi_data[ix]
            text = f"RSI {rsi_value:.1f}"
            # print(text)
        else:
            text = "RSI -"

        return text


def to_int(value: float) -> int:
    """"""
    return int(round(value, 0))

""" 将y方向的显示范围扩大到1.1 """
def adjust_range(in_range:Tuple[float, float])->Tuple[float, float]:
    ret_range:Tuple[float, float]
    diff = abs(in_range[0] - in_range[1])
    ret_range = (in_range[0]-diff*0.05,in_range[1]+diff*0.05)
    return ret_range

class MacdItem(ChartItem):
    """"""
    _values_ranges: Dict[Tuple[int, int], Tuple[float, float]] = {}

    last_range:Tuple[int, int] = (-1,-1)    # 最新显示K线索引范围

    def __init__(self, manager: BarManager):
        """"""
        super().__init__(manager)

        self.white_pen: QtGui.QPen = pg.mkPen(color=(255, 255, 255), width=1)
        self.yellow_pen: QtGui.QPen = pg.mkPen(color=(255, 255, 0), width=1)
        self.red_pen: QtGui.QPen = pg.mkPen(color=(255, 0, 0), width=1)
        self.green_pen: QtGui.QPen = pg.mkPen(color=(0, 255, 0), width=1)

        self.short_window = 12
        self.long_window = 26
        self.M = 9

        self.macd_data: Dict[int, Tuple[float,float,float]] = {}

    def get_macd_value(self, ix: int) -> Tuple[float,float,float]:
        """"""
        if ix < 0:
            return (0.0,0.0,0.0)

        # When initialize, calculate all macd value
        if not self.macd_data:
            bars = self._manager.get_all_bars()
            close_data = [bar.close_price for bar in bars]

            diffs,deas,macds = talib.MACD(np.array(close_data), 
                                    fastperiod=self.short_window, 
                                    slowperiod=self.long_window, 
                                    signalperiod=self.M)

            for n in range(0,len(diffs)):
                self.macd_data[n] = (diffs[n],deas[n],macds[n])

        # Return if already calcualted
        if ix in self.macd_data:
            return self.macd_data[ix]

        # Else calculate new value
        close_data = []
        for n in range(ix-self.long_window-self.M+1, ix + 1):
            bar = self._manager.get_bar(n)
            close_data.append(bar.close_price)

        diffs,deas,macds = talib.MACD(np.array(close_data), 
                                            fastperiod=self.short_window, 
                                            slowperiod=self.long_window, 
                                            signalperiod=self.M) 
        diff,dea,macd = diffs[-1],deas[-1],macds[-1]
        self.macd_data[ix] = (diff,dea,macd)

        return (diff,dea,macd)

    def _draw_bar_picture(self, ix: int, bar: BarData) -> QtGui.QPicture:
        """"""
        macd_value = self.get_macd_value(ix)
        last_macd_value = self.get_macd_value(ix - 1)

        # # Create objects
        picture = QtGui.QPicture()
        painter = QtGui.QPainter(picture)

        # # Draw macd lines
        if np.isnan(macd_value[0]) or np.isnan(last_macd_value[0]):
            # print("略过macd lines0")
            pass
        else:
            end_point0 = QtCore.QPointF(ix, macd_value[0])
            start_point0 = QtCore.QPointF(ix - 1, last_macd_value[0])
            painter.setPen(self.white_pen)
            painter.drawLine(start_point0, end_point0)

        if np.isnan(macd_value[1]) or np.isnan(last_macd_value[1]):
            # print("略过macd lines1")
            pass
        else:
            end_point1 = QtCore.QPointF(ix, macd_value[1])
            start_point1 = QtCore.QPointF(ix - 1, last_macd_value[1])
            painter.setPen(self.yellow_pen)
            painter.drawLine(start_point1, end_point1)

        if not np.isnan(macd_value[2]):
            if (macd_value[2]>0):
                painter.setPen(self.red_pen)
                painter.setBrush(pg.mkBrush(255,0,0))
            else:
                painter.setPen(self.green_pen)
                painter.setBrush(pg.mkBrush(0,255,0))
            painter.drawRect(QtCore.QRectF(ix-0.3,0,0.6,macd_value[2]))
        else:
            # print("略过macd lines2")
            pass

        painter.end()
        return picture

    def boundingRect(self) -> QtCore.QRectF:
        """"""
        min_y, max_y = self.get_y_range()
        rect = QtCore.QRectF(
            0,
            min_y,
            len(self._bar_picutures),
            max_y
        )
        return rect

    def get_y_range(self, min_ix: int = None, max_ix: int = None) -> Tuple[float, float]:
        #   获得3个指标在y轴方向的范围   
        #   hxxjava 修改,2020-6-29
        #   当显示范围改变时,min_ix,max_ix的值不为None,当显示范围不变时,min_ix,max_ix的值不为None,

        offset = max(self.short_window,self.long_window) + self.M - 1

        if not self.macd_data or len(self.macd_data) < offset:
            return 0.0, 1.0

        # print("len of range dict:",len(self._values_ranges),",macd_data:",len(self.macd_data),(min_ix,max_ix))

        if min_ix != None:          # 调整最小K线索引
            min_ix = max(min_ix,offset)

        if max_ix != None:          # 调整最大K线索引
            max_ix = min(max_ix, len(self.macd_data)-1)

        last_range = (min_ix,max_ix)    # 请求的最新范围   

        if last_range == (None,None):   # 当显示范围不变时
            if self.last_range in self._values_ranges:  
                # 如果y方向范围已经保存
                # 读取y方向范围
                result = self._values_ranges[self.last_range]
                # print("1:",self.last_range,result)
                return adjust_range(result)
            else:
                # 如果y方向范围没有保存
                # 从macd_data重新计算y方向范围
                min_ix,max_ix = 0,len(self.macd_data)-1

                macd_list = list(self.macd_data.values())[min_ix:max_ix + 1]
                ndarray = np.array(macd_list)           
                max_price = np.nanmax(ndarray)
                min_price = np.nanmin(ndarray)

                # 保存y方向范围,同时返回结果
                result = (min_price, max_price)
                self.last_range = (min_ix,max_ix)
                self._values_ranges[self.last_range] = result
                # print("2:",self.last_range,result)
                return adjust_range(result)

        """ 以下为显示范围变化时 """

        if last_range in self._values_ranges:
            # 该范围已经保存过y方向范围
            # 取得y方向范围,返回结果
            result = self._values_ranges[last_range]
            # print("3:",last_range,result)
            return adjust_range(result)

        # 该范围没有保存过y方向范围
        # 从macd_data重新计算y方向范围
        macd_list = list(self.macd_data.values())[min_ix:max_ix + 1]
        ndarray = np.array(macd_list) 
        max_price = np.nanmax(ndarray)
        min_price = np.nanmin(ndarray)

        # 取得y方向范围,返回结果
        result = (min_price, max_price)
        self.last_range = last_range
        self._values_ranges[self.last_range] = result
        # print("4:",self.last_range,result)
        return adjust_range(result)


    def get_info_text(self, ix: int) -> str:
        # """"""
        if ix in self.macd_data:
            diff,dea,macd = self.macd_data[ix]
            words = [
                f"diff {diff:.3f}"," ",
                f"dea {dea:.3f}"," ",
                f"macd {macd:.3f}"
                ]
            text = "\n".join(words)
        else:
            text = "diff - \ndea - \nmacd -"

        return text



class NewChartWidget(ChartWidget):
    """"""
    MIN_BAR_COUNT = 100

    def __init__(self, parent: QtWidgets.QWidget = None):
        """"""
        super().__init__(parent)

        self.last_price_line: pg.InfiniteLine = None

    def add_last_price_line(self):
        """"""
        plot = list(self._plots.values())[0]
        color = (255, 255, 255)

        self.last_price_line = pg.InfiniteLine(
            angle=0,
            movable=False,
            label="{value:.1f}",
            pen=pg.mkPen(color, width=1),
            labelOpts={
                "color": color,
                "position": 1,
                "anchors": [(1, 1), (1, 1)]
            }
        )
        self.last_price_line.label.setFont(NORMAL_FONT)
        plot.addItem(self.last_price_line)

    def update_history(self, history: List[BarData]) -> None:
        """
        Update a list of bar data.
        """
        self._manager.update_history(history)

        for item in self._items.values():
            item.update_history(history)

        self._update_plot_limits()

        self.move_to_right()

        self.update_last_price_line(history[-1])

    def update_bar(self, bar: BarData) -> None:
        """
        Update single bar data.
        """
        self._manager.update_bar(bar)

        for item in self._items.values():
            item.update_bar(bar)

        self._update_plot_limits()

        if self._right_ix >= (self._manager.get_count() - self._bar_count / 2):
            self.move_to_right()

        self.update_last_price_line(bar)

    def update_last_price_line(self, bar: BarData) -> None:
        """"""
        if self.last_price_line:
            self.last_price_line.setValue(bar.close_price)


if __name__ == "__main__":
    app = create_qapp()

    # bars = database_manager.load_bar_data(
    #     "IF888",
    #     Exchange.CFFEX,
    #     interval=Interval.MINUTE,
    #     start=datetime(2019, 7, 1),
    #     end=datetime(2019, 7, 17)
    # )

    symbol = "rb2010"
    exchange = Exchange.SHFE
    interval=Interval.MINUTE
    start=datetime(2020, 6, 1)
    end=datetime(2020, 6, 30)    

    dynamic = False  # 是否动态演示
    n = 200          # 缓冲K线根数


    bars = database_manager.load_bar_data(
        symbol=symbol,
        exchange=exchange,
        interval=interval,
        start=start,
        end=end
    )

    widget = NewChartWidget()
    widget.setWindowTitle(f"K线图表——{symbol}.{exchange.value},{interval},{start}-{end}")
    widget.add_plot("candle", hide_x_axis=True)
    widget.add_plot("volume", maximum_height=150)
    widget.add_plot("rsi", maximum_height=150)
    widget.add_plot("macd", maximum_height=150)
    widget.add_item(CandleItem, "candle", "candle")
    widget.add_item(VolumeItem, "volume", "volume")

    widget.add_item(LineItem, "line", "candle")
    widget.add_item(SmaItem, "sma", "candle")
    widget.add_item(RsiItem, "rsi", "rsi")
    widget.add_item(MacdItem, "macd", "macd")
    widget.add_last_price_line()
    widget.add_cursor()

    if dynamic:
        history = bars[:n]      # 先取得最早的n根bar作为历史
        new_data = bars[n:]     # 其它留着演示
    else:
        history = bars[-n:]     # 先取得最新的n根bar作为历史
        new_data = []           # 演示的为空

    widget.update_history(history)

    def update_bar():
        if new_data:
            bar = new_data.pop(0)
            widget.update_bar(bar)

    timer = QtCore.QTimer()
    timer.timeout.connect(update_bar)
    if dynamic:
        timer.start(100)

    widget.show()
    app.exec_()


谈谈对app\cta_backtester\widget.py中CandleChartDialog的看法

1. 先看看CandleChartDialog的代码

class CandleChartDialog(QtWidgets.QDialog):
    """
    """

    def __init__(self):
        """"""
        super().__init__()

        self.dt_ix_map = {}
        self.updated = False
        self.init_ui()

    def init_ui(self):
        """"""
        self.setWindowTitle("回测K线图表")
        self.resize(1400, 800)

        # Create chart widget
        self.chart = ChartWidget()
        self.chart.add_plot("candle", hide_x_axis=True)
        self.chart.add_plot("volume", maximum_height=200)
        self.chart.add_item(CandleItem, "candle", "candle")
        self.chart.add_item(VolumeItem, "volume", "volume")
        self.chart.add_cursor()

        # Add scatter item for showing tradings
        self.trade_scatter = pg.ScatterPlotItem()
        candle_plot = self.chart.get_plot("candle")
        candle_plot.addItem(self.trade_scatter)

        # Set layout
        vbox = QtWidgets.QVBoxLayout()
        vbox.addWidget(self.chart)
        self.setLayout(vbox)

    def update_history(self, history: list):
        """"""
        self.updated = True
        self.chart.update_history(history)

        for ix, bar in enumerate(history):
            self.dt_ix_map[bar.datetime] = ix

    def update_trades(self, trades: list):
        """"""
        trade_data = []

        for trade in trades:
            ix = self.dt_ix_map[trade.datetime]

            scatter = {
                "pos": (ix, trade.price),
                "data": 1,
                "size": 14,
                "pen": pg.mkPen((255, 255, 255))
            }

            if trade.direction == Direction.LONG:
                scatter_symbol = "t1"   # Up arrow
            else:
                scatter_symbol = "t"    # Down arrow

            if trade.offset == Offset.OPEN:
                scatter_brush = pg.mkBrush((255, 255, 0))   # Yellow
            else:
                scatter_brush = pg.mkBrush((0, 0, 255))     # Blue

            scatter["symbol"] = scatter_symbol
            scatter["brush"] = scatter_brush

            trade_data.append(scatter)

        self.trade_scatter.setData(trade_data)

    def clear_data(self):
        """"""
        self.updated = False
        self.chart.clear_all()

        self.dt_ix_map.clear()
        self.trade_scatter.clear()

    def is_updated(self):
        """"""
        return self.updated

2. 再看看self.trade_scatter成员

2.1 一个特别的成员self.trade_scatter

它和其他的绘图部件不同,其他的都是ChartItem类型继承得到,只有它例外,它是一个ScatterPlotItem。

self.trade_scatter = pg.ScatterPlotItem()

2.2 问题:CandleChartDialog不可以在非回测环境下使用

看看update_trades() 的代码中

for trade in trades:
     ix = self.dt_ix_map[trade.datetime]   # 查找一个成交单(trade)是属于哪个BarData的索引

这里能够不出错,完全是因为在回测中,人为固定地把发出交易信号的那个bar的开始时间datetime赋值给了trade.datetime!
让我们来看看app\cta_strategy\backtesting.py中的class BacktestingEngine,它在bar模式的时候,是这样生成成交单的:

            trade = TradeData(
                symbol=order.symbol,
                exchange=order.exchange,
                orderid=order.orderid,
                tradeid=str(self.trade_count),
                direction=order.direction,
                offset=order.offset,
                price=trade_price,
                volume=order.volume,
                datetime=self.datetime,
                gateway_name=self.gateway_name,
            )

其中trade中的datetime=self.datetime,而self.datetime是这样赋值的:

    def run_backtesting(self):
        """"""
        if self.mode == BacktestingMode.BAR:
            func = self.new_bar
        else:
            func = self.new_tick

        self.strategy.on_init()

        # Use the first [days] of history data for initializing strategy
        day_count = 1
        ix = 0

        for ix, data in enumerate(self.history_data):
            if self.datetime and data.datetime.day != self.datetime.day:
                day_count += 1
                if day_count >= self.days:
                    break

            self.datetime = data.datetime      # 它是用self.history_data中代表bar的data.datetime赋值的!
            ... ...

可是我们知道trade.datetime是不可能总是恰好等于bar的开始时间datetime的,成交时间可能是一个bar形成期间的任何时间。
回到CandleChartDialog,由上面可知self.dt_ix_map的维护是update_history()维护的,它是由bar的开始时间datetime和其顺序ix构成的一个字典。

2.3 如果trade是实际的成交单,使用CandleChartDialog是一定会出错

原因是trade.datetime在self.dt_ix_map字典的键值中大概率是不存在的,因此ix = self.dt_ix_map[trade.datetime]语句会出错!而ScatterPlotItem是通过"pos": (ix, trade.price)来确定代表买卖的上下三角形来绘图的,因此CandleChartDialog是不可以简单地在显示实盘成交单的地方引用的。

3. 怎么修改既可显示回测成交单,又可以显示实盘成交单办?

思路:

3.1 增加如下成员:

     self.trades:Dict[int,Dict[str,TradeData]] = {}  # 其键值为成交发生bar的索引,内容为成交单字典

3.2 增加add_trade():

根据trade.datetime字段在self.dt_ix_map中所在位置的两个相邻时间dt0和dt1
满足:

def add_trade(self,trade:TradeData):
  # 这里使用reverse=True,是考虑到实盘成交往往发生在最新的bar里,可以加快搜索速度
  od = OrderedDict(sorted(self.dt_ix_map.items(),key = lambda t:t[0],reverse=True))
  idx = 0
  for dt,ix in od.items():
    if dt <= trade.datetime:
        idx = idx
        break
  # 注意:一个bar期间可能发生多个成交单
  if idx not in self.trades:
      self.trades[idx] = {trade.tradeid: trade}
  else:
      self.trades[idx][trade.tradeid] = trade

3.3 update_trades()这么修改就OK了。

    def update_trades(self, trades: list):
        """"""
        for trade in trades:
             # 寻找每个成交单对应的bar的索引,并且保证为字典
             self.add_trade(trade)

        trade_data = []
        for ix in self.trades:
            for tradeid,trade in self.trades[ix].items():
                scatter = {
                    "pos": (ix, trade.price),
                    "data": 1,
                    "size": 14,
                    "pen": pg.mkPen((255, 255, 255))
                }

                if trade.direction == Direction.LONG:
                    scatter_symbol = "t1"   # Up arrow
                else:
                    scatter_symbol = "t"    # Down arrow

                if trade.offset == Offset.OPEN:
                    scatter_brush = pg.mkBrush((255, 255, 0))   # Yellow
                else:
                    scatter_brush = pg.mkBrush((0, 0, 255))     # Blue

                scatter["symbol"] = scatter_symbol
                scatter["brush"] = scatter_brush

                trade_data.append(scatter)

        self.trade_scatter.setData(trade_data)

4. 不一样的成交单显示解决办法

参见:典型绘图部件及使用方法


统计

主题
4147
帖子
15185
已注册用户
15711
最新用户
在线用户
263
在线来宾用户
278
© 2015-2019 上海韦纳软件科技有限公司
备案服务号:沪ICP备18006526号-3