scalarstop.model_template¶
A class that builds models with hyperparameters.
A ModelTemplate is a description of how to create a
compiled machine learning model and the hyperparameters that the
model depends on.
Subclass ModelTemplate to describe the architecture
of your model and the hyperparameters that are used to construct
the model (and the model’s optimizer).
Then, pass an instance of your ModelTemplate subclass
to as subclass of Model to train an
instance of a machine learning model created from your
ModelTemplate.
>>> import tensorflow as tf
>>> import scalarstop as sp
>>>
>>> class small_dense_10_way_classifier_v1(sp.ModelTemplate):
... @sp.dataclass
...
... class Hyperparams(sp.HyperparamsType):
... hidden_units: int
... optimizer: str = "adam"
...
... def new_model(self):
... model = tf.keras.Sequential(
... layers=[
... tf.keras.layers.Flatten(input_shape=(28, 28)),
... tf.keras.layers.Dense(
... units=self.hyperparams.hidden_units,
... activation="relu",
... ),
... tf.keras.layers.Dense(units=10)
... ],
... name=self.name,
... )
... model.compile(
... optimizer=self.hyperparams.optimizer,
... loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
... metrics=["accuracy"],
... )
... return model
>>> model_template = small_dense_10_way_classifier_v1(hyperparams=dict(hidden_units=20))
>>> model_template.name
'small_dense_10_way_classifier_v1-zc9r3do1baeeffafanjnjmou'
Module Contents¶
Classes¶
Describes machine learning model architectures and hyperparameters. Used to generate new machine learning model objects that are passed into |
- class ModelTemplate(*, hyperparams: Optional[Union[Mapping[str, Any], scalarstop.hyperparams.HyperparamsType]] = None, **kwargs)¶
Bases:
scalarstop._single_namespace.SingleNamespaceDescribes machine learning model architectures and hyperparameters. Used to generate new machine learning model objects that are passed into
Modelobjects.- Parameters
hyperparams – The hyperparameters to initialize this class with.
- Hyperparams :Type[scalarstop.hyperparams.HyperparamsType]¶
- new_model(self) Any¶
Create a new compiled model with the current hyperparameters.
When you override this method, make sure to create a new model object every single time this function is called.
- classmethod calculate_name(cls, *, hyperparams: Optional[Union[Mapping[str, Any], scalarstop.hyperparams.HyperparamsType]] = None) str¶
Calculate the hashed name of this object, given the hyperparameters.
This classmethod can be used to calculate what an object would be without actually having to call
__init__().
- property hyperparams(self) scalarstop.hyperparams.HyperparamsType¶
Returns a
HyperparamsTypeinstance containing hyperparameters.
- property hyperparams_flat(self) Dict[str, Any]¶
Returns a Python dictionary of “flattened” hyperparameters.
AppendDataBlobobjects modify a “parent”DataBlob, nesting the parent’s Hyperparams within theAppendDataBlob‘s own Hyperparams.This makes it hard to look up a given hyperparams key. A value at
parent_datablob.hyperparams.ais stored atchild_datablob.hyperparams.parent.hyperparams.a.This
hyperparams_flatproperty provides all nested hyperparams keys as a flat Python dictionary. If a childAppendDataBlobhas a hyperparameter key that that conflicts with the parent, the child’s value will overwrite the parent’s value.