import sqlite3
from datetime import datetime, timedelta
import requests
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
数据库配置
DB_FILE = 'stock_data.db'
def init_database():
"""初始化数据库,确保表结构正确"""
conn = sqlite3.connect(DB_FILE)
cursor = conn.cursor()
# 检查表是否存在
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='stock_daily'")
table_exists = cursor.fetchone()
if table_exists:
# 检查表结构是否完整
cursor.execute("PRAGMA table_info(stock_daily)")
columns = [col[1] for col in cursor.fetchall()]
required_columns = ['symbol', 'exchange', 'datetime', 'open', 'high', 'low', 'close', 'volume']
if not all(col in columns for col in required_columns):
# 表结构不完整,重建表
print("检测到旧表结构,重建表中...")
cursor.execute("DROP TABLE IF EXISTS stock_daily")
conn.commit()
table_exists = False
if not table_exists:
# 创建新表
cursor.execute('''
CREATE TABLE stock_daily (
symbol TEXT NOT NULL,
exchange TEXT NOT NULL,
datetime TEXT NOT NULL,
open REAL NOT NULL,
high REAL NOT NULL,
low REAL NOT NULL,
close REAL NOT NULL,
volume REAL NOT NULL,
PRIMARY KEY (symbol, exchange, datetime)
)
''')
print("创建新表成功")
conn.commit()
conn.close()
def fetch_stock_data(symbol, exchange='SZ', start_date="2010-01-01", end_date=None):
"""从腾讯财经获取股票日线数据"""
if end_date is None:
end_date = datetime.now().strftime("%Y-%m-%d")
exchange_code = exchange.lower()
bar_data = []
current_date = datetime.strptime(start_date, "%Y-%m-%d")
end_date = datetime.strptime(end_date, "%Y-%m-%d")
print(f"开始获取 {exchange}{symbol} 数据,时间范围: {start_date} 至 {end_date.date()}")
while current_date <= end_date:
chunk_end = min(current_date + timedelta(days=365), end_date)
start_str = current_date.strftime("%Y-%m-%d")
end_str = chunk_end.strftime("%Y-%m-%d")
print(f"获取数据: {start_str} 至 {end_str}")
url = f"http://web.ifzq.gtimg.cn/appstock/app/fqkline/get?param={exchange_code}{symbol},day,{start_str},{end_str},2000,qfq"
print(url)
try:
response = requests.get(url, timeout=10)
data = response.json()
if not data.get('data') or f"{exchange_code}{symbol}" not in data['data']:
print(f"获取数据失败,响应: {data}")
current_date = chunk_end + timedelta(days=1)
continue
for record in data["data"][f"{exchange_code}{symbol}"]["qfqday"]:
if len(record) < 6:
continue
bar_data.append({
'symbol': symbol,
'exchange': exchange,
'datetime': datetime.strptime(record[0], "%Y-%m-%d"),
'open': float(record[1]),
'high': float(record[3]),
'low': float(record[4]),
'close': float(record[2]),
'volume': float(record[5])
})
except Exception as e:
print(f"获取数据出错: {e}")
current_date = chunk_end + timedelta(days=1)
print(f"共获取到 {len(bar_data)} 条数据")
return bar_data
def save_to_database(data):
"""保存数据到数据库"""
if not data:
print("没有数据需要保存")
return
conn = sqlite3.connect(DB_FILE)
cursor = conn.cursor()
inserted = 0
for bar in data:
try:
cursor.execute('''
INSERT OR IGNORE INTO stock_daily
(symbol, exchange, datetime, open, high, low, close, volume)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
''', (
bar['symbol'],
bar['exchange'],
bar['datetime'].strftime("%Y-%m-%d"),
bar['open'],
bar['high'],
bar['low'],
bar['close'],
bar['volume']
))
inserted += cursor.rowcount
except sqlite3.Error as e:
print(f"保存数据出错: {e}")
conn.commit()
conn.close()
print(f"成功保存 {inserted} 条新数据,跳过 {len(data)-inserted} 条已存在数据")
def get_last_record_date(symbol, exchange):
"""获取数据库中某只股票的最后记录日期"""
conn = sqlite3.connect(DB_FILE)
cursor = conn.cursor()
cursor.execute('''
SELECT MAX(datetime)
FROM stock_daily
WHERE symbol=? AND exchange=?
''', (symbol, exchange))
result = cursor.fetchone()
conn.close()
if result and result[0]:
return datetime.strptime(result[0], "%Y-%m-%d") + timedelta(days=1)
return None
def load_stock_data(symbol, exchange):
"""从数据库加载股票数据"""
conn = sqlite3.connect(DB_FILE)
query = '''
SELECT datetime, open, high, low, close, volume
FROM stock_daily
WHERE symbol=? AND exchange=?
ORDER BY datetime
'''
df = pd.read_sql_query(query, conn, params=(symbol, exchange))
conn.close()
if not df.empty:
df['datetime'] = pd.to_datetime(df['datetime'])
df.set_index('datetime', inplace=True)
return df
def calculate_technical_indicators(df):
"""计算技术指标"""
if df.empty:
return df
# 计算均线
df['ma5'] = df['close'].rolling(5, min_periods=1).mean()
df['ma20'] = df['close'].rolling(20, min_periods=1).mean()
return df
def visualize_stock_data(df, symbol, exchange):
"""可视化股票数据"""
if df.empty:
print("没有数据可可视化")
return
# 创建子图布局
fig = make_subplots(
rows=2, cols=1,
shared_xaxes=True,
vertical_spacing=0.05,
row_heights=[0.7, 0.3],
specs=[[{"type": "scatter"}], [{"type": "bar"}]]
)
# 添加K线图
fig.add_trace(
go.Candlestick(
x=df.index,
open=df['open'],
high=df['high'],
low=df['low'],
close=df['close'],
name='K线',
increasing_line_color='red',
decreasing_line_color='green'
),
row=1, col=1
)
# 添加均线
fig.add_trace(
go.Scatter(
x=df.index,
y=df['ma5'],
line=dict(color='blue', width=1),
name='5日均线'
),
row=1, col=1
)
fig.add_trace(
go.Scatter(
x=df.index,
y=df['ma20'],
line=dict(color='orange', width=1.5),
name='20日均线'
),
row=1, col=1
)
# 添加成交量
fig.add_trace(
go.Bar(
x=df.index,
y=df['volume'],
name='成交量',
marker_color='gray',
opacity=0.5
),
row=2, col=1
)
# 更新布局
fig.update_layout(
title=f'{exchange}{symbol} 股票数据 (更新至 {df.index[-1].date()})',
xaxis_rangeslider_visible=False,
height=800,
template='plotly_white',
hovermode='x unified'
)
# 更新y轴标题
fig.update_yaxes(title_text="价格", row=1, col=1)
fig.update_yaxes(title_text="成交量", row=2, col=1)
fig.show()
def main():
# 股票代码和交易所
symbol = "600519" # 平安银行
exchange = "sh" # 深圳交易所
# 1. 初始化数据库
print("初始化数据库...")
init_database()
# 2. 获取最后记录日期,确定下载范围
last_date = get_last_record_date(symbol, exchange)
start_date = "2010-01-01" if last_date is None else last_date.strftime("%Y-%m-%d")
print(f"最后记录日期: {last_date.date() if last_date else '无记录'}")
print(f"将从 {start_date} 开始获取数据")
# 3. 获取股票数据
stock_data = fetch_stock_data(symbol, exchange, start_date=start_date)
# 4. 保存到数据库
save_to_database(stock_data)
# 5. 从数据库加载完整数据
df = load_stock_data(symbol, exchange)
if df.empty:
print("没有获取到数据")
return
# 6. 计算技术指标
df = calculate_technical_indicators(df)
# 7. 可视化
visualize_stock_data(df, symbol, exchange)
if name == "main":
main()