CHAPTER 06

面向对象编程(OOP)

理解类和对象的思想。PyTorch 的神经网络、scikit-learn 的模型,全部都是类。掌握 OOP,才能真正读懂和使用 AI 框架。

class —— 定义类

类是对现实事物的抽象模板。对象是类的具体实例。比如"神经网络"是一个类,而"某个具体训练好的模型"就是一个对象(实例)。类把数据(属性)和操作数据的方法(函数)封装在一起。

核心概念

类 vs 对象

类是蓝图,对象是房子。class Dog 定义了"狗"这个概念(有名字、会叫);my_dog = Dog("旺财") 则创建了一只具体的狗。AI 框架中,nn.Linear 是类,layer = nn.Linear(128, 64) 是对象。

PYTHON · 类的基础
class MLModel:
    """机器学习模型的基础类"""

    # __init__ 是构造方法,创建对象时自动调用
    def __init__(self, name, learning_rate=0.001):
        # self 是对象本身,属性用 self. 前缀
        self.name = name
        self.lr   = learning_rate
        self.trained = False
        self.history = []

    def train(self, data):
        """模拟训练过程"""
        print(f"[{self.name}] 开始训练,lr={self.lr}")
        self.trained = True
        self.history.append({"loss": 0.42})

    def predict(self, x):
        """进行预测"""
        if not self.trained:
            raise RuntimeError("请先调用 train() 训练模型")
        return x * 0.8  # 简化版预测

    def summary(self):
        """打印模型信息"""
        print(f"模型: {self.name} | 已训练: {self.trained} | 历史: {len(self.history)} 轮")

# 创建两个对象(独立的实例,互不干扰)
model_a = MLModel("ResNet50", learning_rate=0.01)
model_b = MLModel("VGG16")  # 使用默认 lr

model_a.train(data)
model_a.summary()  # 模型: ResNet50 | 已训练: True | 历史: 1 轮
model_b.summary()  # 模型: VGG16   | 已训练: False | 历史: 0 轮

继承 —— 复用和扩展

继承让一个类拥有另一个类的所有属性和方法,并可以在此基础上扩展。这是 AI 框架最常用的设计模式:PyTorch 的所有模型都继承自 nn.Module,scikit-learn 的所有估计器都继承自 BaseEstimator

PYTHON · 继承
class BaseModel:
    """所有模型的基类"""
    def __init__(self, name):
        self.name = name
        self.trained = False

    def save(self, path):
        print(f"保存模型到 {path}")

    def load(self, path):
        print(f"从 {path} 加载模型")
        self.trained = True

# ClassifierModel 继承 BaseModel
class ClassifierModel(BaseModel):
    """分类模型"""
    def __init__(self, name, num_classes):
        # super() 调用父类的 __init__
        super().__init__(name)
        self.num_classes = num_classes

    # 新增方法(父类没有的)
    def predict_class(self, x):
        return x % self.num_classes

class RegressorModel(BaseModel):
    """回归模型"""
    def __init__(self, name):
        super().__init__(name)

    def predict_value(self, x):
        return x * 1.5

clf = ClassifierModel("ResNet", 10)
clf.save("resnet.pt")         # 继承自 BaseModel
print(clf.predict_class(7))   # ClassifierModel 自己的方法
🤖

PyTorch 中的继承

在 PyTorch 中,你定义神经网络时必须继承 nn.Module,并实现 __init__forward 方法。这正是面向对象继承的真实应用。第 10 章会详细展示这个模式。

特殊方法(魔术方法)

Python 类有许多以双下划线开头和结尾的特殊方法,它们定义了对象的内置行为,比如打印、比较、计算长度等。理解这些方法,能让你自己写的类用起来和内置类型一样自然。

