文章目录
  1. 1. 代码
  2. 2. 输出结果
  3. 3. 回归结果

线性回归

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import matplotlib.pyplot as plt
import numpy as np

x1 = np.array([1.0, 2.0, 3.0, 4.0])
y1 = np.array([3.0, 5.0, 7.0, 9.0])

w = 0
b = 0

def forward(x):
return x * w + b

def loss(x, y):
y_hat = forward(x)
loss = (y - y_hat) * (y - y_hat)
return loss

def backward(x, w, b, y):
dw = -2 * (y - w * x - b) * x
db = -2 * (y - w * x - b)
return dw, db

epochs = 100
lr = 0.05
l_lists=[]
w_lists=[]
b_lists=[]

for epoch in range(epochs):
for x, y in zip(x1, y1):
dw, db = backward(x, w, b, y)
l = loss(x, y)
w = w - lr * dw
b = b - lr * db
l_lists.append(l)
w_lists.append(w)
b_lists.append(b)

if epoch % 10 == 0:
print(f"The epoch is {epoch}, mean_loss is {sum(l_lists)/len(l_lists)}")
print(f"The epoch is {epoch}, mean_w is {sum(w_lists)/len(w_lists)}")
print(f"The epoch is {epoch}, mean_b is {sum(b_lists)/len(b_lists)}")
print("\n")

print(f"The final mean_loss is {sum(l_lists)/len(l_lists)}")
print(f"The final w is {sum(w_lists)/len(w_lists)}")
print(f"The final b is {sum(b_lists)/len(b_lists)}")

y11 = forward(x1)

plt.scatter(x1, y1, c='r', linewidth=2)
plt.plot(x1, y11, c='b', linewidth=2)

输出结果

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
The epoch is 0, mean_loss is 9.999999999988916e-07
The epoch is 0, mean_w is 1.9993999999999998
The epoch is 0, mean_b is 1.0031


The epoch is 10, mean_loss is 5.303886957252541e-07
The epoch is 10, mean_w is 1.9995744553804442
The epoch is 10, mean_b is 1.0021986472010385


The epoch is 20, mean_loss is 3.300322394851411e-07
The epoch is 20, mean_w is 1.9996845480237937
The epoch is 20, mean_b is 1.0016298352104005


The epoch is 30, mean_loss is 2.3159155039371252e-07
The epoch is 30, mean_w is 1.9997564500065015
The epoch is 30, mean_b is 1.0012583416330758


The epoch is 40, mean_loss is 1.7648131096737295e-07
The epoch is 40, mean_w is 1.9998051016879976
The epoch is 40, mean_b is 1.0010069746120145


The epoch is 50, mean_loss is 1.4212792446273977e-07
The epoch is 50, mean_w is 1.9998392010774284
The epoch is 50, mean_b is 1.0008307944332897


The epoch is 60, mean_loss is 1.1887581929833034e-07
The epoch is 60, mean_w is 1.9998639227367998
The epoch is 60, mean_b is 1.0007030658598695


The epoch is 70, mean_loss is 1.021420123654263e-07
The epoch is 70, mean_w is 1.9998824180130526
The epoch is 70, mean_b is 1.0006075069325624


The epoch is 80, mean_loss is 8.95337294833965e-08
The epoch is 80, mean_w is 1.9998966544004768
The epoch is 80, mean_b is 1.0005339522642032


The epoch is 90, mean_loss is 7.969522998516226e-08
The epoch is 90, mean_w is 1.999907892407271
The epoch is 90, mean_b is 1.0004758892290995


The final mean_loss is 7.252273243521972e-08
The final w is 1.9999161342703209
The final b is 1.0004333062700113

回归结果

tho3cD.jpg

文章目录
  1. 1. 代码
  2. 2. 输出结果
  3. 3. 回归结果
Fork me on GitHub