Counting when it really counts
I’m the kuya at Kurdapyo Labs — a recovering Oracle developer who saw the light and helped migrate legacy systems out of Oracle (and saved a lot of money doing it). I used to write PL/SQL, Perl, ksh, Bash, and all kinds of hand-crafted ETL. These days, I wrestle with PySpark, Airflow, Terraform, and YAML that refuses to cooperate. I’ve been around long enough to know when things were harder… and when they were actually better. This blog is where I write (and occasionally rant) about modern data tools — especially the ones marketed as “no-code” that promise simplicity, but still break in production anyway.
Disclaimer: These are my thoughts—100% my own, not my employer’s, my client’s, or that one loud guy on tech Twitter. I’m just sharing what I’ve learned (and unlearned) along the way. No promises, no warranties—just real talk, some opinions, and the occasional coffee/beer-fueled rant.
If something here helps you out, awesome! If you think I’ve missed something or want to share your own take, I’d love to hear from you. Let’s learn from each other.
A seasoned data engineer is deliberate on their use of expensive operations such as counts. It’s fairly common to see a misuse of this. Such as using a count to check if a query returns some rows. Or running a count query to figure out the number of records to be inserted by an insert-select statement.
Frankly speaking, doing a select count just to check the query will return some rows is inefficient and many systems I’ve come across are littered with samples of this. I won’t dwell into that too much here.
What I want to discuss is more on how to get the counts after running doing dataframe.write statements in spark.
The Old Traditions
Let me start with how it used to be done with traditional databases. Across database systems, the method for determining the number of inserted rows differs. In MySQL, the ROW_COUNT() function returns the count of rows affected by the preceding INSERT statement. PostgreSQL offers a RETURNING clause within the INSERT statement itself, which can provide a count or return the inserted data. In SQL Server, the @@ROWCOUNT system variable holds the number of rows impacted by the most recent statement. Snowflake provides a SQLROWCOUNT variable for this purpose, specifically for DML statements within its scripting environment.
For Amazon Redshift, not really a traditional database but is a fork off good old postgres, obtaining an accurate insert count typically involves querying the stl_insert system table for the query ID of the recent INSERT operation, which can be more complex due to the distributed nature of the database. These variations mean developers must use the correct function or variable for their specific database to get an immediate and accurate count of inserted rows.
Simply put, there used to be values returned or special variables that provides a way to count.
Breaking from tradition
Redshift was already hinting a little bit about the counting difficulties for distributed databases.
In my early days using databricks, I tried this:
results = df.write \
.format("csv") \
.option("header", "true") \
.mode("overwrite") \
.save(output_path)
print(results)
What a surprise to see this back:
> None
A common workaround is to do a df.count() to get the count either before or after the operation to get the counts. This was okay in cases where the data is not changing in between operations. However, it means that the dataframe will actually be executed 2x albeit the count would be optimized to just get a count.
I explored this a bit further and found that accumulators can be used.
sc = spark.sparkContext
rows_written_acc = sc.longAccumulator("RowsWrittenCount")
def increment_counter_and_return(row):
rows_written_acc.add(1)
return row
df_rdd = df.rdd.map(increment_counter_and_return)
df.write \
.format("csv") \
.option("header", "true") \
.mode("overwrite") \
.save(output_path)
print(rows_written_acc)
This code works BUT only if the cluster’s isolation mode is set to Single User. It rules out Unity Catalog so I had to rule it out instead. Back to the drawing board.
Observing counts
My search for a better way to count or not count led me to the Dataframe.observe function. This was introduced on Spark 3.3.1 and it computes the defined aggregates in an dataframe operation. This was exactly what I needed. I ended up with this:
from pyspark.sql import Observation
from pyspark.sql.functions import lit, count
observation = Observation("write-metrics")
df.observe(observation, count(lit("1")).alias("count")) \
.write \
.format("csv") \
.option("header", "true") \
.mode("overwrite") \
.save("test.csv")
print(observation.get)
This returned me a dict of the aggregates I put on the observation:
> {'count': 4425435}
More things can be added to the observe function. Things like getting the max timestamps or summing up values. This is basically what I needed for now.
It is important to note though that the Observation class works on batch queries. For streaming a slightly different approach is needed. But that is a topic for a future blog.
The catch
This approach worked well and good for 99.9% of our use case. However, I did observe that .1% of our processes got None returned on observation.get. I’m still getting to the bottom of it but I’ve observed that this happens when I try to overload the cluster with too much workload. I eventually added a check on that observation.get function to count in case it was null. Because I really don’t want to write a big file again. Wish me luck on this investigation and if I ever get to the reasons I’ll put an addendum to this article.