import logging
from typing import Any, Dict, Iterable, List, Optional, Union
import duckdb
from duckdb import DuckDBPyConnection, DuckDBPyRelation
from triad import SerializableRLock
from triad.utils.assertion import assert_or_throw
from triad.utils.schema import quote_name
from fugue import (
ArrowDataFrame,
ExecutionEngine,
MapEngine,
NativeExecutionEngine,
PandasMapEngine,
SQLEngine,
)
from fugue.collections.partition import PartitionSpec, parse_presort_exp
from fugue.collections.sql import StructuredRawSQL, TempTableName
from fugue.dataframe import DataFrame, DataFrames, LocalBoundedDataFrame
from fugue.dataframe.utils import get_join_schemas
from ._io import DuckDBIO
from ._utils import (
encode_column_name,
encode_column_names,
encode_schema_names,
encode_value_to_expr,
)
from .dataframe import DuckDataFrame, _duck_as_arrow
_FUGUE_DUCKDB_PRAGMA_CONFIG_PREFIX = "fugue.duckdb.pragma."
_FUGUE_DUCKDB_EXTENSIONS = "fugue.duckdb.extensions"
[docs]
class DuckDBEngine(SQLEngine):
"""DuckDB SQL backend implementation.
:param execution_engine: the execution engine this sql engine will run on
"""
@property
def dialect(self) -> Optional[str]:
return "duckdb"
[docs]
def select(self, dfs: DataFrames, statement: StructuredRawSQL) -> DataFrame:
if isinstance(self.execution_engine, DuckExecutionEngine):
return self._duck_select(dfs, statement)
else:
_dfs, _sql = self.encode(dfs, statement)
return self._other_select(_dfs, _sql)
[docs]
def table_exists(self, table: str) -> bool:
return self._get_table(table) is not None
[docs]
def save_table(
self,
df: DataFrame,
table: str,
mode: str = "overwrite",
partition_spec: Optional[PartitionSpec] = None,
**kwargs: Any,
) -> None:
if isinstance(self.execution_engine, DuckExecutionEngine):
con = self.execution_engine.connection
tdf: DuckDataFrame = _to_duck_df(
self.execution_engine, df, create_view=False # type: ignore
)
et = self._get_table(table)
if et is not None:
if mode == "overwrite":
tp = "VIEW" if et["table_type"] == "VIEW" else "TABLE"
tn = encode_column_name(et["table_name"])
con.query(f"DROP {tp} {tn}")
else:
raise Exception(f"{table} exists")
tdf.native.create(table)
else: # pragma: no cover
raise NotImplementedError(
"save_table can only be used with DuckExecutionEngine"
)
[docs]
def load_table(self, table: str, **kwargs: Any) -> DataFrame:
if isinstance(self.execution_engine, DuckExecutionEngine):
return DuckDataFrame(self.execution_engine.connection.table(table))
else: # pragma: no cover
raise NotImplementedError(
"load_table can only be used with DuckExecutionEngine"
)
@property
def is_distributed(self) -> bool:
return False
def _duck_select(self, dfs: DataFrames, statement: StructuredRawSQL) -> DataFrame:
name_map: Dict[str, str] = {}
for k, v in dfs.items():
tdf: DuckDataFrame = _to_duck_df(
self.execution_engine, v, create_view=True # type: ignore
)
name_map[k] = tdf.alias
query = statement.construct(name_map, dialect=self.dialect, log=self.log)
result = self.execution_engine.connection.query(query) # type: ignore
return DuckDataFrame(result)
def _other_select(self, dfs: DataFrames, statement: str) -> DataFrame:
conn = duckdb.connect()
try:
for k, v in dfs.items():
duckdb.from_arrow(v.as_arrow(), connection=conn).create_view(k)
return ArrowDataFrame(_duck_as_arrow(conn.execute(statement)))
finally:
conn.close()
def _get_table(self, table: str) -> Optional[Dict[str, Any]]:
if isinstance(self.execution_engine, DuckExecutionEngine):
# TODO: this is over simplified
con = self.execution_engine.connection
qt = quote_name(table, "'")
if not qt.startswith("'"):
qt = "'" + qt + "'"
tables = (
con.query(
"SELECT table_catalog,table_schema,table_name,table_type"
f" FROM information_schema.tables WHERE table_name={qt}"
)
.to_df()
.to_dict("records")
)
return None if len(tables) == 0 else tables[0]
else: # pragma: no cover
raise NotImplementedError(
"table_exists can only be used with DuckExecutionEngine"
)
[docs]
class DuckExecutionEngine(ExecutionEngine):
"""The execution engine using DuckDB.
Please read |ExecutionEngineTutorial| to understand this important Fugue concept
:param conf: |ParamsLikeObject|, read |FugueConfig| to learn Fugue specific options
:param connection: DuckDB connection
"""
def __init__(
self, conf: Any = None, connection: Optional[DuckDBPyConnection] = None
):
super().__init__(conf)
self._native_engine = NativeExecutionEngine(conf)
self._con = connection or duckdb.connect()
self._external_con = connection is not None
self._context_lock = SerializableRLock()
self._registered_dfs: Dict[str, DuckDataFrame] = {}
try:
for pg in list(self._get_pragmas()): # transactional
self._con.execute(pg)
for ext in self.conf.get(_FUGUE_DUCKDB_EXTENSIONS, "").split(","):
_ext = ext.strip()
if _ext != "":
self._con.install_extension(_ext)
self._con.load_extension(_ext)
except Exception:
self.stop()
raise
def _get_pragmas(self) -> Iterable[str]:
for k, v in self.conf.items():
if k.startswith(_FUGUE_DUCKDB_PRAGMA_CONFIG_PREFIX):
name = k[len(_FUGUE_DUCKDB_PRAGMA_CONFIG_PREFIX) :]
assert_or_throw(
name.isidentifier(), ValueError(f"{name} is not a valid pragma key")
)
value = encode_value_to_expr(v)
yield f"PRAGMA {name}={value};"
[docs]
def stop_engine(self) -> None:
if not self._external_con:
self._con.close()
def __repr__(self) -> str:
return "DuckExecutionEngine"
@property
def is_distributed(self) -> bool:
return False
@property
def connection(self) -> DuckDBPyConnection:
return self._con
@property
def log(self) -> logging.Logger:
return self._native_engine.log
[docs]
def create_default_sql_engine(self) -> SQLEngine:
return DuckDBEngine(self)
[docs]
def create_default_map_engine(self) -> MapEngine:
return PandasMapEngine(self._native_engine)
[docs]
def get_current_parallelism(self) -> int:
return 1
[docs]
def to_df(self, df: Any, schema: Any = None) -> DataFrame:
return _to_duck_df(self, df, schema=schema)
[docs]
def repartition(
self, df: DataFrame, partition_spec: PartitionSpec
) -> DataFrame: # pragma: no cover
self.log.warning("%s doesn't respect repartition", self)
return df
[docs]
def broadcast(self, df: DataFrame) -> DataFrame:
return df
[docs]
def persist(
self,
df: DataFrame,
lazy: bool = False,
**kwargs: Any,
) -> DataFrame:
# TODO: we should create DuckDB table, but it has bugs, so can't use by 0.3.1
if isinstance(df, DuckDataFrame):
# materialize
res: DataFrame = ArrowDataFrame(df.as_arrow())
else:
res = self.to_df(df)
res.reset_metadata(df.metadata)
return res
[docs]
def join(
self,
df1: DataFrame,
df2: DataFrame,
how: str,
on: Optional[List[str]] = None,
) -> DataFrame:
key_schema, output_schema = get_join_schemas(df1, df2, how=how, on=on)
t1, t2, t3 = (
TempTableName(),
TempTableName(),
TempTableName(),
)
on_fields = " AND ".join(
f"{t1}.{encode_column_name(k)}={t2}.{encode_column_name(k)}"
for k in key_schema
)
join_type = self._how_to_join(how)
if how.lower() == "cross":
select_fields = ",".join(
f"{t1}.{encode_column_name(k)}"
if k in df1.schema
else f"{t2}.{encode_column_name(k)}"
for k in output_schema.names
)
sql = f"SELECT {select_fields} FROM {t1} {join_type} {t2}"
elif how.lower() == "right_outer":
select_fields = ",".join(
f"{t2}.{encode_column_name(k)}"
if k in df2.schema
else f"{t1}.{encode_column_name(k)}"
for k in output_schema.names
)
sql = (
f"SELECT {select_fields} FROM {t2} LEFT OUTER JOIN {t1} ON {on_fields}"
)
elif how.lower() == "full_outer":
select_fields = ",".join(
f"COALESCE({t1}.{encode_column_name(k)},{t2}.{encode_column_name(k)}) "
f"AS {encode_column_name(k)}"
if k in key_schema
else encode_column_name(k)
for k in output_schema.names
)
sql = f"SELECT {select_fields} FROM {t1} {join_type} {t2} ON {on_fields}"
elif how.lower() in ["semi", "left_semi"]:
keys = ",".join(encode_schema_names(key_schema))
on_fields = " AND ".join(
f"{t1}.{encode_column_name(k)}={t3}.{encode_column_name(k)}"
for k in key_schema
)
sql = (
f"SELECT {t1}.* FROM {t1} INNER JOIN (SELECT DISTINCT {keys} "
f"FROM {t2}) AS {t3} ON {on_fields}"
)
elif how.lower() in ["anti", "left_anti"]:
keys = ",".join(encode_schema_names(key_schema))
on_fields = " AND ".join(
f"{t1}.{encode_column_name(k)}={t3}.{encode_column_name(k)}"
for k in key_schema
)
sql = (
f"SELECT {t1}.* FROM {t1} LEFT OUTER JOIN "
f"(SELECT DISTINCT {keys}, 1 AS __contain__ FROM {t2}) AS {t3} "
f"ON {on_fields} WHERE {t3}.__contain__ IS NULL"
)
else:
select_fields = ",".join(
f"{t1}.{encode_column_name(k)}"
if k in df1.schema
else f"{t2}.{encode_column_name(k)}"
for k in output_schema.names
)
sql = f"SELECT {select_fields} FROM {t1} {join_type} {t2} ON {on_fields}"
return self._sql(sql, {t1.key: df1, t2.key: df2})
def _how_to_join(self, how: str):
return how.upper().replace("_", " ") + " JOIN"
[docs]
def union(self, df1: DataFrame, df2: DataFrame, distinct: bool = True) -> DataFrame:
assert_or_throw(
df1.schema == df2.schema, ValueError(f"{df1.schema} != {df2.schema}")
)
if distinct:
t1, t2 = TempTableName(), TempTableName()
sql = f"SELECT * FROM {t1} UNION SELECT * FROM {t2}"
return self._sql(sql, {t1.key: df1, t2.key: df2})
return DuckDataFrame(
_to_duck_df(self, df1).native.union(_to_duck_df(self, df2).native)
)
[docs]
def subtract(
self, df1: DataFrame, df2: DataFrame, distinct: bool = True
) -> DataFrame: # pragma: no cover
if distinct:
t1, t2 = TempTableName(), TempTableName()
sql = f"SELECT * FROM {t1} EXCEPT SELECT * FROM {t2}"
return self._sql(sql, {t1.key: df1, t2.key: df2})
return DuckDataFrame(
_to_duck_df(self, df1).native.except_(_to_duck_df(self, df2).native)
)
[docs]
def intersect(
self, df1: DataFrame, df2: DataFrame, distinct: bool = True
) -> DataFrame:
if distinct:
t1, t2 = TempTableName(), TempTableName()
sql = f"SELECT * FROM {t1} INTERSECT DISTINCT SELECT * FROM {t2}"
return self._sql(sql, {t1.key: df1, t2.key: df2})
raise NotImplementedError(
"DuckDB doesn't have consist behavior on INTERSECT ALL,"
" so Fugue doesn't support it"
)
[docs]
def distinct(self, df: DataFrame) -> DataFrame:
rel = _to_duck_df(self, df).native.distinct()
return DuckDataFrame(rel)
[docs]
def dropna(
self,
df: DataFrame,
how: str = "any",
thresh: int = None,
subset: List[str] = None,
) -> DataFrame:
schema = df.schema
if subset is not None:
schema = schema.extract(subset)
if how == "all":
thr = 0
elif how == "any":
thr = thresh or len(schema)
else: # pragma: no cover
raise ValueError(f"{how} is not one of any and all")
cw = [
f"CASE WHEN {encode_column_name(f)} IS NULL THEN 0 ELSE 1 END"
for f in schema.names
]
expr = " + ".join(cw) + f" >= {thr}"
return DuckDataFrame(_to_duck_df(self, df).native.filter(expr))
[docs]
def fillna(self, df: DataFrame, value: Any, subset: List[str] = None) -> DataFrame:
def _build_value_dict(names: List[str]) -> Dict[str, str]:
if not isinstance(value, dict):
v = encode_value_to_expr(value)
return {n: v for n in names}
else:
return {n: encode_value_to_expr(value[n]) for n in names}
names = df.columns
if isinstance(value, dict):
# subset should be ignored
names = list(value.keys())
elif subset is not None:
names = list(df.schema.extract(subset).names)
vd = _build_value_dict(names)
assert_or_throw(
all(v != "NULL" for v in vd.values()),
ValueError("fillna value can not be None or contain None"),
)
cols = [
f"COALESCE({encode_column_name(f)}, {vd[f]}) AS {encode_column_name(f)}"
if f in names
else encode_column_name(f)
for f in df.columns
]
return DuckDataFrame(_to_duck_df(self, df).native.project(", ".join(cols)))
[docs]
def sample(
self,
df: DataFrame,
n: Optional[int] = None,
frac: Optional[float] = None,
replace: bool = False,
seed: Optional[int] = None,
) -> DataFrame:
assert_or_throw(
(n is None and frac is not None and frac >= 0.0)
or (frac is None and n is not None and n >= 0),
ValueError(
f"one and only one of n and frac should be non-negative, {n}, {frac}"
),
)
tb = TempTableName()
if frac is not None:
sql = f"SELECT * FROM {tb} USING SAMPLE bernoulli({frac*100} PERCENT)"
else:
sql = f"SELECT * FROM {tb} USING SAMPLE reservoir({n} ROWS)"
if seed is not None:
sql += f" REPEATABLE ({seed})"
return self._sql(sql, {tb.key: df})
[docs]
def take(
self,
df: DataFrame,
n: int,
presort: str,
na_position: str = "last",
partition_spec: Optional[PartitionSpec] = None,
) -> DataFrame:
partition_spec = partition_spec or PartitionSpec()
assert_or_throw(
isinstance(n, int),
ValueError("n needs to be an integer"),
)
if presort is not None and presort != "":
_presort = parse_presort_exp(presort)
else:
_presort = partition_spec.presort
tb = TempTableName()
if len(_presort) == 0:
if len(partition_spec.partition_by) == 0:
return DuckDataFrame(_to_duck_df(self, df).native.limit(n))
cols = ", ".join(encode_schema_names(df.schema))
pcols = ", ".join(encode_column_names(partition_spec.partition_by))
sql = (
f"SELECT *, ROW_NUMBER() OVER (PARTITION BY {pcols}) "
f"AS __fugue_take_param FROM {tb}"
)
sql = f"SELECT {cols} FROM ({sql}) WHERE __fugue_take_param<={n}"
return self._sql(sql, {tb.key: df})
sorts: List[str] = []
for k, v in _presort.items():
s = encode_column_name(k)
if not v:
s += " DESC"
s += " NULLS FIRST" if na_position == "first" else " NULLS LAST"
sorts.append(s)
sort_expr = "ORDER BY " + ", ".join(sorts)
if len(partition_spec.partition_by) == 0:
sql = f"SELECT * FROM {tb} {sort_expr} LIMIT {n}"
return self._sql(sql, {tb.key: df})
cols = ", ".join(encode_schema_names(df.schema))
pcols = ", ".join(encode_column_names(partition_spec.partition_by))
sql = (
f"SELECT *, ROW_NUMBER() OVER (PARTITION BY {pcols} {sort_expr}) "
f"AS __fugue_take_param FROM {tb}"
)
sql = f"SELECT {cols} FROM ({sql}) WHERE __fugue_take_param<={n}"
return self._sql(sql, {tb.key: df})
[docs]
def load_df(
self,
path: Union[str, List[str]],
format_hint: Any = None,
columns: Any = None,
**kwargs: Any,
) -> LocalBoundedDataFrame:
dio = DuckDBIO(self.connection)
return dio.load_df(path, format_hint, columns, **kwargs)
[docs]
def save_df(
self,
df: DataFrame,
path: str,
format_hint: Any = None,
mode: str = "overwrite",
partition_spec: Optional[PartitionSpec] = None,
force_single: bool = False,
**kwargs: Any,
) -> None:
partition_spec = partition_spec or PartitionSpec()
if not partition_spec.empty and not force_single:
kwargs["partition_cols"] = partition_spec.partition_by
dio = DuckDBIO(self.connection)
dio.save_df(_to_duck_df(self, df), path, format_hint, mode, **kwargs)
[docs]
def convert_yield_dataframe(self, df: DataFrame, as_local: bool) -> DataFrame:
if as_local:
return df.as_local()
# self._ctx_count <= 1 means it is either not in a context or in the top
# leve context where the engine was constructed
return df.as_local() if self._ctx_count <= 1 and not self._external_con else df
def _sql(self, sql: str, dfs: Dict[str, DataFrame]) -> DuckDataFrame:
with self._context_lock:
df = self.sql_engine.select(
DataFrames(dfs), StructuredRawSQL.from_expr(sql, dialect=None)
)
return DuckDataFrame(df.native) # type: ignore
def _to_duck_df(
engine: DuckExecutionEngine, df: Any, schema: Any = None, create_view: bool = False
) -> DuckDataFrame:
def _gen_duck() -> DuckDataFrame:
if isinstance(df, DuckDBPyRelation):
assert_or_throw(
schema is None,
ValueError("schema must be None when df is a DuckDBPyRelation"),
)
return DuckDataFrame(df)
if isinstance(df, DataFrame):
assert_or_throw(
schema is None,
ValueError("schema must be None when df is a DataFrame"),
)
if isinstance(df, DuckDataFrame):
return df
rdf = DuckDataFrame(
duckdb.from_arrow(df.as_arrow(), connection=engine.connection)
)
rdf.reset_metadata(df.metadata if df.has_metadata else None)
return rdf
tdf = ArrowDataFrame(df, schema)
return DuckDataFrame(
duckdb.from_arrow(tdf.native, connection=engine.connection)
)
res = _gen_duck()
if create_view:
with engine._context_lock:
if res.alias not in engine._registered_dfs:
res.native.create_view(res.alias, replace=True)
# must hold the reference of the df so the id will not be reused
engine._registered_dfs[res.alias] = res
return res