How to use tf.contrib.estimator.forward_features


I'm trying to use forward_features to get instance keys for cloudml, but I always get errors that I'm not sure how to fix. The preprocessing section that uses tf.Transform is a modification of <a href="https://github.com/GoogleCloudPlatform/cloudml-samples/tree/master/reddit_tft" rel="nofollow">https://github.com/GoogleCloudPlatform/cloudml-samples/tree/master/reddit_tft</a> where the instance key is a string and everything else is a bunch of floats.

def gzip_reader_fn(): return tf.TFRecordReader(options=tf.python_io.TFRecordOptions( compression_type=tf.python_io.TFRecordCompressionType.GZIP)) def get_transformed_reader_input_fn(transformed_metadata, transformed_data_paths, batch_size, mode): """Wrap the get input features function to provide the runtime arguments.""" return input_fn_maker.build_training_input_fn( metadata=transformed_metadata, file_pattern=( transformed_data_paths[0] if len(transformed_data_paths) == 1 else transformed_data_paths), training_batch_size=batch_size, label_keys=[], #feature_keys=FEATURE_COLUMNS, #key_feature_name='example_id', reader=gzip_reader_fn, reader_num_threads=4, queue_capacity=batch_size * 2, randomize_input=(mode != tf.contrib.learn.ModeKeys.EVAL), num_epochs=(1 if mode == tf.contrib.learn.ModeKeys.EVAL else None)) estimator = KMeansClustering(num_clusters=8, initial_clusters=KMeansClustering.KMEANS_PLUS_PLUS_INIT, kmeans_plus_plus_num_retries=32, relative_tolerance=0.0001) estimator = tf.contrib.estimator.forward_features( estimator, 'example_id') train_input_fn = get_transformed_reader_input_fn( transformed_metadata, args.train_data_paths, args.batch_size, tf.contrib.learn.ModeKeys.TRAIN) estimator.train(input_fn=train_input_fn)

If I were to pass in the keys column along side the training features, then I get the error Tensors in list passed to 'values' of 'ConcatV2' Op have types [float32, float32, string, float32, float32, float32, float32, float32, float32, f loat32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32] that don't all match. However, if I were to not pass in the instance keys during training, then I get the value error saying that the key doesn't exist in the features. Also, if I were to change the key column name in the forward_features section from 'example_id' to some random name that isn't a column, then I still get the former error instead of the latter. Can anyone help me make sense of this?


Please check the following:

(1) Forward features only works with TF.estimator. Ensure that you are not using contrib.learn.estimator. (update: you are using a class that inherits from tf.estimator)

(2) Make sure your input function reads in the key-column. So, the key column has to be part of your input dataset.

(3) In the case of tf.transform, #2 means that the transform metadata has to reflect the schema of the key. The error message you are seeing seems to indicate that the schema specified it as a float and it's actually a string. Or something like that.

(4) Make sure the key column is NOT used by your model. So, you should not create a FeatureColumn with it. In other words, the model will simply pass through the key that is read by the input_fn to the predictor.

(5) If you don't see the key in the output, see if this workaround helps you:

<a href="https://github.com/GoogleCloudPlatform/training-data-analyst/blob/master/courses/machine_learning/deepdive/07_structured/babyweight/trainer/model.py#L132" rel="nofollow">https://github.com/GoogleCloudPlatform/training-data-analyst/blob/master/courses/machine_learning/deepdive/07_structured/babyweight/trainer/model.py#L132</a>

Essentially, forward_features changes the graph in memory but not the exported signature. My workaround fixes this.


  • Replace ALAsset object in iOS ALAssetsLibrary
  • DSC Custom Resource Dependencies
  • UWP app uploading fails Microsoft store
  • Do the git repository data structures use a canonical encoding?
  • Fastest Freely Redistributable Database for Java
  • database dump to text file with side-effects for “rows completed”
  • Angular $http transformResponse and cache
  • Does Firebase guarantee that data set using updateValues or setValue is available in the backend as
  • Get image create date in PHP GD
  • Orchard 1.8 Unable to hide metadata (published date) from custom part using Placement.info
  • Create table in MySQL based on reflected metadata from MSSQL using SQLAlchemy
  • Kubernetes pod autoscaling out of sync with Instance Group autoscaling
  • SpringBoot Couchbase Integration
  • Cannot connect to a mongodb service in a Kubernetes cluster
  • The remote server returned an unexpected response: (400) Bad Request
  • Writing dataframe to postgres database
  • How do I include a single-quote in MSBuild item transformation seperator?
  • Printing string representations of xattr hex output
  • How can I prevent GCE from copying ssh keys to all new instances?
  • scala : Match type argument for an object
  • Stored Procedure with dynamic result into temp table
  • Selecting one checkbox in loop-generated checkboxes from checkboxlist
  • container engine kubernetes and ssl
  • Openstack.Net SDK cannot access services
  • How to get the revision of an item with Dropbox API
  • Are mysqli_result::free and mysqli_stmt::free_result the same?
  • cpan command gives the error “Can't locate B.pm in @INC”
  • Okta SignIn Widget with SAML
  • Easy convert betwen SQLAlchemy columns and data types?
  • Spring SAML Security - Multiple IDP Metadata configuration for two different ADFS server
  • Add filename and length parameter to WCF stream when Transfermode = Stream
  • What to do if “git push heroku master” failed?
  • Using Login with Paypal and using OpenID with AWS Cognito
  • How do I get name of the target table and column of foreign key column with plain JDBC
  • How to Compose OSGi Based project with C++ based project?
  • Heroku push rejected - Hartl's Rails 3.2 tutorial
  • How do I formally document a C# Attribute in UML?
  • Read text file and split every line in MSBuild
  • SSO with signing and signature validation doesn't work
  • File not found error Google Drive API