神经网络基础
2026/5/6大约 1 分钟
神经网络
PyTorch 提供了强大的工具来构建和训练神经网络。
神经网络在 PyTorch 中是通过 torch.nn 模块来实现的。
torch.nn 模块提供了各种网络层(如全连接层、卷积层等)、损失函数和优化器,让神经网络的构建和训练变得更加方便。
在 PyTorch 中,构建神经网络通常需要继承 nn.Module 类。
nn.Module 是所有神经网络模块的基类,你需要定义以下两个部分:
init():定义网络层。
forward():定义数据的前向传播过程。
简单的全连接神经网络(Fully Connected Network):
import torch
import torch.nn as nn
# 定义一个简单的神经网络模型
class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
# 定义一个输入层到隐藏层的全连接层
self.fc1 = nn.Linear(2, 2) # 输入 2 个特征,输出 2 个特征
# 定义一个隐藏层到输出层的全连接层
self.fc2 = nn.Linear(2, 1) # 输入 2 个特征,输出 1 个预测值
def forward(self, x):
# 前向传播过程
x = torch.relu(self.fc1(x)) # 使用 ReLU 激活函数
x = self.fc2(x) # 输出层
return x
# 创建模型实例
model = SimpleNN()
# 打印模型
print(model)