scalarstop.model_template
¶
A class that builds models with hyperparameters.
A ModelTemplate
is a description of how to create a
compiled machine learning model and the hyperparameters that the
model depends on.
Subclass ModelTemplate
to describe the architecture
of your model and the hyperparameters that are used to construct
the model (and the model’s optimizer).
Then, pass an instance of your ModelTemplate
subclass
to as subclass of Model
to train an
instance of a machine learning model created from your
ModelTemplate
.
>>> import tensorflow as tf
>>> import scalarstop as sp
>>>
>>> class small_dense_10_way_classifier_v1(sp.ModelTemplate):
... @sp.dataclass
...
... class Hyperparams(sp.HyperparamsType):
... hidden_units: int
... optimizer: str = "adam"
...
... def new_model(self):
... model = tf.keras.Sequential(
... layers=[
... tf.keras.layers.Flatten(input_shape=(28, 28)),
... tf.keras.layers.Dense(
... units=self.hyperparams.hidden_units,
... activation="relu",
... ),
... tf.keras.layers.Dense(units=10)
... ],
... name=self.name,
... )
... model.compile(
... optimizer=self.hyperparams.optimizer,
... loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
... metrics=["accuracy"],
... )
... return model
>>> model_template = small_dense_10_way_classifier_v1(hyperparams=dict(hidden_units=20))
>>> model_template.name
'small_dense_10_way_classifier_v1-zc9r3do1baeeffafanjnjmou'
Module Contents¶
Classes¶
Describes machine learning model architectures and hyperparameters. Used to generate new machine learning model objects that are passed into |
- class ModelTemplate(*, hyperparams: Optional[Union[Mapping[str, Any], scalarstop.hyperparams.HyperparamsType]] = None, **kwargs)¶
Bases:
scalarstop._single_namespace.SingleNamespace
Describes machine learning model architectures and hyperparameters. Used to generate new machine learning model objects that are passed into
Model
objects.- Parameters
hyperparams – The hyperparameters to initialize this class with.
- Hyperparams :Type[scalarstop.hyperparams.HyperparamsType]¶
- new_model(self) Any ¶
Create a new compiled model with the current hyperparameters.
When you override this method, make sure to create a new model object every single time this function is called.
- classmethod calculate_name(cls, *, hyperparams: Optional[Union[Mapping[str, Any], scalarstop.hyperparams.HyperparamsType]] = None) str ¶
Calculate the hashed name of this object, given the hyperparameters.
This classmethod can be used to calculate what an object would be without actually having to call
__init__()
.
- property hyperparams(self) scalarstop.hyperparams.HyperparamsType ¶
Returns a
HyperparamsType
instance containing hyperparameters.
- property hyperparams_flat(self) Dict[str, Any] ¶
Returns a Python dictionary of “flattened” hyperparameters.
AppendDataBlob
objects modify a “parent”DataBlob
, nesting the parent’s Hyperparams within theAppendDataBlob
‘s own Hyperparams.This makes it hard to look up a given hyperparams key. A value at
parent_datablob.hyperparams.a
is stored atchild_datablob.hyperparams.parent.hyperparams.a
.This
hyperparams_flat
property provides all nested hyperparams keys as a flat Python dictionary. If a childAppendDataBlob
has a hyperparameter key that that conflicts with the parent, the child’s value will overwrite the parent’s value.