收手吧,華強!我用機器學習幫你挑西瓜

在機器學習領域,有一個很有名氣的西瓜 -- 周志華老師的《機器學習》,很多同學選擇這本書入門,都曾有被西瓜支配的恐懼。我寫文章的時候也特別喜歡用西瓜數據集,以它爲例手算 + 可視化講解過 XGBoost,自認非常通俗易懂。

最近我介紹了決策樹的可視化,還有可以快速實現機器學習 web 應用的神器——streamlit。今天我們就把它們結合起來,用機器學習幫華強挑西瓜!僅供娛樂,希望大家可以學到一些新姿勢。

項目已發佈,歡迎大家試玩
https://share.streamlit.io/tjxj/watermelon-prediction/main/app.py

老規矩,先看效果圖(GIF 刷新有點慢,請耐心等待)

使用方法

第一步,左側先選擇西瓜外觀

第二步,選擇決策樹的模型參數

第三步,看結果

如果是好瓜,頁面彈出的就是笑眯眯的圖片~

如果是壞瓜,頁面顯示的是後果很嚴重的圖片

實現方式


注:篇幅原因,僅貼出核心代碼

完整代碼我放到了網頁裏,需要可以 copy 走

data.py

主要是原始數據的處理,inputData 方法實現輸入外觀變量值的標籤編碼。

def inputData():
    st.sidebar.subheader("請選擇西瓜外觀:sunglasses:")
    color = st.sidebar.selectbox("色澤"("青綠""烏黑""淺白"))
    root = st.sidebar.selectbox("根蒂"("蜷縮""稍蜷""硬挺"))
    knocks = st.sidebar.selectbox("敲擊"("濁響""沉悶""清脆"))
    texture = st.sidebar.selectbox("紋理"("清晰""稍糊""模糊"))
    navel = st.sidebar.selectbox("臍部"("凹陷""稍凹""平坦"))
    touch = st.sidebar.selectbox("觸感"("硬滑""軟粘"))
    input = [[color, root, knocks, texture, navel, touch]]
    features = ["color""root""knocks""texture""navel""touch"]
    np.array(input).reshape(1, 6)
    df_input = pd.DataFrame(input, columns=features, index=None)

    for feature in features[0:6]:
        le = joblib.load("./models/" + feature + "_LabelEncoder.model")
        df_input[feature] = le.transform(df_input[feature])

    return df_input

訓練模型

這一塊很簡單,就不多解釋了。注:數據量太小就不整交叉驗證了

def dt_param_selector():
    st.sidebar.subheader("請選擇模型參數:sunglasses:")
    criterion = st.sidebar.selectbox("criterion"["gini""entropy"])
    max_depth = st.sidebar.number_input("max_depth", 1, 50, 5, 1)
    min_samples_split = st.sidebar.number_input(
        "min_samples_split", 1, 20, 2, 1)
    max_features = st.sidebar.selectbox(
        "max_features"[None, "auto""sqrt""log2"])

    params = {
        "criterion": criterion,
        "max_depth": max_depth,
        "min_samples_split": min_samples_split,
        "max_features": max_features,
    }

    model = DecisionTreeClassifier(**params)
    df = dataPreprocessing()
    X, y = df[df.columns[:-1]], df["label"]
    model.fit(X, y)
    return model
def predictor():
    df_input = inputData()
    model = dt_param_selector()
    y_pred = model.predict(df_input)
    if y_pred == 1:
        goodwatermelon = Image.open("./pics/good.png")
        st.image(goodwatermelon,width=705,use_column_width= True)
        st.markdown("<center>🍉🍉🍉這瓜甚甜,買一個🍉🍉🍉</center>"unsafe_allow_html=True)
    else:
        file_ = open("./pics/bad2.gif""rb")
        contents = file_.read()
        data_url = base64.b64encode(contents).decode("utf-8")
        file_.close()

        st.markdown(
            f'<img src="data:image/gif;base64,{data_url}" width="100%">',
            unsafe_allow_html=True,
        )
        st.markdown('<center>🔪🔪🔪這瓜不甜,買不得🔪🔪🔪</center>'unsafe_allow_html=True)
    return y_pred,model

決策樹可視化

決策樹可視化和插入網頁我用decisionTreeVizsvg_write實現,可惜目前僅本地模式正常,發佈後報錯,尚未解決。

def decisionTreeViz():
    df,le = getDataSet()
    X, y = df[df.columns[:-1]], df["label"]
    clf = joblib.load('..\watermelonClassifier.pkl')
    viz = dtreeviz(clf, 
                X, 
                y,
                orientation="LR",
                target_name='label',
                feature_names=df.columns[:-1],
                class_names=["good","bad"]
                )  
                
    return viz

def svg_write(svg, center=True):
    """
    Disable center to left-margin align like other objects.
    """
    # Encode as base 64
    b64 = base64.b64encode(svg.encode("utf-8")).decode("utf-8")

    # Add some CSS on top
    css_justify = "center" if center else "left"
    css = f'<p style="text-align:center; display: flex; justify-content: {css_justify};">'
    html = f'{css}<img src="data:image/svg+xml;base64,{b64}"/>'

    # Write the HTML
    st.write(html, unsafe_allow_html=True)

streamlit

過程就不說了,就把調用的 streamlit API 列一下吧

st.title
st.write
st.code
st.table
st.markdown
st.sidebar
st.expander
st.code
st.image
st.pyplot

以上 API 具體用途大家可以查一查https://docs.streamlit.io/library/api-reference

TODO

以上問題,如有興趣,歡迎貢獻代碼。

本文由 Readfog 進行 AMP 轉碼,版權歸原作者所有。
來源https://mp.weixin.qq.com/s/SomzKMeFQ8GKkvtcdI24vg