面向对象编程(OOP)
理解类和对象的思想。PyTorch 的神经网络、scikit-learn 的模型,全部都是类。掌握 OOP,才能真正读懂和使用 AI 框架。
面向对象核心术语
-
类(Class)
对象的蓝图或模板,定义了一类对象共有的属性和方法。例如
nn.Linear是一个类,定义了"全连接层应该有权重、偏置和前向传播方法"。 -
对象/实例(Object / Instance)
类的具体化。
layer = nn.Linear(128, 64)创建了一个nn.Linear的实例,它有自己独立的权重参数。同一个类可以创建无数个互不干扰的实例。 -
封装(Encapsulation)
将数据(属性)和操作数据的方法捆绑在一起,隐藏内部实现细节。你调用
model.train()不需要知道内部如何更新参数,这就是封装。 -
继承(Inheritance)
子类自动获得父类的所有属性和方法,并可以扩展或重写它们。PyTorch 的所有模型都继承
nn.Module,获得参数管理、设备移动、训练/评估模式切换等能力。 -
多态(Polymorphism)
不同类的对象可以用相同的接口调用。scikit-learn 的所有模型都有
fit()和predict(),可以用同一套代码测试不同算法,这就是多态。 -
self 参数
类方法的第一个参数,代表调用该方法的实例本身。Python 不自动传入,必须显式写出(虽然约定名为 self,但本质上只是一个普通参数名)。
self.attr访问实例属性,cls.attr访问类属性。
class —— 定义类
类是对现实事物的抽象模板。对象是类的具体实例。比如"神经网络"是一个类,而"某个具体训练好的模型"就是一个对象(实例)。类把数据(属性)和操作数据的方法(函数)封装在一起。
类 vs 对象
类是蓝图,对象是房子。class Dog 定义了"狗"这个概念(有名字、会叫);my_dog = Dog("旺财") 则创建了一只具体的狗。AI 框架中,nn.Linear 是类,layer = nn.Linear(128, 64) 是对象。
class MLModel:
"""机器学习模型的基础类"""
# __init__ 是构造方法,创建对象时自动调用
def __init__(self, name, learning_rate=0.001):
# self 是对象本身,属性用 self. 前缀
self.name = name
self.lr = learning_rate
self.trained = False
self.history = []
def train(self, data):
"""模拟训练过程"""
print(f"[{self.name}] 开始训练,lr={self.lr}")
self.trained = True
self.history.append({"loss": 0.42})
def predict(self, x):
"""进行预测"""
if not self.trained:
raise RuntimeError("请先调用 train() 训练模型")
return x * 0.8 # 简化版预测
def summary(self):
"""打印模型信息"""
print(f"模型: {self.name} | 已训练: {self.trained} | 历史: {len(self.history)} 轮")
# 创建两个对象(独立的实例,互不干扰)
model_a = MLModel("ResNet50", learning_rate=0.01)
model_b = MLModel("VGG16") # 使用默认 lr
model_a.train(data)
model_a.summary() # 模型: ResNet50 | 已训练: True | 历史: 1 轮
model_b.summary() # 模型: VGG16 | 已训练: False | 历史: 0 轮
继承 —— 复用和扩展
继承让一个类拥有另一个类的所有属性和方法,并可以在此基础上扩展。这是 AI 框架最常用的设计模式:PyTorch 的所有模型都继承自 nn.Module,scikit-learn 的所有估计器都继承自 BaseEstimator。
class BaseModel:
"""所有模型的基类"""
def __init__(self, name):
self.name = name
self.trained = False
def save(self, path):
print(f"保存模型到 {path}")
def load(self, path):
print(f"从 {path} 加载模型")
self.trained = True
# ClassifierModel 继承 BaseModel
class ClassifierModel(BaseModel):
"""分类模型"""
def __init__(self, name, num_classes):
# super() 调用父类的 __init__
super().__init__(name)
self.num_classes = num_classes
# 新增方法(父类没有的)
def predict_class(self, x):
return x % self.num_classes
class RegressorModel(BaseModel):
"""回归模型"""
def __init__(self, name):
super().__init__(name)
def predict_value(self, x):
return x * 1.5
clf = ClassifierModel("ResNet", 10)
clf.save("resnet.pt") # 继承自 BaseModel
print(clf.predict_class(7)) # ClassifierModel 自己的方法
PyTorch 中的继承
在 PyTorch 中,你定义神经网络时必须继承 nn.Module,并实现 __init__ 和 forward 方法。这正是面向对象继承的真实应用。第 10 章会详细展示这个模式。
特殊方法(魔术方法)
Python 类有许多以双下划线开头和结尾的特殊方法,它们定义了对象的内置行为,比如打印、比较、计算长度等。理解这些方法,能让你自己写的类用起来和内置类型一样自然。
class Dataset:
"""模仿 PyTorch Dataset 的简化版"""
def __init__(self, samples):
self.samples = samples
def __len__(self):
# 让 len(dataset) 正常工作
return len(self.samples)
def __getitem__(self, idx):
# 让 dataset[i] 下标访问正常工作
return self.samples[idx]
def __repr__(self):
# 控制 print(dataset) 的输出
return f"Dataset(size={len(self.samples)})"
def __contains__(self, item):
# 支持 item in dataset 语法
return item in self.samples
data = ["cat", "dog", "bird"]
ds = Dataset(data)
print(len(ds)) # 3 → 调用 __len__
print(ds[0]) # cat → 调用 __getitem__
print(ds) # Dataset(size=3) → 调用 __repr__
print("dog" in ds) # True → 调用 __contains__
# 甚至可以用 for 遍历(因为有 __len__ 和 __getitem__)
for sample in ds:
print(sample)
| 特殊方法 | 触发时机 | AI 框架中的例子 |
|---|---|---|
__init__ |
创建对象时 | model = MyNet() |
__len__ |
len(obj) |
len(dataset) |
__getitem__ |
obj[i] |
dataset[0] |
__repr__ |
print(obj) |
打印模型结构 |
__call__ |
obj(x) |
output = model(input) |
类属性、类方法与静态方法
class ModelRegistry:
"""模型注册中心(类属性被所有实例共享)"""
# 类属性:属于类本身,所有实例共享
_registry = {}
_count = 0
def __init__(self, name):
self.name = name
ModelRegistry._count += 1
ModelRegistry._registry[name] = self
@classmethod
def get_count(cls):
"""类方法:操作类本身,第一个参数是 cls"""
return cls._count
@classmethod
def get_model(cls, name):
return cls._registry.get(name)
@staticmethod
def is_valid_name(name):
"""静态方法:不需要 self 或 cls,只是工具函数"""
return isinstance(name, str) and len(name) > 0
m1 = ModelRegistry("ResNet50")
m2 = ModelRegistry("BERT")
print(ModelRegistry.get_count()) # 2
print(ModelRegistry.get_model("BERT").name) # BERT
print(ModelRegistry.is_valid_name("test")) # True
综合示例:PyTorch 风格的神经网络类
看一个完整的例子,感受 OOP 在 AI 框架中的真实形态:
# 模拟 PyTorch 的 nn.Module 接口
class Module:
"""简化版 nn.Module 基类"""
def __call__(self, x):
# __call__ 让对象像函数一样被调用
return self.forward(x)
def forward(self, x):
raise NotImplementedError("子类必须实现 forward 方法")
def parameters(self):
return []
# 继承 Module,定义自己的网络
class SimpleClassifier(Module):
"""二分类神经网络"""
def __init__(self, input_dim, hidden_dim=64):
self.input_dim = input_dim
self.hidden_dim = hidden_dim
# 实际 PyTorch 这里是 nn.Linear 等层
print(f"模型: {input_dim} → {hidden_dim} → 1")
def forward(self, x):
"""前向传播:定义数据如何流经网络"""
# 简化:实际是矩阵乘法 + 激活函数
hidden = x * 0.5 # 模拟线性层
activated = max(0, hidden) # 模拟 ReLU
output = activated * 0.3
return output
def __repr__(self):
return f"SimpleClassifier(input={self.input_dim}, hidden={self.hidden_dim})"
# 使用方式和真实 PyTorch 几乎一样!
model = SimpleClassifier(input_dim=784, hidden_dim=128)
print(model) # 调用 __repr__
# 像函数一样调用(调用 __call__,内部转发给 forward)
output = model(1.5)
print(f"输出: {output}")
OOP 是 AI 框架的基石
真实的 PyTorch 代码:class MyNet(nn.Module): def __init__(self): super().__init__() ... def forward(self, x): ...。看到了吗?这正是继承 + super() + __init__ + 方法覆盖的综合运用。学了本章,那段代码你已经完全能理解了。
属性访问控制:@property 与描述符
Python 没有真正的私有属性,但提供了 @property 装饰器实现受控的属性访问,以及更底层的描述符协议。
class NeuralNetwork:
"""演示 @property 的神经网络类"""
def __init__(self, learning_rate: float):
# 约定:单下划线前缀表示"内部使用"
self._learning_rate = learning_rate
self._weights = []
@property
def learning_rate(self):
"""访问学习率(getter)"""
return self._learning_rate
@learning_rate.setter
def learning_rate(self, value):
"""设置学习率(setter)—— 加入验证逻辑"""
if not (1e-8 < value < 1.0):
raise ValueError(f"学习率必须在 (1e-8, 1.0) 之间,得到 {value}")
self._learning_rate = value
print(f"学习率已更新为 {value}")
@property
def num_parameters(self):
"""计算属性(只读):参数数量"""
return sum(len(w) for w in self._weights)
net = NeuralNetwork(0.001)
print(net.learning_rate) # 0.001(调用 getter)
net.learning_rate = 0.01 # 学习率已更新为 0.01(调用 setter)
# net.learning_rate = 2.0 # ValueError:超出范围
Python 没有真正的私有属性
单下划线 _name 只是"约定为内部使用",仍然可以从外部访问。双下划线 __name 会触发"名称改写"(Name Mangling),变成 _ClassName__name,使外部直接访问更困难,但仍非真正私有。AI 框架中通常用单下划线约定,不依赖名称改写。
本章小结
面向对象编程(OOP)将数据和方法封装在类中。核心概念:类定义模板,__init__ 初始化实例,self 引用实例自身;继承用 class Child(Parent),子类可覆盖父类方法;特殊方法(双下划线)定义内置行为,__len__/__getitem__/__call__ 在 PyTorch Dataset 和模型中大量使用;@property 实现受控属性访问,加入验证逻辑。PyTorch 中定义模型必须继承 nn.Module 并实现 forward 方法,这是本章最重要的实际应用。