VeighNa量化社区
你的开源社区量化交易平台
Member
avatar
加入于:
帖子: 9
声望: 0

这个报错我遇见过好几次了,就想看看到底是个啥情况,然后有了这个帖子用于记录和分享。

先写我的结论:
函数报错,因为入参“均线窗口”的最小值是2,你如果传进去1之类的值那肯定就报错了。
同时,talib.SMA可以用talib.MA替代,这样参数“均线窗口”为1时,也能正常运行。具体的代码示例如下:

# 函数talib.SMA
talib.SMA(numpy.array([i * 0.1 for i in range(1, 9)]), 2)
# 与之等价的talib.MA的写法
talib.MA(numpy.array([i * 0.1 for i in range(1, 9)]), 2, talib.MA_Type.SMA)

下面是细碎的验证过程,不想看的可以直接关闭了。

入参“均线窗口”的最小值是2,是在C代码层面就限制了的。C代码的截图如下所示:
description
所以这个报错怨不得python包,那么,能做的就是,
要么接受参数最小值是2,
要么找一个参数最小值是1的替代函数(我个人有参数最小值是1的需要),

我偶然看到了 talib.MA,然后认为 matype=talib.MA_Type.SMA 时的 talib.MA 等价于 talib.SMA,
然后写了个脚本进行验证,脚本执行正常,没发现什么问题,脚本内容如下:

# !/usr/bin/env python3
# coding=utf8
"""
这个脚本不严谨地验证了 talib.SMA 和 (参数为 talib.MA_Type.SMA 时的)talib.MA 是等价的,
talib.SMA(numpy.array([i * 0.1 for i in range(1, 9)]), 2)
talib.MA( numpy.array([i * 0.1 for i in range(1, 9)]), 2, talib.MA_Type.SMA)
"""
import json
import numpy
import random
import talib
from typing import Any, Dict, List, Set, Tuple, Type, Optional, Union, Callable


def gen_price_list(base: float, count: int) -> List[float]:
    """
    以 base 为基准, 每次大致在涨跌 10% 的范围内波动, 生成 count 个数据
    """
    values: List[float] = [float(base)]
    for _ in range(count - 1):
        dn: float = values[-1] * (1 - 0.1)
        up: float = values[-1] * (1 + 0.1)
        value: float = round(random.uniform(a=dn, b=up), 2)
        values.append(value)
    return values


def check_once(base: float, count: int, timeperiod: int):
    """
    验证 talib.SMA 和 (参数为 talib.MA_Type.SMA 时的)talib.MA 是等价的,
    验证 prices 和参数为 1 时的 talib.MA 是等价的,
    """
    base: float = round(base, 2)
    values: List[float] = gen_price_list(base=base, count=count)
    prices: numpy.ndarray = numpy.array(object=values)
    smaRet: numpy.ndarray = talib.SMA(prices, timeperiod)
    ma1Ret: numpy.ndarray = talib.MA(prices, timeperiod, talib.MA_Type.SMA)
    ma2Ret: numpy.ndarray = talib.MA(prices, 1, talib.MA_Type.SMA)
    if (ma1Ret.all() != smaRet.all()) or (ma2Ret.all() != prices.all()):
        stat: dict = {"base": base, "count": count, "timeperiod": timeperiod, "values": values, }
        raise RuntimeError(json.dumps(obj=stat))
    return True


def check(total: int):
    """"""
    counter: int = 0
    for _ in range(total):
        base: float = round(random.uniform(a=1.0, b=1000.0), 2)
        count: int = random.randint(1, 10000)
        timeperiod: int = random.randint(2, count if 2 <= count else 2)  # talib.SMA 的 timeperiod 的最小是 2
        check_once(base=base, count=count, timeperiod=timeperiod)
        counter += 1
        if counter % 1000 == 0:
            print(f"total={total}, counter={counter}, {round(counter/total*100,2)}%,")


if __name__ == "__main__":
    check(total=10_000_000)
Member
avatar
加入于:
帖子: 1472
声望: 105

感谢分享!

© 2015-2022 上海韦纳软件科技有限公司
备案服务号:沪ICP备18006526号

沪公网安备 31011502017034号

【用户协议】
【隐私政策】
【免责条款】