vn.py量化社区
By Traders, For Traders.
Super Moderator
avatar
加入于:
帖子: 36
声望: 7

发布于vn.py社区公众号【vnpy-community】
 

原文作者:庸木 | 发布时间: 2020-02-25
 

上一期文章《vn.py社区精选18 - 老用户福音,MongDB分表重构!》,但在数据读写方面上只讲到了“写”的部分(将历史数据写入重构后的数据库)。
 

本期文章就来补完之前缺失的“读”的部分,讲解如何从重构后的数据库中读取数据用于策略的回测和优化。整体内容分为两块:
 

  • 数据库对接层中,对于读取函数实现的修改
  • CTA回测引擎中,对于读取函数调用的修改

 

数据库对接层改造

 

database.py文件

 

找到vn.py源代码所在的路径,使用VN Studio的情况下应该位于C:\vnstudio\Lib\site-packages\vnpy,进入到目录vnpy\trader\database下找到database.py 文件。
 

我们需要对load_bar_dataload_tick_data这两个函数进行修改,增加了一个可选参数collection_name用于指明需要读取数据的特定集合名称:
 

@abstractmethod
def load_bar_data(
    self,
    symbol: str,
    exchange: "Exchange",
    interval: "Interval",
    start: datetime,
    end: datetime,
    collection_name: str = None
    ) -> Sequence["BarData"]:

    pass

@abstractmethod
def load_tick_data(
    self,
    symbol: str,
    exchange: "Exchange",
    start: datetime,
    end: datetime,
    collection_name: str = None
    ) -> Sequence["TickData"]:

    pass

 

database_mongo.py文件

 

然后打开位于同一目录下的database_mongo.py文件,同样我们需要对save_bar_datasave_tick_data进行修改。
 

改动后的数据读取逻辑如下:
 

  • 若没有指定collection_name,则从db_bar_data或db_tick_data中读取数据;
  • 若指定了collection_name,则调用switch_collection函数,从指定的集合中读取数据。
     
def load_bar_data(
    self,
    symbol: str,
    exchange: Exchange,
    interval: Interval,
    start: datetime,
    end: datetime,
    collection_name: str = None,
    ) -> Sequence[BarData]:
    if collection_name is None:
      s = DbBarData.objects(
      symbol=symbol,
      exchange=exchange.value,
      interval=interval.value,
      datetime__gte=start,
      datetime__lte=end,
      )
    else:
      with switch_collection(DbBarData, collection_name):
      s = DbBarData.objects(
        symbol=symbol,
        exchange=exchange.value,
        interval=interval.value,
        datetime__gte=start,
        datetime__lte=end,
      )
    data = [db_bar.to_bar() for db_bar in s]
    return data
def load_tick_data(
    self,
    symbol: str,
    exchange: Exchange,
    start: datetime,
    end: datetime,
    collection_name: str = None,
    ) -> Sequence[TickData]:
    if collection_name is None:
      s = DbTickData.objects(
          symbol=symbol,
          exchange=exchange.value,
          datetime__gte=start,
          datetime__lte=end,
      )
    else:
        with switch_collection(DbTickData, collection_name):
          s = DbTickData.objects(
              symbol=symbol,
              exchange=exchange.value,
              datetime__gte=start,
              datetime__lte=end,
          )
    data = [db_tick.to_tick() for db_tick in s]
    return data

 

CTA回测引擎改造

 

backtesting.py文件

 

进入目录vnpy\app\cta_strategy打开backtesting.py文件,这次我们要修改的函数比较多,为了方便整理,基于函数功能大致把它们分成两块:
 

  • 策略回测部分
  • 参数优化部分
     

策略回测部分

 

  • load_bar_dataload_tick_data增加一个入参collection_name:
     
@lru_cache(maxsize=999)
def load_bar_data(
    symbol: str,
    exchange: Exchange,
    interval: Interval,
    start: datetime,
    end: datetime,
    collection_name: str = None
):
    """"""
    return database_manager.load_bar_data(
        symbol, exchange, interval, start, end, collection_name
    )
@lru_cache(maxsize=999)
def load_tick_data(
    symbol: str,
    exchange: Exchange,
    start: datetime,
    end: datetime,
    collection_name: str = None
):
    """"""
    return database_manager.load_tick_data(
        symbol, exchange, start, end, collection_name
    )

 

  • 在BacktestingEngine类的__init__函数中增加新的类属性collection_name用于自定义要使用的MongoDB集合名称:
     
