Composing Modules
✦₊⁺ Overview
Modules can contain other modules as sub-modules. Sub-modules are automatically tracked when assigned to attributes — no extra registration needed. They appear in the state dict with dot-separated keys and are visible to all Module methods like named_modules(), state_dict(), and parameters().
1. Sub-Modules
Any Module assigned to self.<attr> is automatically registered as a sub-module:
import msgflux.nn as nn
class Normalizer(nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("prefix", "[normalized]")
def forward(self, text: str) -> str:
return f"{self.prefix} {text.strip().lower()}"
class Classifier(nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("label", "spam")
def forward(self, text: str) -> str:
return f"{text} → label={self.label}"
class Pipeline(nn.Module):
def __init__(self):
super().__init__()
self.normalizer = Normalizer() # auto-registered
self.classifier = Classifier() # auto-registered
def forward(self, text: str) -> str:
text = self.normalizer(text)
return self.classifier(text)
pipeline = Pipeline()
result = pipeline(" HELLO WORLD ")
print(result)
# [normalized] hello world → label=spam
2. State Dict with Nested Keys
Sub-module buffers and parameters appear in the state dict with dot-separated paths:
import msgflux.nn as nn
class Normalizer(nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("prefix", "[normalized]")
def forward(self, text: str) -> str:
return f"{self.prefix} {text.strip().lower()}"
class Classifier(nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("label", "spam")
def forward(self, text: str) -> str:
return f"{text} → label={self.label}"
class Pipeline(nn.Module):
def __init__(self):
super().__init__()
self.normalizer = Normalizer()
self.classifier = Classifier()
def forward(self, text: str) -> str:
text = self.normalizer(text)
return self.classifier(text)
pipeline = Pipeline()
print(pipeline.state_dict())
# {
# "normalizer.prefix": "[normalized]",
# "classifier.label": "spam"
# }
3. Inspecting the Module Tree
Use named_modules() to traverse the entire hierarchy:
import msgflux.nn as nn
class Normalizer(nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("prefix", "[normalized]")
def forward(self, text: str) -> str:
return f"{self.prefix} {text.strip().lower()}"
class Classifier(nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("label", "spam")
def forward(self, text: str) -> str:
return f"{text} → label={self.label}"
class Pipeline(nn.Module):
def __init__(self):
super().__init__()
self.normalizer = Normalizer()
self.classifier = Classifier()
def forward(self, text: str) -> str:
text = self.normalizer(text)
return self.classifier(text)
pipeline = Pipeline()
for name, module in pipeline.named_modules():
print(f"{name!r:20} → {type(module).__name__}")
# '' → Pipeline
# 'normalizer' → Normalizer
# 'classifier' → Classifier
Use named_children() for only direct children (one level deep):
for name, module in pipeline.named_children():
print(f"{name}: {type(module).__name__}")
# normalizer: Normalizer
# classifier: Classifier
4. Deep Nesting
Sub-modules can be nested to any depth. State dict keys reflect the full path:
import msgflux.nn as nn
class Formatter(nn.Module):
def __init__(self, tag: str):
super().__init__()
self.register_buffer("tag", tag)
def forward(self, text: str) -> str:
return f"<{self.tag}>{text}</{self.tag}>"
class InnerPipeline(nn.Module):
def __init__(self):
super().__init__()
self.bold = Formatter("b")
self.italic = Formatter("i")
def forward(self, text: str) -> str:
return self.italic(self.bold(text))
class OuterPipeline(nn.Module):
def __init__(self):
super().__init__()
self.inner = InnerPipeline()
self.wrapper = Formatter("div")
def forward(self, text: str) -> str:
return self.wrapper(self.inner(text))
outer = OuterPipeline()
print(outer("hello"))
# <div><i><b>hello</b></i></div>
print(outer.state_dict())
# {
# "inner.bold.tag": "b",
# "inner.italic.tag": "i",
# "wrapper.tag": "div"
# }
5. Loading State
Use load_state_dict() to restore a saved configuration across the whole tree:
import msgflux.nn as nn
class Formatter(nn.Module):
def __init__(self, tag: str):
super().__init__()
self.register_buffer("tag", tag)
def forward(self, text: str) -> str:
return f"<{self.tag}>{text}</{self.tag}>"
class InnerPipeline(nn.Module):
def __init__(self):
super().__init__()
self.bold = Formatter("b")
self.italic = Formatter("i")
def forward(self, text: str) -> str:
return self.italic(self.bold(text))
class OuterPipeline(nn.Module):
def __init__(self):
super().__init__()
self.inner = InnerPipeline()
self.wrapper = Formatter("div")
def forward(self, text: str) -> str:
return self.wrapper(self.inner(text))
outer = OuterPipeline()
# Save state
saved = outer.state_dict()
# Mutate
outer.inner.bold.tag = "strong"
print(outer("hello"))
# <div><i><strong>hello</strong></i></div>
# Restore
outer.load_state_dict(saved)
print(outer("hello"))
# <div><i><b>hello</b></i></div>
6. Container Modules
For common patterns, msgFlux provides dedicated containers — see their individual pages:
| Container | When to use |
|---|---|
nn.Sequential |
Fixed chain: output of step N feeds step N+1 |
nn.ModuleList |
Ordered collection iterated manually in forward |
nn.ModuleDict |
Named collection, selected by key at runtime |