from typing import Any, Dict, Iterable, List, Optional
import pandas as pd
import pyarrow as pa
import pyspark.sql as ps
from pyspark.sql.functions import col
from triad import SerializableRLock
from triad.collections.schema import SchemaError
from triad.utils.assertion import assert_or_throw
from fugue.dataframe import (
ArrowDataFrame,
DataFrame,
IterableDataFrame,
LocalBoundedDataFrame,
)
from fugue.dataframe.utils import pa_table_as_array, pa_table_as_dicts
from fugue.exceptions import FugueDataFrameOperationError
from fugue.plugins import (
as_array,
as_array_iterable,
as_arrow,
as_dict_iterable,
as_dicts,
as_local_bounded,
as_pandas,
count,
drop_columns,
get_column_names,
get_num_partitions,
head,
is_bounded,
is_df,
is_empty,
is_local,
rename,
select_columns,
)
from ._utils.convert import (
to_arrow,
to_cast_expression,
to_pandas,
to_schema,
to_type_safe_input,
)
from ._utils.misc import is_spark_connect, is_spark_dataframe
[docs]
class SparkDataFrame(DataFrame):
"""DataFrame that wraps Spark DataFrame. Please also read
|DataFrameTutorial| to understand this Fugue concept
:param df: :class:`spark:pyspark.sql.DataFrame`
:param schema: |SchemaLikeObject| or :class:`spark:pyspark.sql.types.StructType`,
defaults to None.
.. note::
* You should use :meth:`fugue_spark.execution_engine.SparkExecutionEngine.to_df`
instead of construction it by yourself.
* If ``schema`` is set, then there will be type cast on the Spark DataFrame if
the schema is different.
"""
def __init__(self, df: Any = None, schema: Any = None): # noqa: C901
self._lock = SerializableRLock()
if is_spark_dataframe(df):
if schema is not None:
schema = to_schema(schema).assert_not_empty()
has_cast, expr = to_cast_expression(df, schema, True)
if has_cast:
df = df.selectExpr(*expr) # type: ignore
else:
schema = to_schema(df).assert_not_empty()
self._native = df
super().__init__(schema)
else: # pragma: no cover
assert_or_throw(schema is not None, SchemaError("schema is None"))
schema = to_schema(schema).assert_not_empty()
raise ValueError(f"{df} is incompatible with SparkDataFrame")
@property
def alias(self) -> str:
return "_" + str(id(self.native))
@property
def native(self) -> ps.DataFrame:
"""The wrapped Spark DataFrame
:rtype: :class:`spark:pyspark.sql.DataFrame`
"""
return self._native
[docs]
def native_as_df(self) -> ps.DataFrame:
return self._native
@property
def is_local(self) -> bool:
return False
@property
def is_bounded(self) -> bool:
return True
[docs]
def as_local_bounded(self) -> LocalBoundedDataFrame:
res = ArrowDataFrame(self.as_arrow())
if self.has_metadata:
res.reset_metadata(self.metadata)
return res
@property
def num_partitions(self) -> int:
return _spark_num_partitions(self.native)
@property
def empty(self) -> bool:
return self._first is None
[docs]
def peek_array(self) -> List[Any]:
self.assert_not_empty()
return self._first # type: ignore
[docs]
def count(self) -> int:
with self._lock:
if "_df_count" not in self.__dict__:
self._df_count = self.native.count()
return self._df_count
def _drop_cols(self, cols: List[str]) -> DataFrame:
cols = (self.schema - cols).names
return self._select_cols(cols)
def _select_cols(self, cols: List[Any]) -> DataFrame:
schema = self.schema.extract(cols)
return SparkDataFrame(self.native[schema.names])
[docs]
def as_pandas(self) -> pd.DataFrame:
return _spark_df_as_pandas(self.native)
[docs]
def as_arrow(self, type_safe: bool = False) -> pa.Table:
return _spark_df_as_arrow(self.native)
[docs]
def rename(self, columns: Dict[str, str]) -> DataFrame:
try:
self.schema.rename(columns)
except Exception as e:
raise FugueDataFrameOperationError from e
return SparkDataFrame(_rename_spark_dataframe(self.native, columns))
[docs]
def alter_columns(self, columns: Any) -> DataFrame:
new_schema = self.schema.alter(columns)
if new_schema == self.schema:
return self
return SparkDataFrame(self.native, new_schema)
[docs]
def as_array(
self, columns: Optional[List[str]] = None, type_safe: bool = False
) -> List[Any]:
return _spark_as_array(self.native, columns=columns, type_safe=type_safe)
[docs]
def as_array_iterable(
self, columns: Optional[List[str]] = None, type_safe: bool = False
) -> Iterable[Any]:
yield from _spark_as_array_iterable(
self.native, columns=columns, type_safe=type_safe
)
[docs]
def as_dicts(self, columns: Optional[List[str]] = None) -> List[Dict[str, Any]]:
return _spark_as_dicts(self.native, columns=columns)
[docs]
def as_dict_iterable(
self, columns: Optional[List[str]] = None
) -> Iterable[Dict[str, Any]]:
yield from _spark_as_dict_iterable(self.native, columns=columns)
[docs]
def head(
self, n: int, columns: Optional[List[str]] = None
) -> LocalBoundedDataFrame:
sdf = self._select_columns(columns)
return SparkDataFrame(
sdf.native.limit(n), sdf.schema
).as_local() # type: ignore
@property
def _first(self) -> Optional[List[Any]]:
with self._lock:
if "_first_row" not in self.__dict__:
self._first_row = self.native.first()
if self._first_row is not None:
self._first_row = list(self._first_row) # type: ignore
return self._first_row # type: ignore
def _select_columns(self, columns: Optional[List[str]]) -> "SparkDataFrame":
if columns is None:
return self
return SparkDataFrame(self.native.select(*columns))
@is_df.candidate(lambda df: is_spark_dataframe(df))
def _spark_is_df(df: ps.DataFrame) -> bool:
return True
@as_arrow.candidate(lambda df: isinstance(df, ps.DataFrame))
def _spark_df_as_arrow(df: ps.DataFrame) -> pd.DataFrame:
return to_arrow(df)
@as_pandas.candidate(lambda df: isinstance(df, ps.DataFrame))
def _spark_df_as_pandas(df: ps.DataFrame) -> pd.DataFrame:
return to_pandas(df)
@get_num_partitions.candidate(lambda df: is_spark_dataframe(df))
def _spark_num_partitions(df: ps.DataFrame) -> int:
return df.rdd.getNumPartitions()
@count.candidate(lambda df: is_spark_dataframe(df))
def _spark_df_count(df: ps.DataFrame) -> int:
return df.count()
@is_bounded.candidate(lambda df: is_spark_dataframe(df))
def _spark_df_is_bounded(df: ps.DataFrame) -> bool:
return True
@is_empty.candidate(lambda df: is_spark_dataframe(df))
def _spark_df_is_empty(df: ps.DataFrame) -> bool:
return df.first() is None
@is_local.candidate(lambda df: is_spark_dataframe(df))
def _spark_df_is_local(df: ps.DataFrame) -> bool:
return False
@as_local_bounded.candidate(lambda df: is_spark_dataframe(df))
def _spark_df_as_local(df: ps.DataFrame) -> pd.DataFrame:
return to_pandas(df)
@get_column_names.candidate(lambda df: is_spark_dataframe(df))
def _get_spark_df_columns(df: ps.DataFrame) -> List[Any]:
return df.columns
@rename.candidate(lambda df, *args, **kwargs: is_spark_dataframe(df))
def _rename_spark_df(
df: ps.DataFrame, columns: Dict[str, Any], as_fugue: bool = False
) -> ps.DataFrame:
if len(columns) == 0:
return df
_assert_no_missing(df, columns.keys())
return _adjust_df(_rename_spark_dataframe(df, columns), as_fugue=as_fugue)
@drop_columns.candidate(lambda df, *args, **kwargs: is_spark_dataframe(df))
def _drop_spark_df_columns(
df: ps.DataFrame, columns: List[str], as_fugue: bool = False
) -> Any:
cols = [x for x in df.columns if x not in columns]
if len(cols) == 0:
raise FugueDataFrameOperationError("cannot drop all columns")
if len(cols) + len(columns) != len(df.columns):
_assert_no_missing(df, columns)
return _adjust_df(df[cols], as_fugue=as_fugue)
@select_columns.candidate(lambda df, *args, **kwargs: is_spark_dataframe(df))
def _select_spark_df_columns(
df: ps.DataFrame, columns: List[Any], as_fugue: bool = False
) -> Any:
if len(columns) == 0:
raise FugueDataFrameOperationError("must select at least one column")
_assert_no_missing(df, columns)
return _adjust_df(df[columns], as_fugue=as_fugue)
@head.candidate(lambda df, *args, **kwargs: is_spark_dataframe(df))
def _spark_df_head(
df: ps.DataFrame,
n: int,
columns: Optional[List[str]] = None,
as_fugue: bool = False,
) -> pd.DataFrame:
if columns is not None:
df = df[columns]
res = df.limit(n)
return SparkDataFrame(res).as_local() if as_fugue else to_pandas(res)
@as_array.candidate(lambda df, *args, **kwargs: is_spark_dataframe(df))
def _spark_as_array(
df: ps.DataFrame, columns: Optional[List[str]] = None, type_safe: bool = False
) -> List[Any]:
assert_or_throw(columns is None or len(columns) > 0, ValueError("empty columns"))
_df = df if columns is None or len(columns) == 0 else df[columns]
return pa_table_as_array(to_arrow(_df), columns)
@as_array_iterable.candidate(lambda df, *args, **kwargs: is_spark_dataframe(df))
def _spark_as_array_iterable(
df: ps.DataFrame, columns: Optional[List[str]] = None, type_safe: bool = False
) -> Iterable[Any]:
if is_spark_connect(df): # pragma: no cover
yield from _spark_as_array(df, columns, type_safe=type_safe)
else:
assert_or_throw(
columns is None or len(columns) > 0, ValueError("empty columns")
)
_df = df if columns is None or len(columns) == 0 else df[columns]
if not type_safe:
for row in to_type_safe_input(
_df.rdd.toLocalIterator(), to_schema(_df.schema)
):
yield list(row)
else:
tdf = IterableDataFrame(
_spark_as_array_iterable(_df, type_safe=False), to_schema(_df.schema)
)
yield from tdf.as_array_iterable(type_safe=True)
@as_dicts.candidate(lambda df, *args, **kwargs: is_spark_dataframe(df))
def _spark_as_dicts(
df: ps.DataFrame, columns: Optional[List[str]] = None, type_safe: bool = False
) -> List[Dict[str, Any]]:
assert_or_throw(columns is None or len(columns) > 0, ValueError("empty columns"))
_df = df if columns is None or len(columns) == 0 else df[columns]
return pa_table_as_dicts(to_arrow(_df), columns)
@as_dict_iterable.candidate(lambda df, *args, **kwargs: is_spark_dataframe(df))
def _spark_as_dict_iterable(
df: ps.DataFrame, columns: Optional[List[str]] = None, type_safe: bool = False
) -> Iterable[Dict[str, Any]]:
assert_or_throw(columns is None or len(columns) > 0, ValueError("empty columns"))
_df = df if columns is None or len(columns) == 0 else df[columns]
cols = list(_df.columns)
for row in _spark_as_array_iterable(_df, type_safe=type_safe):
yield dict(zip(cols, row))
def _rename_spark_dataframe(df: ps.DataFrame, names: Dict[str, Any]) -> ps.DataFrame:
cols: List[ps.Column] = []
for f in df.schema:
c = col(f.name)
if f.name in names:
c = c.alias(names[f.name])
cols.append(c)
return df.select(cols)
def _assert_no_missing(df: ps.DataFrame, columns: Iterable[Any]) -> None:
missing = set(columns) - set(df.columns)
if len(missing) > 0:
raise FugueDataFrameOperationError("found nonexistent columns: {missing}")
def _adjust_df(res: ps.DataFrame, as_fugue: bool):
return res if not as_fugue else SparkDataFrame(res)