这个报错我遇见过好几次了,就想看看到底是个啥情况,然后有了这个帖子用于记录和分享。
先写我的结论:
函数报错,因为入参“均线窗口”的最小值是2,你如果传进去1之类的值那肯定就报错了。
同时,talib.SMA可以用talib.MA替代,这样参数“均线窗口”为1时,也能正常运行。具体的代码示例如下:
# 函数talib.SMA
talib.SMA(numpy.array([i * 0.1 for i in range(1, 9)]), 2)
# 与之等价的talib.MA的写法
talib.MA(numpy.array([i * 0.1 for i in range(1, 9)]), 2, talib.MA_Type.SMA)
下面是细碎的验证过程,不想看的可以直接关闭了。
入参“均线窗口”的最小值是2,是在C代码层面就限制了的。C代码的截图如下所示:
所以这个报错怨不得python包,那么,能做的就是,
要么接受参数最小值是2,
要么找一个参数最小值是1的替代函数(我个人有参数最小值是1的需要),
我偶然看到了 talib.MA,然后认为 matype=talib.MA_Type.SMA 时的 talib.MA 等价于 talib.SMA,
然后写了个脚本进行验证,脚本执行正常,没发现什么问题,脚本内容如下:
# !/usr/bin/env python3
# coding=utf8
"""
这个脚本不严谨地验证了 talib.SMA 和 (参数为 talib.MA_Type.SMA 时的)talib.MA 是等价的,
talib.SMA(numpy.array([i * 0.1 for i in range(1, 9)]), 2)
talib.MA( numpy.array([i * 0.1 for i in range(1, 9)]), 2, talib.MA_Type.SMA)
"""
import json
import numpy
import random
import talib
from typing import Any, Dict, List, Set, Tuple, Type, Optional, Union, Callable
def gen_price_list(base: float, count: int) -> List[float]:
"""
以 base 为基准, 每次大致在涨跌 10% 的范围内波动, 生成 count 个数据
"""
values: List[float] = [float(base)]
for _ in range(count - 1):
dn: float = values[-1] * (1 - 0.1)
up: float = values[-1] * (1 + 0.1)
value: float = round(random.uniform(a=dn, b=up), 2)
values.append(value)
return values
def check_once(base: float, count: int, timeperiod: int):
"""
验证 talib.SMA 和 (参数为 talib.MA_Type.SMA 时的)talib.MA 是等价的,
验证 prices 和参数为 1 时的 talib.MA 是等价的,
"""
base: float = round(base, 2)
values: List[float] = gen_price_list(base=base, count=count)
prices: numpy.ndarray = numpy.array(object=values)
smaRet: numpy.ndarray = talib.SMA(prices, timeperiod)
ma1Ret: numpy.ndarray = talib.MA(prices, timeperiod, talib.MA_Type.SMA)
ma2Ret: numpy.ndarray = talib.MA(prices, 1, talib.MA_Type.SMA)
if (ma1Ret.all() != smaRet.all()) or (ma2Ret.all() != prices.all()):
stat: dict = {"base": base, "count": count, "timeperiod": timeperiod, "values": values, }
raise RuntimeError(json.dumps(obj=stat))
return True
def check(total: int):
""""""
counter: int = 0
for _ in range(total):
base: float = round(random.uniform(a=1.0, b=1000.0), 2)
count: int = random.randint(1, 10000)
timeperiod: int = random.randint(2, count if 2 <= count else 2) # talib.SMA 的 timeperiod 的最小是 2
check_once(base=base, count=count, timeperiod=timeperiod)
counter += 1
if counter % 1000 == 0:
print(f"total={total}, counter={counter}, {round(counter/total*100,2)}%,")
if __name__ == "__main__":
check(total=10_000_000)