《机器学习实战》第3版 第二章 端到端机器学习项目

第二章 端到端机器学习项目

环境要求:

  • python >= 3.7

  • sklearn >= 1.0.1

以下代码均基于 Google Colab 平台,建议使用 Jupyter Notebook 运行。

GitHub 地址:

欢迎 Star~

1. 准备工作

logging 配置,用于打印相关信息。

1
2
3
4
5
6
7
8
9
import logging
import importlib

# 使用 logging.info 打印信息,colab 需要 reload() 函数,否则无法打印
importlib.reload(logging)

logging.basicConfig(format="%(asctime)s %(levelname)s %(message)s", level=logging.INFO)

logging.info("Test logging output...")

判断 Python 和 sklearn 版本是否符合要求:

1
2
3
4
5
import sys

# 判断 python 版本是否 >= 3.7
logging.info("Current Python Version: " + str(sys.version_info))
assert sys.version_info >= (3, 7)

运行结果:

1
2024-12-10 08:35:05,400 INFO Current Python Version: sys.version_info(major=3, minor=10, micro=12, releaselevel='final', serial=0)

判断 sklearn 版本是否符合要求:

1
2
3
4
5
6
from packaging import version   # pakage 包用于解析和比较包版本号
import sklearn

logging.info("Current sklearn Version: " + str(sklearn.__version__))
# 使用 package 包里的 version 类进行解析和比较版本号
assert version.parse(sklearn.__version__) >= version.parse("1.0.1")

运行结果:

1
2024-12-10 08:35:14,973 INFO Current sklearn Version: 1.5.2

2. 获取数据

从网址:https://github.com/ageron/data/raw/main/housing.tgz下载数据保存并解压到本地。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from pathlib import Path
import pandas as pd
import tarfile
import urllib.request

def load_housing_data():
"""
获取房价数据
"""
# 定义数据集压缩包路径
tarfile_path = Path("datasets/housing.tgz")
# 如果压缩文件不存在, 则执行下载和解压操作
if not tarfile_path.is_file():
# 创建保存数据集的目录"datasets"
Path("datasets").mkdir(parents=True, exist_ok=True)

url = "https://github.com/ageron/data/raw/main/housing.tgz"

logging.info("Downloading data from: " + url)

# 下载压缩文件到本地路径
urllib.request.urlretrieve(url, tarfile_path)

# 解压 tar.gz 文件
with tarfile.open(tarfile_path) as housing_tarfile:
# 文件解压到 "datasets" 目录下
housing_tarfile.extractall(path="datasets")

logging.info("Data downloaded and extracted.")

# 加载解压后的 CSV 文件为 Pandas DataFrame 并返回
return pd.read_csv(Path("datasets/housing/housing.csv"))


housing = load_housing_data()

运行结果:

1
2
2024-12-10 13:22:21,348 INFO Downloading data from: https://github.com/ageron/data/raw/main/housing.tgz
2024-12-10 13:22:21,573 INFO Data downloaded and extracted.

3. 查看数据

数据表表头为:

英文 翻译
longitude 经度
latitude 纬度
housing_median_age 住房年龄中位数
total_rooms 房间数
total_bedrooms 卧室数目
population 人口
households 家庭数
median_income 收入中位数
median_house_value 房屋中位价
ocean_proximity 海洋邻近度

查看数据表结构和内容:

1
2
# 查看数据表结构和内容
housing.head()

输出结果:

1
2
3
4
5
6
longitude	latitude	housing_median_age	total_rooms	total_bedrooms	population	households	median_income	median_house_value	ocean_proximity
0 -122.23 37.88 41.0 880.0 129.0 322.0 126.0 8.3252 452600.0 NEAR BAY
1 -122.22 37.86 21.0 7099.0 1106.0 2401.0 1138.0 8.3014 358500.0 NEAR BAY
2 -122.24 37.85 52.0 1467.0 190.0 496.0 177.0 7.2574 352100.0 NEAR BAY
3 -122.25 37.85 52.0 1274.0 235.0 558.0 219.0 5.6431 341300.0 NEAR BAY
4 -122.25 37.85 52.0 1627.0 280.0 565.0 259.0 3.8462 342200.0 NEAR BAY

查看 DataFrame 的概况与结构:

1
2
# 查看 DataFrame 的总行数和列数;每列名称、非空值数量、数据类型;内存使用情况
housing.info()