def __init__(self):

        """"""
  self.vt_symbol = ""
  self.symbol = ""
  self.exchange = None
  self.start = None
  self.end = None
  self.rate = 0
  self.slippage = 0
  self.size = 1
  self.pricetick = 0
  self.capital = 1_000_000
  self.mode = BacktestingMode.BAR
  self.inverse = False
  self.collection_name = None

 

  • set_parameters进行修改,增加一个入参collection_name,并把该参数的值绑定到类属性collection_name:
     
def set_parameters(
        self,
        vt_symbol: str,
        interval: Interval,
        start: datetime,
        rate: float,
        slippage: float,
        size: float,
        pricetick: float,
        capital: int = 0,
        end: datetime = None,
        mode: BacktestingMode = BacktestingMode.BAR,
        inverse: bool = False,
        collection_name: str = None
    ):
    """"""

    self.mode = mode
    self.vt_symbol = vt_symbol
    self.interval = Interval(interval)
    self.rate = rate
    self.slippage = slippage
    self.size = size
    self.pricetick = pricetick
    self.start = start

    self.symbol, exchange_str = self.vt_symbol.split(".")
    self.exchange = Exchange(exchange_str)

    self.capital = capital
    self.end = end
    self.mode = mode
    self.inverse = inverse
    self.collection_name = collection_name

 

  • load_data函数进行修改,增加一个类属性collection_name。由于load_data函数涉及逻辑比较多,下面只显示修改代码部分,其他用省略号(...)来代替:
     
...
            if self.mode == BacktestingMode.BAR:
                data = load_bar_data(
                    self.symbol,
                    self.exchange,
                    self.interval,
                    start,
                    end,
                    self.collection_name,
                )
            else:
                data = load_tick_data(
                    self.symbol,
                    self.exchange,
                    start,
                    end,
                    self.collection_name,
                )
···

 

参数优化部分

 

在参数优化的代码中,也需要对数据的读取位置进行修改,因为默认优化函数读取的数据表/集合仍然是默认的da_bar_datadb_tick_data

 

  • 首先增加全局变量collection_name, 并对它进行赋值:
     
global ga_collection_name

ga_collection_name = self.collection_name

 

  • 修改_ga_optimizeoptimize函数,增加入参collection_name:
     
def optimize(
    target_name: str,
    strategy_class: CtaTemplate,
    setting: dict,
    vt_symbol: str,
    interval: Interval,
    start: datetime,
    rate: float,
    slippage: float,
    size: float,
    pricetick: float,
    capital: int,
    end: datetime,
    mode: BacktestingMode,
    inverse: bool,
    collection_name: str = None
):
    """
    Function for running in multiprocessing.pool
    """
    engine = BacktestingEngine()

    engine.set_parameters(
        vt_symbol=vt_symbol,
        interval=interval,
        start=start,
        rate=rate,
        slippage=slippage,
        size=size,
        pricetick=pricetick,
        capital=capital,
        end=end,
        mode=mode,
        inverse=inverse,
        collection_name=collection_name
    )

    engine.add_strategy(strategy_class, setting)
    engine.load_data()
    engine.run_backtesting()
    engine.calculate_result()
    statistics = engine.calculate_statistics(output=False)

    target_value = statistics[target_name]
    return (str(setting), target_value, statistics)
@lru_cache(maxsize=1000000)
def _ga_optimize(parameter_values: tuple):
    """"""
    setting = dict(parameter_values)

    result = optimize(
        ga_target_name,
        ga_strategy_class,
        setting,
        ga_vt_symbol,
        ga_interval,
        ga_start,
        ga_rate,
        ga_slippage,
        ga_size,
        ga_pricetick,
        ga_capital,
        ga_end,
        ga_mode,
        ga_inverse,
        ga_collection_name,
    )
    return (result[1],)

 

Jupyter Notebook使用示例

 

最后我们同样可以通过Jupyter Notebook来测试下使用效果,在上一期文章的结尾,我们已经把XBTUSD数据导入到新的集合【XBTUSD】中。
 

现在我们要从这个新的集合读取数据进行回测,需要修改的部分同样很简单,只要调用set_parameters函数时,增加参数collection_name="XBTUSD"即可:
 

engine = BacktestingEngine()
engine.set_parameters(
    vt_symbol="XBTUSD.BITMEX",
    interval="1h",
    start=datetime(2018, 1, 1),
    end=datetime(2019, 1, 1),
    rate=1/10000,
    slippage=0.5,
    size=10,
    pricetick=0.5,
    capital=1_000_000,
    collection_name = "XBTUSD"
)
engine.add_strategy(AtrRsiStrategy, {})

 
 