PYTHON · 常用特殊方法
class Dataset:
    """模仿 PyTorch Dataset 的简化版"""

    def __init__(self, samples):
        self.samples = samples

    def __len__(self):
        # 让 len(dataset) 正常工作
        return len(self.samples)

    def __getitem__(self, idx):
        # 让 dataset[i] 下标访问正常工作
        return self.samples[idx]

    def __repr__(self):
        # 控制 print(dataset) 的输出
        return f"Dataset(size={len(self.samples)})"

    def __contains__(self, item):
        # 支持 item in dataset 语法
        return item in self.samples

data = ["cat", "dog", "bird"]
ds = Dataset(data)

print(len(ds))      # 3 → 调用 __len__
print(ds[0])        # cat → 调用 __getitem__
print(ds)           # Dataset(size=3) → 调用 __repr__
print("dog" in ds) # True → 调用 __contains__

# 甚至可以用 for 遍历(因为有 __len__ 和 __getitem__)
for sample in ds:
    print(sample)
特殊方法 触发时机 AI 框架中的例子
__init__ 创建对象时 model = MyNet()
__len__ len(obj) len(dataset)
__getitem__ obj[i] dataset[0]
__repr__ print(obj) 打印模型结构
__call__ obj(x) output = model(input)

类属性、类方法与静态方法

PYTHON · 类属性与静态方法
class ModelRegistry:
    """模型注册中心(类属性被所有实例共享)"""

    # 类属性:属于类本身,所有实例共享
    _registry = {}
    _count = 0

    def __init__(self, name):
        self.name = name
        ModelRegistry._count += 1
        ModelRegistry._registry[name] = self

    @classmethod
    def get_count(cls):
        """类方法:操作类本身,第一个参数是 cls"""
        return cls._count

    @classmethod
    def get_model(cls, name):
        return cls._registry.get(name)

    @staticmethod
    def is_valid_name(name):
        """静态方法:不需要 self 或 cls,只是工具函数"""
        return isinstance(name, str) and len(name) > 0

m1 = ModelRegistry("ResNet50")
m2 = ModelRegistry("BERT")

print(ModelRegistry.get_count())      # 2
print(ModelRegistry.get_model("BERT").name)  # BERT
print(ModelRegistry.is_valid_name("test"))   # True

综合示例:PyTorch 风格的神经网络类

看一个完整的例子,感受 OOP 在 AI 框架中的真实形态:

PYTHON · 模拟 PyTorch 模型定义
# 模拟 PyTorch 的 nn.Module 接口
class Module:
    """简化版 nn.Module 基类"""
    def __call__(self, x):
        # __call__ 让对象像函数一样被调用
        return self.forward(x)

    def forward(self, x):
        raise NotImplementedError("子类必须实现 forward 方法")

    def parameters(self):
        return []

# 继承 Module,定义自己的网络
class SimpleClassifier(Module):
    """二分类神经网络"""

    def __init__(self, input_dim, hidden_dim=64):
        self.input_dim  = input_dim
        self.hidden_dim = hidden_dim
        # 实际 PyTorch 这里是 nn.Linear 等层
        print(f"模型: {input_dim} → {hidden_dim} → 1")

    def forward(self, x):
        """前向传播:定义数据如何流经网络"""
        # 简化:实际是矩阵乘法 + 激活函数
        hidden  = x * 0.5    # 模拟线性层
        activated = max(0, hidden)  # 模拟 ReLU
        output  = activated * 0.3
        return output

    def __repr__(self):
        return f"SimpleClassifier(input={self.input_dim}, hidden={self.hidden_dim})"

# 使用方式和真实 PyTorch 几乎一样!
model = SimpleClassifier(input_dim=784, hidden_dim=128)
print(model)               # 调用 __repr__

# 像函数一样调用(调用 __call__,内部转发给 forward)
output = model(1.5)
print(f"输出: {output}")

OOP 是 AI 框架的基石

真实的 PyTorch 代码:class MyNet(nn.Module): def __init__(self): super().__init__() ... def forward(self, x): ...。看到了吗?这正是继承 + super() + __init__ + 方法覆盖的综合运用。学了本章,那段代码你已经完全能理解了。