LSTM 黃金分析

      在〈LSTM 黃金分析〉中尚無留言

本篇使用 LSTM 模型分析及預測台銀黃金存摺價格,資料可由 yfinance 或本站資料庫取得。

模型訓練

模型訓練如下

#pip install pandas tensorflow  scikit-learn
import pandas as pd
from keras import Sequential
from keras.src.layers import Dense, SimpleRNN, Dropout, LSTM
from sklearn.preprocessing import MinMaxScaler

from G import G
days=1
conn, cursor=G.connect()
cursor.execute("select * from 台銀黃金 order by 日期")
rs=cursor.fetchall()
columns=[d[0] for d in cursor.description]
df=pd.DataFrame(rs,columns=columns)
scaler = MinMaxScaler()
df["scaler"]=scaler.fit_transform(df[["賣出"]])

data=pd.DataFrame([df["scaler"].iloc[i:i+days].tolist() for i in range (len(df)-days+1)])
data["target"]=df["scaler"].shift(-7)
data=data.dropna()

#混亂
data = data.sample(frac=1, random_state=42).reset_index(drop=True)

split=0.8
mid=int(len(data)*split)
train_x=data.iloc[:mid,0:days]
train_y=data.iloc[:mid,-1]
test_x=data.iloc[mid:,0:days]
test_y=data.iloc[mid:,-1]

model=Sequential()
model.add(
    LSTM(
        input_shape=(days, 1),
        units=256,
        unroll=False
    )
)
model.add(Dropout(0.2))
model.add(Dense(1))
model.compile(
    loss="mse",
    optimizer="adam",
    metrics=["accuracy"]
)
model.fit(
    train_x, train_y,
    batch_size=200,
    epochs=200,
    verbose=1,
    validation_data=(test_x, test_y)
)
model.save(f"gold_{days}.keras")

todo

預測

預測明日黃金價格

from datetime import timedelta
import keras
from sklearn.preprocessing import MinMaxScaler
import plotly.graph_objects as go
from G import G
import pandas as pd
days=5
conn, cursor=G.connect()
cursor.execute("select * from 台銀黃金 order by 日期")
rs=cursor.fetchall()
columns=[d[0] for d in cursor.description]
df=pd.DataFrame(rs,columns=columns)
scaler = MinMaxScaler()
df["scaler"]=scaler.fit_transform(df[["賣出"]])

data=pd.DataFrame([df["scaler"].iloc[i-days:i].tolist() for i in range (days, len(df)+1)])
data["date"]=df["日期"].shift(-days+1)
data["sale"]=df["賣出"].shift(-days+1)

model=keras.models.load_model(f'gold_{days}.keras')
gold=data.iloc[:,0:days]
predict=model.predict(gold)

data["predict"]=scaler.inverse_transform(predict)
date=data.iloc[-1]["date"]
date=date+ timedelta(days=1)
data.loc[len(df)]=[0]*days+[date,0,0]
data["predict"]=data["predict"].shift(1)
data=data.dropna()

period=-300
fig=go.Figure()
fig.add_trace(
    go.Scatter(
        x=data["date"][period:],
        y=data["sale"][period:-1],
        mode='lines',
        name='實際價格',
        line=dict(color='red', width=2)
    )
)
fig.add_trace(
    go.Scatter(
        x=data["date"][period:],
        y=data["predict"][period:],
        mode='lines',
        name='預測價格',
        line=dict(color='green', width=2)
    )
)

fig.update_layout(
    dragmode="pan",
    title_text="台灣黃金存摺歷史價格",
    xaxis=go.layout.XAxis(
        rangeselector=dict(
            buttons=list([
                dict(count=1,
                     label="1 month",
                     step="month",
                     stepmode="backward"),
                dict(count=6,
                     label="6 month",
                     step="month",
                     stepmode="backward"),
                dict(count=1,
                     label="1 year",
                     step="year",
                     stepmode="backward"),
                dict(count=1,
                     label="1 day",
                     step="day",
                     stepmode="todate"),
                dict(step="all")
            ])
        ),
        rangeslider=dict(
            visible=True
        ),
        #range=[datetime.datetime(d.year, 1,1),datetime.datetime(d.year, d.month, d.day)],
        type="date"
    ),
    yaxis=dict(fixedrange=False)
)
fig.show()

todo

平均日線

底下可繪出黃金期貨走勢圖及預測值,請先安裝如下套件

pip install plotly yfinance scikit-learn pandas

完整代碼如下

from datetime import datetime, timedelta
import pandas as pd
from dateutil.relativedelta import relativedelta
from sklearn.linear_model import LinearRegression

display=pd.options.display
display.max_columns=None
display.max_rows=None
display.width=None
display.max_colwidth=None
import yfinance as yf
import plotly.graph_objects as go
"""
大盤 : ^TWII
黃金期貨 : GC=F
"""
stock='GC=F'
df=yf.download(stock, start='2023-10-01', end='2024-05-12')

fig=go.Figure()
fig.add_trace(
    go.Scatter(
        x=df.index,
        y=df['Close'].values,
        mode='lines',
        name='實際價格',
        line=dict(color='royalblue', width=2)
    )
)

ma1=5#5日平均線
ma2=10#10日平均線
df=df.dropna()
df['s1']=df['Close'].rolling(window=ma1).mean()
df['s2']=df['Close'].rolling(window=ma2).mean()
df=df.dropna()

train=df[['Close','s1','s2']]
train['next_day_price']=train['Close'].shift(-1)
train=train.dropna()
x_train=train[['s1', 's2']]
y_train=train['next_day_price']
model=LinearRegression()
model.fit(x_train, y_train)
df['predict_price']=model.predict(df[['s1', 's2']])
pred=df[['predict_price']]
s=(pred.tail(1).index+timedelta(days=1))[0]
dates=pd.date_range(s, periods=1)
pred.loc[dates[0]]=[0]
pred['predict_price']=pred['predict_price'].shift(1)
print(pred)
fig.add_trace(
    go.Scatter(
        x=pred.index,
        y=pred['predict_price'].values,
        mode='lines',
        name='AI預測',
        line=dict(color='orange', width=1)
    )
)
current = datetime.now()
xrange = [(current - relativedelta(months=6)).strftime("%Y-%m-%d"), current.strftime("%Y-%m-%d")]
yrange = [df['Close'].tail(180).min(), df['Close'].tail(180).max()]
fig.update_layout(
    dragmode="pan",
    xaxis=go.layout.XAxis(
        range=xrange,
        rangeselector=dict(
            buttons=list([
                dict(count=1,
                     label="1 month",
                     step="month",
                     stepmode="backward"),
                dict(count=6,
                     label="6 month",
                     step="month",
                     stepmode="backward"),
                dict(count=1,
                     label="1 year",
                     step="year",
                     stepmode="backward"),
                dict(count=1,
                     label="1 day",
                     step="day",
                     stepmode="todate"),
                dict(step="all")
            ])
        ),
        rangeslider=dict(
            visible=True
        ),
        type="date"
    ),
    yaxis=dict(
        fixedrange=False,
        range=yrange
    )
)
fig.show()

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *