Dataset API

Loading CSV

CSV_COLUMNS = ['fare_amount', 'pickuplon','pickuplat','dropofflon','dropofflat','passengers', 'key']
DEFAULTS = [[0.0], [-74.0], [40.0], [-74.0], [40.7], [1.0], ['nokey']]

# TODO: Create an appropriate input function read_dataset
def read_csv_line(row):
  columns = tf.decode_csv(row, record_defaults=DEFAULTS)
  features = {k: v for k, v in zip(CSV_COLUMNS, columns)}
  del features["key"]
  label = features.pop("fare_amount")
  return features, label
  
def read_dataset(filename, mode, batch_size=512):
    files = tf.data.Dataset.list_files(filename, shuffle=False)
    dataset = files.flat_map(tf.data.TextLineDataset).map(read_csv_line)
    
    if mode == tf.estimator.ModeKeys.TRAIN:
      num_epochs = None
      dataset = dataset.shuffle(seed=666, buffer_size=10*batch_size)
    else:
      num_epochs = 1
    return dataset.repeat(num_epochs).batch(batch_size)
  
def get_train_input_fn():
  return read_dataset('./taxi-train.csv', mode = tf.estimator.ModeKeys.TRAIN)

def get_valid_input_fn():
  return read_dataset('./taxi-valid.csv', mode = tf.estimator.ModeKeys.EVAL)
No matches...