250x250
Notice
Recent Posts
Recent Comments
Link
일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | |||||
3 | 4 | 5 | 6 | 7 | 8 | 9 |
10 | 11 | 12 | 13 | 14 | 15 | 16 |
17 | 18 | 19 | 20 | 21 | 22 | 23 |
24 | 25 | 26 | 27 | 28 | 29 | 30 |
Tags
- linux
- 프로세스
- Windows 10
- Windows10
- 코딩
- C언어
- C++
- 알고리즘
- 운영체제
- 백준알고리즘
- 리눅스
- 쉘
- 턱걸이
- 학습
- error
- 공부
- shell
- 백준
- 프로그래밍
- CV
- TensorFlow
- 텐서플로우
- c
- Computer Vision
- 영상처리
- OpenCV
- 시스템프로그래밍
- 딥러닝
- python
- 회귀
Archives
- Today
- Total
줘이리의 인생적기
[tensorflow 11] 선형회귀02 본문
728x90
tensorflow 10번째 게시물에서 수식과 최소제곱법을 이용하여 회귀선을 도출했었습니다.
하지만 텐서플로우에서는 이러한 어려운 수학 수식없이 회귀선을 구할 수 있습니다.
[tensorflow 10] 선형회귀01 에서 사용했던 자료들을 이용해보겠습니다.
먼저, tf.Variable를 통해 랜덤 값을 뽑아내고 tf.reduce_mean을 통해 잔차의 제곱의 평균을 뽑아내고 tf.optimizers를 통해 잔차의 제곱의 평균을 최소화하려고 합니다.
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import random
school_age_population = [16.4, 14.7, 17.2, 16.6, 17.1, 20.3, 17.2, 15.6, 17.4, 18.7, 16.9, 17.1, 15.8, 19.3, 16.6, 15.1, 17.9]
elderly_population = [11.4, 13.2, 11.1, 17.6, 16.3, 9.2, 15.2, 18.4, 18.5, 11.6, 13.7, 9.7, 21.5, 12.0, 14.5, 15.8, 14.0]
# a, b 랜덤값 초기화
a = tf.Variable(random.random())
b = tf.Variable(random.random())
#optimizer 설정
optimizer = tf.keras.optimizers.Adam(lr=0.1)
# 잔차의 제곱평균 구하기
def compute_loss():
y = a * school_age_population + b
loss = tf.reduce_mean((elderly_population - y) ** 2)
return loss
for i in range(2000):#2000번 반복
# loss 최소화
optimizer.minimize(compute_loss, var_list=[a,b])
if i % 200 == 1:
print(i, 'a:', a.numpy(), 'b:', b.numpy(), 'loss:', compute_loss().numpy())
#회귀선
line_x = np.arange(min(school_age_population), max(school_age_population), 0.1)
line_y = a * line_x + b
# 그래프
plt.plot(line_x, line_y, 'r-')
plt.plot(school_age_population, elderly_population, 'ko')
plt.xlabel('school_age_population(%)')
plt.ylabel('elderly_population(%)')
plt.show()
그래프가 잘 나온듯 싶습니다.
compute_loss라는 함수에서 loss를 구하고, SGD대신 Adam 최적화 함수를 사용하였습니다.
2000번의 학습을 통해 a, b는 적절한 값이 되었습니다.
'공부 > tensorflow' 카테고리의 다른 글
[tensorflow 13] 딥러닝 네트워크 회귀 (0) | 2021.04.22 |
---|---|
[tensorflow 12] 다항회귀 (0) | 2021.04.21 |
[tensorflow 10] 선형회귀01 (0) | 2021.04.14 |
[tensorflow 09] XOR 네트워크 시각화 (0) | 2021.04.13 |
[tensorflow 08] matplotlib.pyplot 시각화 (0) | 2021.04.12 |