Dataset API
Loading CSV
CSV_COLUMNS = ['fare_amount', 'pickuplon','pickuplat','dropofflon','dropofflat','passengers', 'key']
DEFAULTS = [[0.0], [-74.0], [40.0], [-74.0], [40.7], [1.0], ['nokey']]
# TODO: Create an appropriate input function read_dataset
def read_csv_line(row):
columns = tf.decode_csv(row, record_defaults=DEFAULTS)
features = {k: v for k, v in zip(CSV_COLUMNS, columns)}
del features["key"]
label = features.pop("fare_amount")
return features, label
def read_dataset(filename, mode, batch_size=512):
files = tf.data.Dataset.list_files(filename, shuffle=False)
dataset = files.flat_map(tf.data.TextLineDataset).map(read_csv_line)
if mode == tf.estimator.ModeKeys.TRAIN:
num_epochs = None
dataset = dataset.shuffle(seed=666, buffer_size=10*batch_size)
else:
num_epochs = 1
return dataset.repeat(num_epochs).batch(batch_size)
def get_train_input_fn():
return read_dataset('./taxi-train.csv', mode = tf.estimator.ModeKeys.TRAIN)
def get_valid_input_fn():
return read_dataset('./taxi-valid.csv', mode = tf.estimator.ModeKeys.EVAL)
No matches...