Python + Pytorch 指北
这篇博客主要记录自己在学习 python 过程中遇到的一些细节或坑。
类属性 vs 实例属性
起因是在看李沐的《动手学深度学习》,发现其中有这样的对于神经网络顺序模块的实现。
class MySequential(nn.Module):
def __init__(self, *args):
super().__init__()
for idx, module in enumerate(args):
self._modules[str(idx)] = module
def forward(self, x):
for block in self._modules.values():
x = block(x)
return x
说明中有提到 self._modules
是一个 OrderedDict, 当时感觉疑惑,因为正常的python语法是不支持直接这样定义有序字典的,查阅相关资料后发现原来 _modules
是在实例化一个对象时自动定义好的属性,是在执行父类的初始化的时候定义的。
我试图通过 print(nn.Module._modules)
来查看其内容,但程序报错,显示 nn.Module
不存在这样的属性,但是如果我们事先实例化一个 nn.Module
的对象 Model
,然后再去打印 Model._modules
就是可以的。
源码是这么写的:
class Module:
......
_modules: Dict[str, Optional['Module']]
call_super_init: bool = False
_compiled_call_impl: Optional[Callable] = None
def __init__(self, *args, **kwargs) -> None:
......
super().__setattr__("_modules", {})
......
这里定义在 __init__()
函数之前的是类的属性的声明,可以区分为两类:一类是注释类型 + 赋值;另一类只有注释类型而没有赋值,在 __init__()
函数中使用 self.xxx = xxx
或者 super().__setattr__()
进行赋值。
使用类属性声明的作用:
只是提前声明了「以后会有这些属性」 不赋初值(初值一般在 init() 里动态生成) 用于代码提示、类型检查(像用 MyPy 静态分析工具时)
换句话说,它们是为了:
✅ 让 IDE(如 VS Code、PyCharm)有智能提示
✅ 让开发者知道这些属性存在及其预期类型
✅ 对实际运行没影响(Python 解释器并不强制要求)
区别在于,声明的同时直接进行赋值的属性是类属性,在初始化时进行赋值的是实例属性。
类属性是能够直接通过类的属性查找到的(比如可以直接 print(nn.Module.call_super_init
),且是不同的实例共享的,而实例属性是在定义实例的时候赋予该实例的属性,仅属于该实例,不同实例之间不共享。
关于是否共享,下面有个例子可以说明:
class MyModule:
_modules = {} # 类属性
A = MyModule()
B = MyModule()
A._modules['1'] = 1
print(B._modules) # 会输出 {'1': 1}
A._modules = {'2'}
print(A._modules) # 会输出 {'2'}
print(B._modules) # 会输出 {'1': 1}
注意,当 Python 编译器看到 A._modules
时会优先查找实例中是否有同名属性,如果没有才会去类中寻找,代码 A._modules['1'] = 1
是对内容进行了修改,并不会创建实例属性。而 A._modules = {'2'}
则是修改了引用,此时会创建实例的属性。下面还有个例子可以帮助理解:
class cls:
clsattr = 1
def __init__(self):
self.objattr = 2
A = cls()
B = cls()
print(A.clsattr) # 1
A.clsattr = 3
print(A.clsattr) # 3
print(B.clsattr) # 1
C = cls()
print(C.clsattr) # 1
A.__class__.clsattr = 3
print(B.clsattr) # 3
也就是说如果不能通过引用的方式利用实例去修改类属性时,必须先拿到该实例的类才行 A.__class__.xxx = xxx
,如果直接赋值(如 A.clsattr = 3
)实际上只是在该实例上新建了一个属性(效果相当于 A.__dict__[clsattr] = 3
)。
pytorch的这个设计的好处在于:
- 模块的状态(如子模块、参数、缓冲区、hooks 等)都保存在实例上,而不是类上。
- 因为每个模型实例都有自己的结构,不能共享
_modules
字典。 - 类只是定义了行为(方法、接口),具体的数据和配置是在每个实例中分开存储的。
如果类直接挂 _modules
,那所有实例就会共享一份,很容易出错。这是pytorch的核心设计哦!
named_parameters vs state_dict
当我们需要查看模型的内部参数的时候,经常会用到这两个属性,其父类是 nn.Module
,然而这两个属性是有区别的:
- named_parameters
返回模型中所有可学习参数 nn.Parameter
,以 (name, parameter)
的形式迭代(通常是权重和偏置)。
常用场景:当需要把模型的参数传入到优化器中时常使用,如 optimizer = torch.optim.Adam(model.parameters())
这里传入的是
model.parameters()
是因为其不含有参数的名称name
- state_dict
返回模型中所有 参数 + 缓冲区(包括 nn.Parameter
和注册的 buffer,比如 BatchNorm 中的 running_mean、running_var),以 {name: tensor}
的形式组成的 OrderedDict
。
它是模型保存和加载的核心接口:
- 保存模型:
torch.save(model.state_dict(), 'model.pth')
- 加载模型:
model.load_state_dict(torch.load('model.pth'))
模型保存方法
当我们想要把训练好模型想要将其保存到本地,或是在训练过程中临时保存防止故障导致数据丢失时,通常有两种方式保存模型的参数。
torch.save(model, "xxx.pth")
这种方法是保存了模型的实例,其使用 Python 的 pickle 序列化机制 来保存整个对象,包括模型的类名、模块路径、方法、属性等,而不仅仅是模型的权重。
这意味着当我们想要加载模型时,必须提供完全一致的环境,包括:类名一致、模块路径一致、属性和结构一致,否则 pickle 在反序列化时找不到对应类,或者加载出错。
典型的问题如下:
- 改了类名:原来叫
class MyNet
,后来改成class MyNetwork
→ 加载失败。 - 改了文件位置:原来在
model.py
,后来移到models/model.py
→ 加载失败。 - 改了属性:原来
self.hidden = nn.Linear(...)
,后来换成self.hidden_layer = nn.Linear(...)
→ 加载成功但属性不对应,可能会出 bug。
torch.save(model.state_dict, "xxx.pth")
这种方法只保存模型的参数字典(OrderedDict),并不包含模型的架构。
优点:
- 文件更小、更简单
- 加载更灵活(只要你有同样架构的实例就能加载)
- 在不同代码版本/环境中更稳健(尤其跨机器、跨版本)
加载方式如下:
model = MyModel(*args)
model.load_state_dict(torch.load("xxx.pth"))
两种保存方法的异同点总结:
对比项 | 保存整个模型 (torch.save(model) ) |
保存 state_dict (torch.save(model.state_dict()) ) |
---|---|---|
保存内容 | 模型架构 + 参数 | 仅参数(权重字典) |
文件大小 | 更大 | 更小 |
加载依赖 | 需要完全相同的类定义和代码环境 | 只需提供同样架构的实例 |
跨版本兼容性 | 差(代码变化可能导致无法加载) | 好(只要匹配架构就行) |
推荐使用场景 | 快速实验、临时保存 | 生产环境、发布模型、长期保存 |
.contiguous()
这玩意讲起来还挺复杂,所以我直接引一篇知乎博客,感觉写得很详细了。
简单来讲就是让一个张量的元信息(行、列读取规则)和数据在内存中的实际存储位置相一致。
Leave a comment