Monads/Writer monad: Difference between revisions

Content added Content deleted
m (Added language identifier.)
m (→‎{{header|Python}}: Complete type hints)
Line 1,128: Line 1,128:
=={{header|Python}}==
=={{header|Python}}==


<syntaxhighlight lang="python">"""A Writer Monad. Requires Python >= 3.7 for type hints."""
<syntaxhighlight lang="python">
"""A Writer Monad. Requires Python >= 3.7 for type hints."""
from __future__ import annotations
from __future__ import annotations


Line 1,135: Line 1,136:
import os
import os


from typing import Any
from typing import Callable
from typing import Callable
from typing import Generic
from typing import Generic
Line 1,144: Line 1,144:


T = TypeVar("T")
T = TypeVar("T")
U = TypeVar("U")




Line 1,155: Line 1,156:
self.msgs = list(f"{msg}: {self.value}" for msg in msgs)
self.msgs = list(f"{msg}: {self.value}" for msg in msgs)


def bind(self, func: Callable[[T], Writer[Any]]) -> Writer[Any]:
def bind(self, func: Callable[[T], Writer[U]]) -> Writer[U]:
writer = func(self.value)
writer = func(self.value)
return Writer(writer, *self.msgs)
return Writer(writer, *self.msgs)


def __rshift__(self, func: Callable[[T], Writer[Any]]) -> Writer[Any]:
def __rshift__(self, func: Callable[[T], Writer[U]]) -> Writer[U]:
return self.bind(func)
return self.bind(func)


Line 1,169: Line 1,170:




def lift(func: Callable, msg: str) -> Callable[[Any], Writer[Any]]:
def lift(func: Callable[[T], U], msg: str) -> Callable[[T], Writer[U]]:
"""Return a writer monad version of the simple function `func`."""
"""Return a writer monad version of the simple function `func`."""


@functools.wraps(func)
@functools.wraps(func)
def wrapped(value):
def wrapped(value: T) -> Writer[U]:
return Writer(func(value), msg)
return Writer(func(value), msg)


Line 1,181: Line 1,182:
if __name__ == "__main__":
if __name__ == "__main__":
square_root = lift(math.sqrt, "square root")
square_root = lift(math.sqrt, "square root")

add_one = lift(lambda x: x + 1, "add one")
add_one: Callable[[Union[int, float]], Writer[Union[int, float]]] = lift(
half = lift(lambda x: x / 2, "div two")
lambda x: x + 1, "add one"
)

half: Callable[[Union[int, float]], Writer[float]] = lift(
lambda x: x / 2, "div two"
)


print(Writer(5, "initial") >> square_root >> add_one >> half)
print(Writer(5, "initial") >> square_root >> add_one >> half)