输出结果:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 20640 entries, 0 to 20639
Data columns (total 10 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 longitude 20640 non-null float64
1 latitude 20640 non-null float64
2 housing_median_age 20640 non-null float64
3 total_rooms 20640 non-null float64
4 total_bedrooms 20433 non-null float64
5 population 20640 non-null float64
6 households 20640 non-null float64
7 median_income 20640 non-null float64
8 median_house_value 20640 non-null float64
9 ocean_proximity 20640 non-null object
dtypes: float64(9), object(1)
memory usage: 1.6+ MB

分类统计 ocean_proximity 字段的值

1
2
# 分类统计 ocean_proximity 字段的值
housing["ocean_proximity"].value_counts()

运行结果:

1
2
3
4
5
6
7
	count ocean_proximity	
<1H OCEAN 9136
INLAND 6551
NEAR OCEAN 2658
NEAR BAY 2290
ISLAND 5
dtype: int64

生成 DataFrame 的列的统计摘要信息

统计指标

  • count:非空值数量。
  • mean:均值(所有值的平均数)。
  • std:标准差(衡量数据分布的离散程度)。
  • minmax:最小值和最大值。
  • 四分位数(25%、50%、75%):
    • 25%:第 1 四分位数(数据中 25% 的值小于该值)。
    • 50%:中位数(第 2 四分位数,等于数据中间值)。
    • 75%:第 3 四分位数(数据中 75% 的值小于该值)。
1
2
# 对数值列进行统计分析
housing.describe()

运行结果:

1
2
3
4
5
6
7
8
9
	longitude	latitude	housing_median_age	total_rooms	total_bedrooms	population	households	median_income	median_house_value
count 20640.000000 20640.000000 20640.000000 20640.000000 20433.000000 20640.000000 20640.000000 20640.000000 20640.000000
mean -119.569704 35.631861 28.639486 2635.763081 537.870553 1425.476744 499.539680 3.870671 206855.816909
std 2.003532 2.135952 12.585558 2181.615252 421.385070 1132.462122 382.329753 1.899822 115395.615874
min -124.350000 32.540000 1.000000 2.000000 1.000000 3.000000 1.000000 0.499900 14999.000000
25% -121.800000 33.930000 18.000000 1447.750000 296.000000 787.000000 280.000000 2.563400 119600.000000
50% -118.490000 34.260000 29.000000 2127.000000 435.000000 1166.000000 409.000000 3.534800 179700.000000
75% -118.010000 37.710000 37.000000 3148.000000 647.000000 1725.000000 605.000000 4.743250 264725.000000
max -114.310000 41.950000 52.000000 39320.000000 6445.000000 35682.000000 6082.000000 15.000100 500001.000000

生成每个数值属性的直方图

图片路径以及图片保存函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# 图片路径
IMAGES_PATH = Path() / "images" / "end_to_end_prroject"
IMAGES_PATH.mkdir(parents=True, exist_ok=True)

def save_fig(fig_id, tight_layout=True, fig_extension="png", resolution=300):
"""
保存高分辨率图片图片
@param fig_id: 图片名称
@param tight_layout: 是否使用紧凑布局
@param fig_extension: 图片格式
@param resolution: 图片分辨率
"""

# 创建保存路径:通过 IMAGES_PATH 和文件名及扩展名组合路径
path = IMAGES_PATH / f"{fig_id}.{fig_extension}"

# 设置是否需要紧凑布局
if tight_layout:
plt.tight_layout()

# 保存指定格式的图片到本地
plt.savefig(path, format=fig_extension, dpi=resolution)

保存图片:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import matplotlib.pyplot as plt

# 设置默认字体大小
plt.rc("font", size=14)
plt.rc("axes", labelsize=14, titlesize=14)
plt.rc("legend", fontsize=14)
plt.rc("xtick", labelsize=10)
plt.rc("ytick", labelsize=10)

# 为 housing 数据框中的每一列数值数据绘制直方图
housing.hist(bins=50, figsize=(12, 8))
# 保存图片
save_fig("attribute_histogram_plots")
plt.show()

4. 创建数据集

创建数据集函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import numpy as np

# 创建数据集
def shuffle_and_split_data(data, test_ratio):
"""
随机打乱数据并按比例拆分为训练集和测试集
@param data: 数据集
@param test_ratio: 测试集比例
@return: 训练集和测试集
"""

# 随机打乱数据索引
shuffled_indices = np.random.permutation(len(data))

# 计算测试集大小
test_set_size = int(len(data) * test_ratio)

# 拆分索引为测试集和训练集
test_indices = shuffled_indices[:test_set_size]
train_indices = shuffled_indices[test_set_size:]

# 根据索引返回训练集和测试集
return data.iloc[train_indices], data.iloc[test_indices]

获取训练集和测试集:

1
2
3
# 获取训练集和测试集, 测试集比例为 0.2
train_set, test_set = shuffle_and_split_data(housing, 0.2)
logging.info(f"Train Set Size: {len(train_set)}, Test Set Size: {len(test_set)}")

运行结果:

1
2024-12-10 13:53:31,519 INFO Train Set Size: 16512, Test Set Size: 4128

为了保证每次运行 notebook 的结果保持不变,需要设置随机种子:

1
2
# 设置随机种子保证每次运行 notebook 的结果保持不变
np.random.seed(42)

相关辅助函数

判断 id 是否属于测试集函数以及根据唯一 id 标识符拆分数据为训练集和测试集函数。

1
2
3
4
5
6
7
8
9
10
11
12
from zlib import crc32

def is_id_in_test_set(identifier, test_ratio):
"""
判断数据是否在测试集中
@param identifier: 数据标识符
@param test_ratio: 测试集比例
@return: 如果属于测试集,返回 True;否则返回 False
"""
# 2**32 是 CRC32 哈希值的上限,故 test_ratio * 2**32 给出了测试集的哈希值阈值
return crc32(np.int64(identifier)) < test_ratio * 2**32


《机器学习实战》第3版 第二章 端到端机器学习项目
https://excelius.xyz/《机器学习实战》第3版-第二章-端到端机器学习项目/
作者
Excelius
发布于
2024年12月10日
许可协议