《vn.py全实战进阶 - 期权零基础入门》课程已经更新到第12集,内容专门面向从未接触过期权交易的新手,共计30节课程带你一步步掌握期权的基础知识、了解合约特征和品种细节、学习方向交易和套利组合等各种常用期权交易策略,详细内容请戳新课上线:《期权零基础入门》

Super Moderator
avatar
加入于:
帖子: 36
声望: 7

参数优化部分补充

 

在参数优化部分,除了上楼提到的_ga_optimizeoptimize函数, 我们还需要对 run_optimization函数进行修改。具体修改流程如下:

 

  1. 跳转到 C:\vnstudio\Lib\site-packages\vnpy\app\cta_strategy 路径下, 找到 backtesting.py 文件
  2. 在代码编辑器中,按CTRL+F ,搜索并定位到 run_optimization函数, 该函数是 BacktestingEngine 类的一个方法。
  3. run_optimization函数中,对使用多进程运行optimize函数时传入的参数进行修改。在传入的参数中加入self.collection_name。参见下图:
     

description

 

原因解释

 

在进行参数优化的过程中,optimize函数只负责将一组参数放入回测引擎中进行回测。而 BacktestingEngine 类中的run_optimization函数负责将所有可能参数的组合以多进程的方式放入回测引擎进行回测(通过不断调用optimize函数来实现)。因此,在修改了optimize函数之后,我们必须要对run_optimization函数进行修改。

Super Moderator
avatar
加入于:
帖子: 36
声望: 7

Portfolio_strategy 回测引擎部分补充

 

如果,要在 vn.py 新推出的Portfolio_strategy的回测模块中实现分表读取和储存数据,则需要进行如下修改:

  1. BacktestingEngine__init__(self) 函数下,定义一个新的类属性:self.collection_names: Dict[str, str] = {}
  2. 修改 BacktestingEngineset_parameters 函数,增加新的参数 collection_names: Dict[str, str] = None。具体代码:
def set_parameters(
        self,
        vt_symbols: List[str],
        interval: Interval,
        start: datetime,
        rates: Dict[str, float],
        slippages: Dict[str, float],
        sizes: Dict[str, float],
        priceticks: Dict[str, float],
        capital: int = 0,
        end: datetime = None,
        collection_names: Dict[str, str] = None
    ) -> None:
        """"""
        self.vt_symbols = vt_symbols
        self.interval = interval

        self.rates = rates
        self.slippages = slippages
        self.sizes = sizes
        self.priceticks = priceticks

        self.start = start
        self.end = end
        self.capital = capital
        self.collection_names = collection_names
  1. 修改 BacktestingEngineload_data 函数,修改部分代码如下:

description

  1. 修改位于BacktestingEngine外部的 load_data 函数,修改部分代码如下:
@lru_cache(maxsize=999)
def load_bar_data(
    vt_symbol: str,
    interval: Interval,
    start: datetime,
    end: datetime,
    collection_name: str = None
):
    """"""
    symbol, exchange = extract_vt_symbol(vt_symbol)

    return database_manager.load_bar_data(
        symbol, exchange, interval, start, end, collection_name
    )

Portfolio_strategy 回测引擎修改 Jupyter notebook 展示

在 Jupyter notebook中使用Portfolio_strategy 回测来验证上面修改过程:

description

Member
avatar
加入于:
帖子: 17
声望: 0

我跟着设置以后, 用交互式的方法在examples目录下打开jupyter notebook运行backtesting_demo.ipynb会显示错误

2020-10-26 13:53:55.815079 开始加载历史数据

TypeError Traceback (most recent call last)

<ipython-input-44-06c0f59cbee0> in <module>
1 #%%
----> 2 engine.load_data()
3 engine.run_backtesting()
4 df = engine.calculate_result()
5 engine.calculate_statistics()

d:\vnstudio d\lib\site-packages\vnpy\app\cta_strategy\backtesting.py in load_data(self)
246 start,
247 end,
--> 248 collection_name
249 )
250 else:

d:\vnstudio d\lib\site-packages\vnpy\app\cta_strategy\backtesting.py in load_bar_data(symbol, exchange, interval, start, end, collection_name)
1291 ):
1292 """"""
-> 1293 return database_manager.load_bar_data(
1294 symbol, exchange, interval, start, end,collection_name
1295 )

TypeError: load_bar_data() takes 6 positional arguments but 7 were given

我用的是RQdata数据

Member
avatar
加入于:
帖子: 912
声望: 45

建议检查一下你的load_bar_data函数吧

© 2015-2019 上海韦纳软件科技有限公司
备案服务号:沪ICP备18006526号-3