import copy
from typing import Any, Callable, Dict, List, Optional, Type, Union, no_type_check
from triad import ParamDict, Schema
from triad.utils.assertion import assert_arg_not_none, assert_or_throw
from triad.utils.convert import get_caller_global_local_vars, to_function, to_instance
from triad.utils.hash import to_uuid
from fugue._utils.interfaceless import is_class_method, parse_output_schema_from_comment
from fugue._utils.registry import fugue_plugin
from fugue.dataframe import ArrayDataFrame, DataFrame, DataFrames, LocalDataFrame
from fugue.dataframe.function_wrapper import DataFrameFunctionWrapper
from fugue.exceptions import FugueInterfacelessError
from fugue.extensions.transformer.constants import OUTPUT_TRANSFORMER_DUMMY_SCHEMA
from fugue.extensions.transformer.transformer import CoTransformer, Transformer
from .._utils import (
load_namespace_extensions,
parse_validation_rules_from_comment,
to_validation_rules,
)
_TRANSFORMER_REGISTRY = ParamDict()
_OUT_TRANSFORMER_REGISTRY = ParamDict()
@fugue_plugin
def parse_transformer(obj: Any) -> Any:
"""Parse an object to another object that can be converted to a Fugue
:class:`~fugue.extensions.transformer.transformer.Transformer`.
.. admonition:: Examples
.. code-block:: python
from fugue import Transformer, parse_transformer, FugueWorkflow
from triad import to_uuid
class My(Transformer):
def __init__(self, x):
self.x = x
...
def __uuid__(self) -> str:
return to_uuid(super().__uuid__(), self.x)
@parse_transformer.candidate(
lambda x: isinstance(x, str) and x.startswith("-*"))
def _parse(obj):
return My(obj)
dag = FugueWorkflow()
dag.df([[0]], "a:int").transform("-*abc")
# == dag.df([[0]], "a:int").transform(My("-*abc"))
dag.run()
"""
if isinstance(obj, str) and obj in _TRANSFORMER_REGISTRY:
return _TRANSFORMER_REGISTRY[obj]
return obj
@fugue_plugin
def parse_output_transformer(obj: Any) -> Any:
"""Parse an object to another object that can be converted to a Fugue
:class:`~fugue.extensions.transformer.transformer.OutputTransformer`.
.. admonition:: Examples
.. code-block:: python
from fugue import Transformer, parse_output_transformer, FugueWorkflow
from triad import to_uuid
class My(OutputTransformer):
def __init__(self, x):
self.x = x
...
def __uuid__(self) -> str:
return to_uuid(super().__uuid__(), self.x)
@parse_output_transformer.candidate(
lambda x: isinstance(x, str) and x.startswith("-*"))
def _parse(obj):
return My(obj)
dag = FugueWorkflow()
dag.df([[0]], "a:int").out_transform("-*abc")
# == dag.df([[0]], "a:int").out_transform(My("-*abc"))
dag.run()
"""
if isinstance(obj, str) and obj in _OUT_TRANSFORMER_REGISTRY:
return _OUT_TRANSFORMER_REGISTRY[obj]
return obj
class _FuncAsTransformer(Transformer):
def validate_on_compile(self) -> None:
super().validate_on_compile()
_validate_callback(self)
def get_output_schema(self, df: DataFrame) -> Any:
return self._parse_schema(self._output_schema_arg, df) # type: ignore
def get_format_hint(self) -> Optional[str]:
return self._format_hint # type: ignore
@property
def validation_rules(self) -> Dict[str, Any]:
return self._validation_rules # type: ignore
@no_type_check
def transform(self, df: LocalDataFrame) -> LocalDataFrame:
args = [df] + _get_callback(self)
return self._wrapper.run(
args, self.params, ignore_unknown=False, output_schema=self.output_schema
)
def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self._wrapper(*args, **kwargs) # type: ignore
@no_type_check
def __uuid__(self) -> str:
return to_uuid(self._wrapper.__uuid__(), self._output_schema_arg)
def _parse_schema(self, obj: Any, df: DataFrame) -> Schema:
if callable(obj):
return obj(df, **self.params)
if isinstance(obj, str):
return df.schema.transform(obj)
if isinstance(obj, List):
return df.schema.transform(*obj)
raise NotImplementedError # pragma: no cover
@staticmethod
def from_func(
func: Callable, schema: Any, validation_rules: Dict[str, Any]
) -> "_FuncAsTransformer":
if schema is None:
schema = parse_output_schema_from_comment(func)
if isinstance(schema, Schema): # to be less strict on determinism
schema = str(schema)
validation_rules.update(parse_validation_rules_from_comment(func))
assert_arg_not_none(schema, "schema")
tr = _FuncAsTransformer()
tr._wrapper = DataFrameFunctionWrapper( # type: ignore
func, "^[lspqr][fF]?x*z?$", "^[lspqr]$"
)
tr._output_schema_arg = schema # type: ignore
tr._validation_rules = validation_rules # type: ignore
tr._uses_callback = "f" in tr._wrapper.input_code.lower() # type: ignore
tr._requires_callback = "F" in tr._wrapper.input_code # type: ignore
tr._format_hint = tr._wrapper.get_format_hint() # type: ignore
return tr
class _FuncAsOutputTransformer(_FuncAsTransformer):
def validate_on_compile(self) -> None:
super().validate_on_compile()
_validate_callback(self)
def get_output_schema(self, df: DataFrame) -> Any:
return OUTPUT_TRANSFORMER_DUMMY_SCHEMA
def get_format_hint(self) -> Optional[str]:
return self._format_hint # type: ignore
@no_type_check
def transform(self, df: LocalDataFrame) -> LocalDataFrame:
args = [df] + _get_callback(self)
self._wrapper.run(args, self.params, ignore_unknown=False, output=False)
return ArrayDataFrame([], OUTPUT_TRANSFORMER_DUMMY_SCHEMA)
@staticmethod
def from_func(
func: Callable, schema: Any, validation_rules: Dict[str, Any]
) -> "_FuncAsOutputTransformer":
assert_or_throw(schema is None, "schema must be None for output transformers")
validation_rules.update(parse_validation_rules_from_comment(func))
tr = _FuncAsOutputTransformer()
tr._wrapper = DataFrameFunctionWrapper( # type: ignore
func, "^[lspqr][fF]?x*z?$", "^[lspnqr]$"
)
tr._output_schema_arg = None # type: ignore
tr._validation_rules = validation_rules # type: ignore
tr._uses_callback = "f" in tr._wrapper.input_code.lower() # type: ignore
tr._requires_callback = "F" in tr._wrapper.input_code # type: ignore
tr._format_hint = tr._wrapper.get_format_hint() # type: ignore
return tr
class _FuncAsCoTransformer(CoTransformer):
def validate_on_compile(self) -> None:
super().validate_on_compile()
_validate_callback(self)
def get_output_schema(self, dfs: DataFrames) -> Any:
return self._parse_schema(self._output_schema_arg, dfs) # type: ignore
def get_format_hint(self) -> Optional[str]:
return self._format_hint # type: ignore
@property
def validation_rules(self) -> ParamDict:
return self._validation_rules # type: ignore
@no_type_check
def transform(self, dfs: DataFrames) -> LocalDataFrame:
cb = _get_callback(self)
if self._dfs_input: # function has DataFrames input
return self._wrapper.run( # type: ignore
[dfs] + cb,
self.params,
ignore_unknown=False,
output_schema=self.output_schema,
)
if not dfs.has_key: # input does not have key
return self._wrapper.run( # type: ignore
list(dfs.values()) + cb,
self.params,
ignore_unknown=False,
output_schema=self.output_schema,
)
else: # input DataFrames has key
p = dict(dfs)
p.update(self.params)
return self._wrapper.run( # type: ignore
[] + cb, p, ignore_unknown=False, output_schema=self.output_schema
)
def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self._wrapper(*args, **kwargs) # type: ignore
@no_type_check
def __uuid__(self) -> str:
return to_uuid(
self._wrapper.__uuid__(), self._output_schema_arg, self._dfs_input
)
def _parse_schema(self, obj: Any, dfs: DataFrames) -> Schema:
if callable(obj):
return obj(dfs, **self.params)
if isinstance(obj, str):
return Schema(obj)
if isinstance(obj, List):
s = Schema()
for x in obj:
s += self._parse_schema(x, dfs)
return s
return Schema(obj)
@staticmethod
def from_func(
func: Callable, schema: Any, validation_rules: Dict[str, Any]
) -> "_FuncAsCoTransformer":
assert_or_throw(
len(validation_rules) == 0,
NotImplementedError("CoTransformer does not support validation rules"),
)
if schema is None:
schema = parse_output_schema_from_comment(func)
if isinstance(schema, Schema): # to be less strict on determinism
schema = str(schema)
if isinstance(schema, str):
assert_or_throw(
"*" not in schema,
FugueInterfacelessError(
"* can't be used on cotransformer output schema"
),
)
assert_arg_not_none(schema, "schema")
tr = _FuncAsCoTransformer()
tr._wrapper = DataFrameFunctionWrapper( # type: ignore
func, "^(c|[lspq]+)[fF]?x*z?$", "^[lspqr]$"
)
tr._dfs_input = tr._wrapper.input_code[0] == "c" # type: ignore
tr._output_schema_arg = schema # type: ignore
tr._validation_rules = {} # type: ignore
tr._uses_callback = "f" in tr._wrapper.input_code.lower() # type: ignore
tr._requires_callback = "F" in tr._wrapper.input_code # type: ignore
tr._format_hint = tr._wrapper.get_format_hint() # type: ignore
return tr
class _FuncAsOutputCoTransformer(_FuncAsCoTransformer):
def validate_on_compile(self) -> None:
super().validate_on_compile()
_validate_callback(self)
def get_output_schema(self, dfs: DataFrames) -> Any:
return OUTPUT_TRANSFORMER_DUMMY_SCHEMA
def get_format_hint(self) -> Optional[str]:
return self._format_hint # type: ignore
@no_type_check
def transform(self, dfs: DataFrames) -> LocalDataFrame:
cb = _get_callback(self)
if self._dfs_input: # function has DataFrames input
self._wrapper.run( # type: ignore
[dfs] + cb,
self.params,
ignore_unknown=False,
output=False,
)
elif not dfs.has_key: # input does not have key
self._wrapper.run( # type: ignore
list(dfs.values()) + cb,
self.params,
ignore_unknown=False,
output=False,
)
else: # input DataFrames has key
p = dict(dfs)
p.update(self.params)
self._wrapper.run(
[] + cb, p, ignore_unknown=False, output=False # type: ignore
)
return ArrayDataFrame([], OUTPUT_TRANSFORMER_DUMMY_SCHEMA)
@staticmethod
def from_func(
func: Callable, schema: Any, validation_rules: Dict[str, Any]
) -> "_FuncAsOutputCoTransformer":
assert_or_throw(schema is None, "schema must be None for output cotransformers")
assert_or_throw(
len(validation_rules) == 0,
NotImplementedError("CoTransformer does not support validation rules"),
)
tr = _FuncAsOutputCoTransformer()
tr._wrapper = DataFrameFunctionWrapper( # type: ignore
func, "^(c|[lspq]+)[fF]?x*z?$", "^[lspnqr]$"
)
tr._dfs_input = tr._wrapper.input_code[0] == "c" # type: ignore
tr._output_schema_arg = None # type: ignore
tr._validation_rules = {} # type: ignore
tr._uses_callback = "f" in tr._wrapper.input_code.lower() # type: ignore
tr._requires_callback = "F" in tr._wrapper.input_code # type: ignore
tr._format_hint = tr._wrapper.get_format_hint() # type: ignore
return tr
def _to_transformer(
obj: Any,
schema: Any = None,
global_vars: Optional[Dict[str, Any]] = None,
local_vars: Optional[Dict[str, Any]] = None,
validation_rules: Optional[Dict[str, Any]] = None,
) -> Union[Transformer, CoTransformer]:
global_vars, local_vars = get_caller_global_local_vars(global_vars, local_vars)
load_namespace_extensions(obj)
return _to_general_transformer(
obj=parse_transformer(obj),
schema=schema,
global_vars=global_vars,
local_vars=local_vars,
validation_rules=validation_rules,
func_transformer_type=_FuncAsTransformer,
func_cotransformer_type=_FuncAsCoTransformer,
)
def _to_output_transformer(
obj: Any,
global_vars: Optional[Dict[str, Any]] = None,
local_vars: Optional[Dict[str, Any]] = None,
validation_rules: Optional[Dict[str, Any]] = None,
) -> Union[Transformer, CoTransformer]:
global_vars, local_vars = get_caller_global_local_vars(global_vars, local_vars)
load_namespace_extensions(obj)
return _to_general_transformer(
obj=parse_output_transformer(obj),
schema=None,
global_vars=global_vars,
local_vars=local_vars,
validation_rules=validation_rules,
func_transformer_type=_FuncAsOutputTransformer,
func_cotransformer_type=_FuncAsOutputCoTransformer,
)
def _to_general_transformer( # noqa: C901
obj: Any,
schema: Any,
global_vars: Optional[Dict[str, Any]],
local_vars: Optional[Dict[str, Any]],
validation_rules: Optional[Dict[str, Any]],
func_transformer_type: Type,
func_cotransformer_type: Type,
) -> Union[Transformer, CoTransformer]:
global_vars, local_vars = get_caller_global_local_vars(global_vars, local_vars)
exp: Optional[Exception] = None
if validation_rules is None:
validation_rules = {}
try:
return copy.copy(
to_instance(
obj, Transformer, global_vars=global_vars, local_vars=local_vars
)
)
except Exception as e:
exp = e
try:
return copy.copy(
to_instance(
obj, CoTransformer, global_vars=global_vars, local_vars=local_vars
)
)
except Exception as e:
exp = e
try:
f = to_function(obj, global_vars=global_vars, local_vars=local_vars)
# this is for string expression of function with decorator
if isinstance(f, Transformer):
return copy.copy(f)
# this is for functions without decorator
return func_transformer_type.from_func(
f, schema, validation_rules=validation_rules
)
except Exception as e:
exp = e
try:
f = to_function(obj, global_vars=global_vars, local_vars=local_vars)
# this is for string expression of function with decorator
if isinstance(f, CoTransformer):
return copy.copy(f)
# this is for functions without decorator
return func_cotransformer_type.from_func(
f, schema, validation_rules=validation_rules
)
except Exception as e:
exp = e
raise FugueInterfacelessError(f"{obj} is not a valid transformer", exp)
def _validate_callback(ctx: Any) -> None:
if ctx._requires_callback:
assert_or_throw(
ctx.has_callback,
FugueInterfacelessError(f"Callback is required but not provided: {ctx}"),
)
def _get_callback(ctx: Any) -> List[Any]:
uses_callback = ctx._uses_callback
requires_callback = ctx._requires_callback
if not uses_callback:
return []
if requires_callback:
assert_or_throw(
ctx.has_callback,
FugueInterfacelessError(f"Callback is required but not provided: {ctx}"),
)
return [ctx.callback]
return [ctx.callback if ctx.has_callback else None]