Source code for ada.models.network_factory
import torch
import ada.models.modules as amm
[docs]def get_module_fn(module_name):
if module_name == "feedforward":
return amm.FeatureExtractFF
elif module_name == "digits":
return amm.FeatureExtractorDigits
elif module_name == "AlexNet":
return amm.ResNet50AlexNetFeatureFeature
elif module_name == "ResNet18":
return amm.ResNet18Feature
elif module_name == "ResNet34":
return amm.ResNet34Feature
elif module_name == "ResNet50":
return amm.ResNet50Feature
elif module_name == "ResNet101":
return amm.ResNet101Feature
elif module_name == "ResNet152":
return amm.ResNet152Feature
else:
raise NotImplementedError(
"Unknown module {module_name}. Please define the function in modules.py and register its use here."
)
[docs]class NetworkFactory:
"""
This class takes a network configuration dictionary and
creates the corresponding modules.
"""
def __init__(self, network_config):
self._params = network_config
[docs] def get_task_classifier(self, input_dim, n_classes):
cfg = self._params["task"]
if cfg["name"] == "feedforward":
return amm.FFSoftmaxClassifier(
input_dim,
n_classes=n_classes,
name="h",
hidden=cfg.get("hidden_size", ()),
)
if cfg["name"] == "digits":
return amm.DataClassifierDigits(input_size=input_dim, n_class=n_classes)
raise NotImplementedError()
[docs] def get_critic_network(self, input_dim):
cfg = self._params["critic"]
if cfg["name"] == "feedforward":
return amm.FFSoftmaxClassifier(
input_dim, n_classes=2, name="g", hidden=cfg.get("hidden_size", ())
)
elif cfg["name"] == "digits":
return amm.DomainClassifierDigits(input_size=input_dim)
raise NotImplementedError()