Note
Access to this page requires authorization. You can try signing in or changing directories.
Access to this page requires authorization. You can try changing directories.
Use the TaskContext PySpark API to get context information while running a Batch Unity Catalog Python UDF or PySpark UDF.
For example, context information such as the user's identity and cluster tags can verify a user's identity to access external services.
Requirements
TaskContext is supported on Databricks Runtime versions 16.3 and above.
TaskContext is supported on the following UDF types:
Use TaskContext to get context information
Select a tab to see TaskContext examples for PySpark UDFs or Batch Unity Catalog Python UDFs .
PySpark UDF
The following PySpark UDF example prints the user's context:
@udf
def log_context():
import json
from pyspark.taskcontext import TaskContext
tc = TaskContext.get()
# Returns current user executing the UDF
session_user = tc.getLocalProperty("user")
# Returns cluster tags
tags = dict(item.values() for item in json.loads(tc.getLocalProperty("spark.databricks.clusterUsageTags.clusterAllTags ") or "[]"))
# Returns current version details
current_version = {
"dbr_version": tc.getLocalProperty("spark.databricks.clusterUsageTags.sparkVersion"),
"dbsql_version": tc.getLocalProperty("spark.databricks.clusterUsageTags.dbsqlVersion")
}
return {
"user": session_user,
"job_group_id": job_group_id,
"tags": tags,
"current_version": current_version
}
Batch Unity Catalog Python UDF
The following Batch Unity Catalog Python UDF example gets the user's identity to call an AWS Lambda function using a service credential:
%sql
CREATE OR REPLACE FUNCTION main.test.call_lambda_func(data STRING, debug BOOLEAN) RETURNS STRING LANGUAGE PYTHON
PARAMETER STYLE PANDAS
HANDLER 'batchhandler'
CREDENTIALS (
`batch-udf-service-creds-example-cred` DEFAULT
)
AS $$
import boto3
import json
import pandas as pd
import base64
from pyspark.taskcontext import TaskContext
def batchhandler(it):
# Automatically picks up DEFAULT credential:
session = boto3.Session()
client = session.client("lambda", region_name="us-west-2")
# Can propagate TaskContext information to lambda context:
user_ctx = {"custom": {"user": TaskContext.get().getLocalProperty("user")}}
for vals, is_debug in it:
payload = json.dumps({"values": vals.to_list(), "is_debug": bool(is_debug[0])})
res = client.invoke(
FunctionName="HashValuesFunction",
InvocationType="RequestResponse",
ClientContext=base64.b64encode(json.dumps(user_ctx).encode("utf-8")).decode(
"utf-8"
),
Payload=payload,
)
response_payload = json.loads(res["Payload"].read().decode("utf-8"))
if "errorMessage" in response_payload:
raise Exception(str(response_payload))
yield pd.Series(response_payload["values"])
$$;
Call the UDF after it is registered:
SELECT main.test.call_lambda_func(data, false)
FROM VALUES
('abc'),
('def')
AS t(data)
TaskContext properties
The TaskContext.getLocalProperty()
method has the following property keys:
Property Key | Description | Example Usage |
---|---|---|
user |
The user currently executing the UDF | tc.getLocalProperty("user") -> "alice" |
spark.jobGroup.id |
The Spark job group ID associated with the current UDF | tc.getLocalProperty("spark.jobGroup.id") -> "jobGroup-92318" |
spark.databricks.clusterUsageTags.clusterAllTags |
Cluster metadata tags as key-value pairs formatted as a string representation of a JSON dictionary | tc.getLocalProperty("spark.databricks.clusterUsageTags.clusterAllTags") -> [{"Department": "Finance"}] |
spark.databricks.clusterUsageTags.region |
The region where the workspace resides | tc.getLocalProperty("spark.databricks.clusterUsageTags.region") -> "us-west-2" |
accountId |
Databricks account ID for the running context | tc.getLocalProperty("accountId") -> "1234567890123456" |
orgId |
Workspace ID (not available on DBSQL) | tc.getLocalProperty("orgId") -> "987654321" |
spark.databricks.clusterUsageTags.sparkVersion |
Databricks Runtime version for the cluster (on non-DBSQL environments) | tc.getLocalProperty("spark.databricks.clusterUsageTags.sparkVersion") -> "16.3" |
spark.databricks.clusterUsageTags.dbsqlVersion |
DBSQL version (on DBSQL environments) | tc.getLocalProperty("spark.databricks.clusterUsageTags.dbsqlVersion") -> "2024.35" |