大道至簡,SQL 也可以實現神經網絡
最近寫 SQL 寫多了,突發奇想 SQL 是不是也能實現簡單的神經網絡訓練呢?於是帶着這個問題在 GitHub 上找了找,還真有.... 那麼本文就來分享一下如何用純 SQL 實現一個神經網絡吧!
題外話,可能有很多人會有疑問,你一個搞算法的,爲啥在寫 SQL?這.... 就說來話長了... 總之,技多不壓身嘛!
迴歸正題,我們再用 SQL 建模時,利用列來定義參數,從輸入層到隱藏層,我們用 w1_00, w1_01, w1_10, w1_11 表示權重矩陣 W1,用 b1_0, b1_1 表示偏置向量 B1。從隱藏層到輸出層,我們用 w2_00, w2_01, w2_10, w2_11 表示權重矩陣 W2,用 b2_0, b2_1 表示偏置向量 B2。這樣我們通過一個多層嵌套的查詢語句實現了整個訓練過程。
Train 訓練部分代碼
add_iteration_sql = """
SELECT
x1,x2,y,
w_00 - (2.0)*(dw_00+(1e-3)*w_00) AS w1_00,
w_01 - (2.0)*(dw_01+(1e-3)*w_01) AS w1_01,
w_10 - (2.0)*(dw_10+(1e-3)*w_10) AS w1_10,
w_11 - (2.0)*(dw_11+(1e-3)*w_11) AS w1_11,
b_0 - (2.0)*db_0 AS b_0,
b_1 - (2.0)*db_1 AS b_1,
w2_00 - (2.0)*(dw2_00+(1e-3)*w2_00) AS w2_00,
w2_01 - (2.0)*(dw2_01+(1e-3)*w2_01) AS w2_01,
w2_10 - (2.0)*(dw2_10+(1e-3)*w2_10) AS w2_10,
w2_11 - (2.0)*(dw2_11+(1e-3)*w2_11) AS w2_11,
b2_0 - (2.0)*db2_0 AS b2_0,
b2_1 - (2.0)*db2_1 AS b2_1
FROM (
SELECT
*,
SUM(x1*dhidden_0) OVER () AS dw_00,
SUM(x1*dhidden_1) OVER () AS dw_01,
SUM(x2*dhidden_0) OVER () AS dw_10,
SUM(x2*dhidden_1) OVER () AS dw_11,
SUM(dhidden_0) OVER () AS db_0,
SUM(dhidden_1) OVER () AS db_1
FROM (
SELECT
*,
SUM(d0*dscores_0) OVER () AS dw2_00,
SUM(d0*dscores_1) OVER () AS dw2_01,
SUM(d1*dscores_0) OVER () AS dw2_10,
SUM(d1*dscores_1) OVER () AS dw2_11,
SUM(dscores_0) OVER () AS db2_0,
SUM(dscores_1) OVER () AS db2_1,
CASE
WHEN (d0) <= 0.0 THEN 0.0
ELSE (dscores_0*w2_00 + dscores_1*w2_01)
END AS dhidden_0,
CASE
WHEN (d1) <= 0.0 THEN 0.0
ELSE (dscores_0*w2_10 + dscores_1*w2_11)
END AS dhidden_1
FROM (
SELECT
*,
(CASE
WHEN y = 0 THEN (probs_0 - 1)/num_examples
ELSE probs_0/num_examples END) AS dscores_0,
(CASE
WHEN y = 1 THEN (probs_1 - 1)/num_examples
ELSE probs_1/num_examples END) AS dscores_1
FROM (
SELECT
*,
(sum_correct_logprobs/num_examples) + 1e-3*(0.5*(w_00*w_00 + w_01*w_01 + w_10*w_10 + w_11*w_11) + 0.5*(w2_00*w2_00 + w2_01*w2_01 + w2_10*w2_10 + w2_11*w2_11)) AS loss
FROM (
SELECT
*,
SUM(correct_logprobs) OVER () sum_correct_logprobs,
COUNT(1) OVER () num_examples
FROM (
SELECT
*,
(CASE
WHEN y = 0 THEN -1*LOG(probs_0)
ELSE -1*LOG(probs_1) END) AS correct_logprobs
FROM (
SELECT
*,
EXP(scores_0)/(EXP(scores_0) + EXP(scores_1)) AS probs_0,
EXP(scores_1)/(EXP(scores_0) + EXP(scores_1)) AS probs_1
FROM (
SELECT
*,
((d0*w2_00 + d1*w2_10) + b2_0) AS scores_0,
((d0*w2_01 + d1*w2_11) + b2_1) AS scores_1
FROM (
SELECT
*,
(CASE
WHEN ((x1*w_00 + x2*w_10) + b_0) > 0.0 THEN ((x1*w_00 + x2*w_10) + b_0)
ELSE 0.0 END) AS d0,
(CASE
WHEN ((x1*w_01 + x2*w_11) + b_0) > 0.0 THEN ((x1*w_01 + x2*w_11) + b_1)
ELSE 0.0 END) AS d1
FROM (
{}))))))))))""";
def generate_query(add_iteration_sql, root_table_sql, iterations):
"""
returns SQL query for deep neural network training
param root_table_sql: SQL inner query producing a table with the training data and the initial values of the model parameters
param add_iteration_sql: string format for adding one iteration of forward pass and backpropagation
iterations: number of training iterations to be performed
"""
inner_table = None
final_query = None
for i in range(iterations):
if inner_table is None:
inner_table = root_table_sql
else:
inner_table = final_query
final_query = add_iteration_sql.format(inner_table)
return final_query
print(generate_query(add_iteration_sql, root_table_sql, 10))
Predict 部分代碼
SELECT
(SUM(CASE
WHEN y_hat = y THEN 1
ELSE 0 END)/COUNT(1))*100.0 AS accuracy_perc
FROM (
SELECT
*,
(CASE
WHEN scores_0 > scores_1 THEN 0
ELSE 1 END) AS y_hat
FROM (
SELECT
*,
((d0*w2_00 + d1*w2_10) + b2_0) AS scores_0,
((d0*w2_01 + d1*w2_11) + b2_1) AS scores_1
FROM (
SELECT
*,
(CASE
WHEN ((x1*w_00 + x2*w_10) + b_0) > 0.0 THEN ((x1*w_00 + x2*w_10) + b_0)
ELSE 0.0 END) AS d0,
(CASE
WHEN ((x1*w_01 + x2*w_11) + b_0) > 0.0 THEN ((x1*w_01 + x2*w_11) + b_1)
ELSE 0.0 END) AS d1
FROM ( (
SELECT
*
FROM
`dota.2009.example_table_for_sql2nn` )))))
參考資料
-
https://towardsdatascience.com/deep-neural-network-implemented-in-pure-sql-over-bigquery-f3ed245814d3
-
https://github.com/harisankarh/nn-sql-bq/blob/master/query_for_prediction.sql
本文由 Readfog 進行 AMP 轉碼,版權歸原作者所有。
來源:https://mp.weixin.qq.com/s/dPmcS3juWPoec_uougorAg