面向对象编程(OOP)
理解类和对象的思想。PyTorch 的神经网络、scikit-learn 的模型,全部都是类。掌握 OOP,才能真正读懂和使用 AI 框架。
class —— 定义类
类是对现实事物的抽象模板。对象是类的具体实例。比如"神经网络"是一个类,而"某个具体训练好的模型"就是一个对象(实例)。类把数据(属性)和操作数据的方法(函数)封装在一起。
类 vs 对象
类是蓝图,对象是房子。class Dog 定义了"狗"这个概念(有名字、会叫);my_dog = Dog("旺财") 则创建了一只具体的狗。AI 框架中,nn.Linear 是类,layer = nn.Linear(128, 64) 是对象。
class MLModel:
"""机器学习模型的基础类"""
# __init__ 是构造方法,创建对象时自动调用
def __init__(self, name, learning_rate=0.001):
# self 是对象本身,属性用 self. 前缀
self.name = name
self.lr = learning_rate
self.trained = False
self.history = []
def train(self, data):
"""模拟训练过程"""
print(f"[{self.name}] 开始训练,lr={self.lr}")
self.trained = True
self.history.append({"loss": 0.42})
def predict(self, x):
"""进行预测"""
if not self.trained:
raise RuntimeError("请先调用 train() 训练模型")
return x * 0.8 # 简化版预测
def summary(self):
"""打印模型信息"""
print(f"模型: {self.name} | 已训练: {self.trained} | 历史: {len(self.history)} 轮")
# 创建两个对象(独立的实例,互不干扰)
model_a = MLModel("ResNet50", learning_rate=0.01)
model_b = MLModel("VGG16") # 使用默认 lr
model_a.train(data)
model_a.summary() # 模型: ResNet50 | 已训练: True | 历史: 1 轮
model_b.summary() # 模型: VGG16 | 已训练: False | 历史: 0 轮
继承 —— 复用和扩展
继承让一个类拥有另一个类的所有属性和方法,并可以在此基础上扩展。这是 AI 框架最常用的设计模式:PyTorch 的所有模型都继承自 nn.Module,scikit-learn 的所有估计器都继承自 BaseEstimator。
class BaseModel:
"""所有模型的基类"""
def __init__(self, name):
self.name = name
self.trained = False
def save(self, path):
print(f"保存模型到 {path}")
def load(self, path):
print(f"从 {path} 加载模型")
self.trained = True
# ClassifierModel 继承 BaseModel
class ClassifierModel(BaseModel):
"""分类模型"""
def __init__(self, name, num_classes):
# super() 调用父类的 __init__
super().__init__(name)
self.num_classes = num_classes
# 新增方法(父类没有的)
def predict_class(self, x):
return x % self.num_classes
class RegressorModel(BaseModel):
"""回归模型"""
def __init__(self, name):
super().__init__(name)
def predict_value(self, x):
return x * 1.5
clf = ClassifierModel("ResNet", 10)
clf.save("resnet.pt") # 继承自 BaseModel
print(clf.predict_class(7)) # ClassifierModel 自己的方法
PyTorch 中的继承
在 PyTorch 中,你定义神经网络时必须继承 nn.Module,并实现 __init__ 和 forward 方法。这正是面向对象继承的真实应用。第 10 章会详细展示这个模式。
特殊方法(魔术方法)
Python 类有许多以双下划线开头和结尾的特殊方法,它们定义了对象的内置行为,比如打印、比较、计算长度等。理解这些方法,能让你自己写的类用起来和内置类型一样自然。
class Dataset:
"""模仿 PyTorch Dataset 的简化版"""
def __init__(self, samples):
self.samples = samples
def __len__(self):
# 让 len(dataset) 正常工作
return len(self.samples)
def __getitem__(self, idx):
# 让 dataset[i] 下标访问正常工作
return self.samples[idx]
def __repr__(self):
# 控制 print(dataset) 的输出
return f"Dataset(size={len(self.samples)})"
def __contains__(self, item):
# 支持 item in dataset 语法
return item in self.samples
data = ["cat", "dog", "bird"]
ds = Dataset(data)
print(len(ds)) # 3 → 调用 __len__
print(ds[0]) # cat → 调用 __getitem__
print(ds) # Dataset(size=3) → 调用 __repr__
print("dog" in ds) # True → 调用 __contains__
# 甚至可以用 for 遍历(因为有 __len__ 和 __getitem__)
for sample in ds:
print(sample)
| 特殊方法 | 触发时机 | AI 框架中的例子 |
|---|---|---|
__init__ |
创建对象时 | model = MyNet() |
__len__ |
len(obj) |
len(dataset) |
__getitem__ |
obj[i] |
dataset[0] |
__repr__ |
print(obj) |
打印模型结构 |
__call__ |
obj(x) |
output = model(input) |
类属性、类方法与静态方法
class ModelRegistry:
"""模型注册中心(类属性被所有实例共享)"""
# 类属性:属于类本身,所有实例共享
_registry = {}
_count = 0
def __init__(self, name):
self.name = name
ModelRegistry._count += 1
ModelRegistry._registry[name] = self
@classmethod
def get_count(cls):
"""类方法:操作类本身,第一个参数是 cls"""
return cls._count
@classmethod
def get_model(cls, name):
return cls._registry.get(name)
@staticmethod
def is_valid_name(name):
"""静态方法:不需要 self 或 cls,只是工具函数"""
return isinstance(name, str) and len(name) > 0
m1 = ModelRegistry("ResNet50")
m2 = ModelRegistry("BERT")
print(ModelRegistry.get_count()) # 2
print(ModelRegistry.get_model("BERT").name) # BERT
print(ModelRegistry.is_valid_name("test")) # True
综合示例:PyTorch 风格的神经网络类
看一个完整的例子,感受 OOP 在 AI 框架中的真实形态:
# 模拟 PyTorch 的 nn.Module 接口
class Module:
"""简化版 nn.Module 基类"""
def __call__(self, x):
# __call__ 让对象像函数一样被调用
return self.forward(x)
def forward(self, x):
raise NotImplementedError("子类必须实现 forward 方法")
def parameters(self):
return []
# 继承 Module,定义自己的网络
class SimpleClassifier(Module):
"""二分类神经网络"""
def __init__(self, input_dim, hidden_dim=64):
self.input_dim = input_dim
self.hidden_dim = hidden_dim
# 实际 PyTorch 这里是 nn.Linear 等层
print(f"模型: {input_dim} → {hidden_dim} → 1")
def forward(self, x):
"""前向传播:定义数据如何流经网络"""
# 简化:实际是矩阵乘法 + 激活函数
hidden = x * 0.5 # 模拟线性层
activated = max(0, hidden) # 模拟 ReLU
output = activated * 0.3
return output
def __repr__(self):
return f"SimpleClassifier(input={self.input_dim}, hidden={self.hidden_dim})"
# 使用方式和真实 PyTorch 几乎一样!
model = SimpleClassifier(input_dim=784, hidden_dim=128)
print(model) # 调用 __repr__
# 像函数一样调用(调用 __call__,内部转发给 forward)
output = model(1.5)
print(f"输出: {output}")
OOP 是 AI 框架的基石
真实的 PyTorch 代码:class MyNet(nn.Module): def __init__(self): super().__init__() ... def forward(self, x): ...。看到了吗?这正是继承 + super() + __init__ + 方法覆盖的综合运用。学了本章,那段代码你已经完全能理解了。