pyspark.pandas.extensions.register_dataframe_accessor#

pyspark.pandas.extensions.register_dataframe_accessor(name)[source]#

Register a custom accessor with a DataFrame

Parameters
namestr

name used when calling the accessor after its registered

Returns
callable

A class decorator.

See also

register_series_accessor

Register a custom accessor on Series objects

register_index_accessor

Register a custom accessor on Index objects

Notes

When accessed, your accessor will be initialized with the pandas-on-Spark object the user is interacting with. The accessor’s init method should always ingest the object being accessed. See the examples for the init signature.

In the pandas API, if data passed to your accessor has an incorrect dtype, it’s recommended to raise an AttributeError for consistency purposes. In pandas-on-Spark, ValueError is more frequently used to annotate when a value’s datatype is unexpected for a given method/function.

Ultimately, you can structure this however you like, but pandas-on-Spark would likely do something like this:

>>> ps.Series(['a', 'b']).dt
...
Traceback (most recent call last):
    ...
ValueError: Cannot call DatetimeMethods on type StringType()

Examples

In your library code:

from pyspark.pandas.extensions import register_dataframe_accessor

@register_dataframe_accessor("geo")
class GeoAccessor:

    def __init__(self, pandas_on_spark_obj):
        self._obj = pandas_on_spark_obj
        # other constructor logic

    @property
    def center(self):
        # return the geographic center point of this DataFrame
        lat = self._obj.latitude
        lon = self._obj.longitude
        return (float(lon.mean()), float(lat.mean()))

    def plot(self):
        # plot this array's data on a map
        pass

Then, in an ipython session:

>>> ## Import if the accessor is in the other file.
>>> # from my_ext_lib import GeoAccessor
>>> psdf = ps.DataFrame({"longitude": np.linspace(0,10),
...                     "latitude": np.linspace(0, 20)})
>>> psdf.geo.center  
(5.0, 10.0)

>>> psdf.geo.plot()