ebay_client/.venv_codegen/Lib/site-packages/more_itertools/recipes.py

1476 lines
41 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Imported from the recipes section of the itertools documentation.
All functions taken from the recipes section of the itertools library docs
[1]_.
Some backward-compatible usability improvements have been made.
.. [1] http://docs.python.org/library/itertools.html#recipes
"""
import random
from bisect import bisect_left, insort
from collections import deque
from contextlib import suppress
from functools import lru_cache, reduce
from heapq import heappush, heappushpop
from itertools import (
accumulate,
chain,
combinations,
compress,
count,
cycle,
filterfalse,
groupby,
islice,
pairwise as itertools_pairwise,
product,
repeat,
starmap,
takewhile,
tee,
zip_longest,
)
from math import prod, comb, isqrt, gcd
from operator import mul, getitem, index, is_, itemgetter, truediv
from random import randrange, sample, choice, shuffle
from sys import hexversion
__all__ = [
'all_equal',
'batched',
'before_and_after',
'consume',
'convolve',
'dotproduct',
'first_true',
'factor',
'flatten',
'grouper',
'is_prime',
'iter_except',
'iter_index',
'loops',
'matmul',
'multinomial',
'ncycles',
'nth',
'nth_combination',
'padnone',
'pad_none',
'pairwise',
'partition',
'polynomial_eval',
'polynomial_from_roots',
'polynomial_derivative',
'powerset',
'prepend',
'quantify',
'reshape',
'random_combination_with_replacement',
'random_combination',
'random_derangement',
'random_permutation',
'random_product',
'repeatfunc',
'roundrobin',
'running_mean',
'running_median',
'sieve',
'sliding_window',
'subslices',
'sum_of_squares',
'tabulate',
'tail',
'take',
'totient',
'transpose',
'triplewise',
'unique',
'unique_everseen',
'unique_justseen',
]
_marker = object()
# heapq max-heap functions are available for Python 3.14+
try:
from heapq import heappush_max, heappushpop_max
except ImportError: # pragma: no cover
_max_heap_available = False
else: # pragma: no cover
_max_heap_available = True
def take(n, iterable):
"""Return first *n* items of the *iterable* as a list.
>>> take(3, range(10))
[0, 1, 2]
If there are fewer than *n* items in the iterable, all of them are
returned.
>>> take(10, range(3))
[0, 1, 2]
"""
return list(islice(iterable, n))
def tabulate(function, start=0):
"""Return an iterator over the results of ``func(start)``,
``func(start + 1)``, ``func(start + 2)``...
*func* should be a function that accepts one integer argument.
If *start* is not specified it defaults to 0. It will be incremented each
time the iterator is advanced.
>>> square = lambda x: x ** 2
>>> iterator = tabulate(square, -3)
>>> take(4, iterator)
[9, 4, 1, 0]
"""
return map(function, count(start))
def tail(n, iterable):
"""Return an iterator over the last *n* items of *iterable*.
>>> t = tail(3, 'ABCDEFG')
>>> list(t)
['E', 'F', 'G']
"""
try:
size = len(iterable)
except TypeError:
return iter(deque(iterable, maxlen=n))
else:
return islice(iterable, max(0, size - n), None)
def consume(iterator, n=None):
"""Advance *iterable* by *n* steps. If *n* is ``None``, consume it
entirely.
Efficiently exhausts an iterator without returning values. Defaults to
consuming the whole iterator, but an optional second argument may be
provided to limit consumption.
>>> i = (x for x in range(10))
>>> next(i)
0
>>> consume(i, 3)
>>> next(i)
4
>>> consume(i)
>>> next(i)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
StopIteration
If the iterator has fewer items remaining than the provided limit, the
whole iterator will be consumed.
>>> i = (x for x in range(3))
>>> consume(i, 5)
>>> next(i)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
StopIteration
"""
# Use functions that consume iterators at C speed.
if n is None:
# feed the entire iterator into a zero-length deque
deque(iterator, maxlen=0)
else:
# advance to the empty slice starting at position n
next(islice(iterator, n, n), None)
def nth(iterable, n, default=None):
"""Returns the nth item or a default value.
>>> l = range(10)
>>> nth(l, 3)
3
>>> nth(l, 20, "zebra")
'zebra'
"""
return next(islice(iterable, n, None), default)
def all_equal(iterable, key=None):
"""
Returns ``True`` if all the elements are equal to each other.
>>> all_equal('aaaa')
True
>>> all_equal('aaab')
False
A function that accepts a single argument and returns a transformed version
of each input item can be specified with *key*:
>>> all_equal('AaaA', key=str.casefold)
True
>>> all_equal([1, 2, 3], key=lambda x: x < 10)
True
"""
iterator = groupby(iterable, key)
for first in iterator:
for second in iterator:
return False
return True
return True
def quantify(iterable, pred=bool):
"""Return the how many times the predicate is true.
>>> quantify([True, False, True])
2
"""
return sum(map(pred, iterable))
def pad_none(iterable):
"""Returns the sequence of elements and then returns ``None`` indefinitely.
>>> take(5, pad_none(range(3)))
[0, 1, 2, None, None]
Useful for emulating the behavior of the built-in :func:`map` function.
See also :func:`padded`.
"""
return chain(iterable, repeat(None))
padnone = pad_none
def ncycles(iterable, n):
"""Returns the sequence elements *n* times
>>> list(ncycles(["a", "b"], 3))
['a', 'b', 'a', 'b', 'a', 'b']
"""
return chain.from_iterable(repeat(tuple(iterable), n))
def dotproduct(vec1, vec2):
"""Returns the dot product of the two iterables.
>>> dotproduct([10, 15, 12], [0.65, 0.80, 1.25])
33.5
>>> 10 * 0.65 + 15 * 0.80 + 12 * 1.25
33.5
In Python 3.12 and later, use ``math.sumprod()`` instead.
"""
return sum(map(mul, vec1, vec2))
# math.sumprod is available for Python 3.12+
try:
from math import sumprod as _sumprod
except ImportError: # pragma: no cover
_sumprod = dotproduct
def flatten(list_of_lists):
"""Return an iterator flattening one level of nesting in a list of lists.
>>> list(flatten([[0, 1], [2, 3]]))
[0, 1, 2, 3]
See also :func:`collapse`, which can flatten multiple levels of nesting.
"""
return chain.from_iterable(list_of_lists)
def repeatfunc(function, times=None, *args):
"""Call *function* with *args* repeatedly, returning an iterable over the
results.
If *times* is specified, the iterable will terminate after that many
repetitions:
>>> from operator import add
>>> times = 4
>>> args = 3, 5
>>> list(repeatfunc(add, times, *args))
[8, 8, 8, 8]
If *times* is ``None`` the iterable will not terminate:
>>> from random import randrange
>>> times = None
>>> args = 1, 11
>>> take(6, repeatfunc(randrange, times, *args)) # doctest:+SKIP
[2, 4, 8, 1, 8, 4]
"""
if times is None:
return starmap(function, repeat(args))
return starmap(function, repeat(args, times))
def pairwise(iterable):
"""
Wrapper for :func:`itertools.pairwise`.
.. warning::
This function is deprecated as of version 11.0.0. It will be removed in a future
major release.
"""
return itertools_pairwise(iterable)
def grouper(iterable, n, incomplete='fill', fillvalue=None):
"""Group elements from *iterable* into fixed-length groups of length *n*.
>>> list(grouper('ABCDEF', 3))
[('A', 'B', 'C'), ('D', 'E', 'F')]
The keyword arguments *incomplete* and *fillvalue* control what happens for
iterables whose length is not a multiple of *n*.
When *incomplete* is `'fill'`, the last group will contain instances of
*fillvalue*.
>>> list(grouper('ABCDEFG', 3, incomplete='fill', fillvalue='x'))
[('A', 'B', 'C'), ('D', 'E', 'F'), ('G', 'x', 'x')]
When *incomplete* is `'ignore'`, the last group will not be emitted.
>>> list(grouper('ABCDEFG', 3, incomplete='ignore', fillvalue='x'))
[('A', 'B', 'C'), ('D', 'E', 'F')]
When *incomplete* is `'strict'`, a `ValueError` will be raised.
>>> iterator = grouper('ABCDEFG', 3, incomplete='strict')
>>> list(iterator) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError
"""
iterators = [iter(iterable)] * n
match incomplete:
case 'fill':
return zip_longest(*iterators, fillvalue=fillvalue)
case 'strict':
return zip(*iterators, strict=True)
case 'ignore':
return zip(*iterators)
case _:
raise ValueError('Expected fill, strict, or ignore')
def roundrobin(*iterables):
"""Visit input iterables in a cycle until each is exhausted.
>>> list(roundrobin('ABC', 'D', 'EF'))
['A', 'D', 'E', 'B', 'F', 'C']
This function produces the same output as :func:`interleave_longest`, but
may perform better for some inputs (in particular when the number of
iterables is small).
"""
# Algorithm credited to George Sakkis
iterators = map(iter, iterables)
for num_active in range(len(iterables), 0, -1):
iterators = cycle(islice(iterators, num_active))
yield from map(next, iterators)
def partition(pred, iterable):
"""
Returns a 2-tuple of iterables derived from the input iterable.
The first yields the items that have ``pred(item) == False``.
The second yields the items that have ``pred(item) == True``.
>>> is_odd = lambda x: x % 2 != 0
>>> iterable = range(10)
>>> even_items, odd_items = partition(is_odd, iterable)
>>> list(even_items), list(odd_items)
([0, 2, 4, 6, 8], [1, 3, 5, 7, 9])
If *pred* is None, :func:`bool` is used.
>>> iterable = [0, 1, False, True, '', ' ']
>>> false_items, true_items = partition(None, iterable)
>>> list(false_items), list(true_items)
([0, False, ''], [1, True, ' '])
"""
if pred is None:
pred = bool
iterator = iter(iterable)
false_queue = deque()
true_queue = deque()
def gen(queue):
while True:
while queue:
yield queue.popleft()
for value in iterator:
(true_queue if pred(value) else false_queue).append(value)
break
else:
return
return gen(false_queue), gen(true_queue)
def powerset(iterable):
"""Yields all possible subsets of the iterable.
>>> list(powerset([1, 2, 3]))
[(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]
:func:`powerset` will operate on iterables that aren't :class:`set`
instances, so repeated elements in the input will produce repeated elements
in the output.
>>> seq = [1, 1, 0]
>>> list(powerset(seq))
[(), (1,), (1,), (0,), (1, 1), (1, 0), (1, 0), (1, 1, 0)]
For a variant that efficiently yields actual :class:`set` instances, see
:func:`powerset_of_sets`.
"""
s = list(iterable)
return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))
def unique_everseen(iterable, key=None):
"""Yield unique elements, preserving order. Remember all elements ever seen.
>>> list(unique_everseen('AAAABBBCCDAABBB'))
['A', 'B', 'C', 'D']
>>> list(unique_everseen('ABBCcAD', str.casefold))
['A', 'B', 'C', 'D']
Raises ``TypeError`` for unhashable items.
Some unhashable objects can be converted to hashable objects
using the *key* parameter:
* For ``list`` objects, try ``key=tuple``.
* For ``set`` objects, try ``key=frozenset``.
* For ``dict`` objects, try ``key=lambda x: frozenset(x.items())``
or in Python 3.15 and later, set ``key=frozendict``.
Alternatively, consider the ``unique()`` itertool recipe. It sorts
the data and then uses equality to eliminate duplicates. Hashability
is not required.
"""
seen = set()
if key is None:
for element in filterfalse(seen.__contains__, iterable):
seen.add(element)
yield element
else:
for element in iterable:
k = key(element)
if k not in seen:
seen.add(k)
yield element
def unique_justseen(iterable, key=None):
"""Yields elements in order, ignoring serial duplicates
>>> list(unique_justseen('AAAABBBCCDAABBB'))
['A', 'B', 'C', 'D', 'A', 'B']
>>> list(unique_justseen('ABBCcAD', str.lower))
['A', 'B', 'C', 'A', 'D']
"""
if key is None:
return map(itemgetter(0), groupby(iterable))
return map(next, map(itemgetter(1), groupby(iterable, key)))
def unique(iterable, key=None, reverse=False):
"""Yields unique elements in sorted order.
>>> list(unique([[1, 2], [3, 4], [1, 2]]))
[[1, 2], [3, 4]]
*key* and *reverse* are passed to :func:`sorted`.
>>> list(unique('ABBcCAD', str.casefold))
['A', 'B', 'c', 'D']
>>> list(unique('ABBcCAD', str.casefold, reverse=True))
['D', 'c', 'B', 'A']
The elements in *iterable* need not be hashable, but they must be
comparable for sorting to work.
"""
sequenced = sorted(iterable, key=key, reverse=reverse)
return unique_justseen(sequenced, key=key)
def iter_except(function, exception, first=None):
"""Yields results from a function repeatedly until an exception is raised.
Converts a call-until-exception interface to an iterator interface.
Like ``iter(function, sentinel)``, but uses an exception instead of a sentinel
to end the loop.
>>> l = [0, 1, 2]
>>> list(iter_except(l.pop, IndexError))
[2, 1, 0]
Multiple exceptions can be specified as a stopping condition:
>>> l = [1, 2, 3, '...', 4, 5, 6]
>>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
[7, 6, 5]
>>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
[4, 3, 2]
>>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
[]
"""
with suppress(exception):
if first is not None:
yield first()
while True:
yield function()
def first_true(iterable, default=None, pred=None):
"""
Returns the first true value in the iterable.
If no true value is found, returns *default*
If *pred* is not None, returns the first item for which
``pred(item) == True`` .
>>> first_true(range(10))
1
>>> first_true(range(10), pred=lambda x: x > 5)
6
>>> first_true(range(10), default='missing', pred=lambda x: x > 9)
'missing'
"""
return next(filter(pred, iterable), default)
def random_product(*iterables, repeat=1):
"""Draw an item at random from each of the input iterables.
>>> random_product('abc', range(4), 'XYZ') # doctest:+SKIP
('c', 3, 'Z')
If *repeat* is provided as a keyword argument, that many items will be
drawn from each iterable.
>>> random_product('abcd', range(4), repeat=2) # doctest:+SKIP
('a', 2, 'd', 3)
This equivalent to taking a random selection from
``itertools.product(*args, repeat=repeat)``.
"""
pools = tuple(map(tuple, iterables)) * repeat
return tuple(map(choice, pools))
def random_permutation(iterable, r=None):
"""Return a random *r* length permutation of the elements in *iterable*.
If *r* is not specified or is ``None``, then *r* defaults to the length of
*iterable*.
>>> random_permutation(range(5)) # doctest:+SKIP
(3, 4, 0, 1, 2)
This equivalent to taking a random selection from
``itertools.permutations(iterable, r)``.
"""
pool = tuple(iterable)
r = len(pool) if r is None else r
return tuple(sample(pool, r))
def random_combination(iterable, r):
"""Return a random *r* length subsequence of the elements in *iterable*.
>>> random_combination(range(5), 3) # doctest:+SKIP
(2, 3, 4)
This equivalent to taking a random selection from
``itertools.combinations(iterable, r)``.
"""
pool = tuple(iterable)
n = len(pool)
indices = sorted(sample(range(n), r))
return tuple([pool[i] for i in indices])
def random_combination_with_replacement(iterable, r):
"""Return a random *r* length subsequence of elements in *iterable*,
allowing individual elements to be repeated.
>>> random_combination_with_replacement(range(3), 5) # doctest:+SKIP
(0, 0, 1, 2, 2)
This equivalent to taking a random selection from
``itertools.combinations_with_replacement(iterable, r)``.
"""
pool = tuple(iterable)
n = len(pool)
indices = sorted(randrange(n) for i in range(r))
return tuple([pool[i] for i in indices])
def nth_combination(iterable, r, index):
"""Equivalent to ``list(combinations(iterable, r))[index]``.
The subsequences of *iterable* that are of length *r* can be ordered
lexicographically. :func:`nth_combination` computes the subsequence at
sort position *index* directly, without computing the previous
subsequences.
>>> nth_combination(range(5), 3, 5)
(0, 3, 4)
``ValueError`` will be raised If *r* is negative.
``IndexError`` will be raised if the given *index* is invalid.
"""
pool = tuple(iterable)
n = len(pool)
c = comb(n, r)
if index < 0:
index += c
if not 0 <= index < c:
raise IndexError
result = []
while r:
c, n, r = c * r // n, n - 1, r - 1
while index >= c:
index -= c
c, n = c * (n - r) // n, n - 1
result.append(pool[-1 - n])
return tuple(result)
def prepend(value, iterable):
"""Yield *value*, followed by the elements in *iterable*.
>>> value = '0'
>>> iterable = ['1', '2', '3']
>>> list(prepend(value, iterable))
['0', '1', '2', '3']
To prepend multiple values, see :func:`itertools.chain`
or :func:`value_chain`.
"""
return chain([value], iterable)
def convolve(signal, kernel):
"""Discrete linear convolution of two iterables.
Equivalent to polynomial multiplication.
For example, multiplying ``(x² -x - 20)`` by ``(x - 3)``
gives ``(x³ -4x² -17x + 60)``.
>>> list(convolve([1, -1, -20], [1, -3]))
[1, -4, -17, 60]
Examples of popular kinds of kernels:
* The kernel ``[0.25, 0.25, 0.25, 0.25]`` computes a moving average.
For image data, this blurs the image and reduces noise.
* The kernel ``[1/2, 0, -1/2]`` estimates the first derivative of
a function evaluated at evenly spaced inputs.
* The kernel ``[1, -2, 1]`` estimates the second derivative of a
function evaluated at evenly spaced inputs.
Convolutions are mathematically commutative; however, the inputs are
evaluated differently. The signal is consumed lazily and can be
infinite. The kernel is fully consumed before the calculations begin.
Supports all numeric types: int, float, complex, Decimal, Fraction.
References:
* Article: https://betterexplained.com/articles/intuitive-convolution/
* Video by 3Blue1Brown: https://www.youtube.com/watch?v=KuXjwB4LzSA
"""
# This implementation comes from an older version of the itertools
# documentation. While the newer implementation is a bit clearer,
# this one was kept because the inlined window logic is faster
# and it avoids an unnecessary deque-to-tuple conversion.
kernel = tuple(kernel)[::-1]
n = len(kernel)
window = deque([0], maxlen=n) * n
for x in chain(signal, repeat(0, n - 1)):
window.append(x)
yield _sumprod(kernel, window)
def before_and_after(predicate, it):
"""A variant of :func:`takewhile` that allows complete access to the
remainder of the iterator.
>>> it = iter('ABCdEfGhI')
>>> all_upper, remainder = before_and_after(str.isupper, it)
>>> ''.join(all_upper)
'ABC'
>>> ''.join(remainder) # takewhile() would lose the 'd'
'dEfGhI'
Note that the first iterator must be fully consumed before the second
iterator can generate valid results.
"""
trues, after = tee(it)
trues = compress(takewhile(predicate, trues), zip(after))
return trues, after
def triplewise(iterable):
"""Return overlapping triplets from *iterable*.
>>> list(triplewise('ABCDE'))
[('A', 'B', 'C'), ('B', 'C', 'D'), ('C', 'D', 'E')]
"""
# This deviates from the itertools documentation recipe - see
# https://github.com/more-itertools/more-itertools/issues/889
t1, t2, t3 = tee(iterable, 3)
next(t3, None)
next(t3, None)
next(t2, None)
return zip(t1, t2, t3)
def _sliding_window_islice(iterable, n):
# Fast path for small, non-zero values of n.
iterators = tee(iterable, n)
for i, iterator in enumerate(iterators):
next(islice(iterator, i, i), None)
return zip(*iterators)
def _sliding_window_deque(iterable, n):
# Normal path for other values of n.
iterator = iter(iterable)
window = deque(islice(iterator, n - 1), maxlen=n)
for x in iterator:
window.append(x)
yield tuple(window)
def sliding_window(iterable, n):
"""Return a sliding window of width *n* over *iterable*.
>>> list(sliding_window(range(6), 4))
[(0, 1, 2, 3), (1, 2, 3, 4), (2, 3, 4, 5)]
If *iterable* has fewer than *n* items, then nothing is yielded:
>>> list(sliding_window(range(3), 4))
[]
For a variant with more features, see :func:`windowed`.
"""
if n > 20:
return _sliding_window_deque(iterable, n)
elif n > 2:
return _sliding_window_islice(iterable, n)
elif n == 2:
return pairwise(iterable)
elif n == 1:
return zip(iterable)
else:
raise ValueError(f'n should be at least one, not {n}')
def subslices(iterable):
"""Return all contiguous non-empty subslices of *iterable*.
>>> list(subslices('ABC'))
[['A'], ['A', 'B'], ['A', 'B', 'C'], ['B'], ['B', 'C'], ['C']]
This is similar to :func:`substrings`, but emits items in a different
order.
"""
seq = list(iterable)
slices = starmap(slice, combinations(range(len(seq) + 1), 2))
return map(getitem, repeat(seq), slices)
def polynomial_from_roots(roots):
"""Compute a polynomial's coefficients from its roots.
>>> roots = [5, -4, 3] # (x - 5) * (x + 4) * (x - 3)
>>> polynomial_from_roots(roots) # x³ - 4 x² - 17 x + 60
[1, -4, -17, 60]
Note that polynomial coefficients are specified in descending power order.
Supports all numeric types: int, float, complex, Decimal, Fraction.
"""
# This recipe differs from the one in itertools docs in that it
# applies list() after each call to convolve(). This avoids
# hitting stack limits with nested generators.
poly = [1]
for root in roots:
poly = list(convolve(poly, (1, -root)))
return poly
def iter_index(iterable, value, start=0, stop=None):
"""Yield the index of each place in *iterable* that *value* occurs,
beginning with index *start* and ending before index *stop*.
>>> list(iter_index('AABCADEAF', 'A'))
[0, 1, 4, 7]
>>> list(iter_index('AABCADEAF', 'A', 1)) # start index is inclusive
[1, 4, 7]
>>> list(iter_index('AABCADEAF', 'A', 1, 7)) # stop index is not inclusive
[1, 4]
The behavior for non-scalar *values* matches the built-in Python types.
>>> list(iter_index('ABCDABCD', 'AB'))
[0, 4]
>>> list(iter_index([0, 1, 2, 3, 0, 1, 2, 3], [0, 1]))
[]
>>> list(iter_index([[0, 1], [2, 3], [0, 1], [2, 3]], [0, 1]))
[0, 2]
See :func:`locate` for a more general means of finding the indexes
associated with particular values.
"""
seq_index = getattr(iterable, 'index', None)
if seq_index is None:
# Slow path for general iterables
iterator = islice(iterable, start, stop)
for i, element in enumerate(iterator, start):
if element is value or element == value:
yield i
else:
# Fast path for sequences
stop = len(iterable) if stop is None else stop
i = start - 1
with suppress(ValueError):
while True:
yield (i := seq_index(value, i + 1, stop))
def sieve(n):
"""Yield the primes less than n.
>>> list(sieve(30))
[2, 3, 5, 7, 11, 13, 17, 19, 23, 29]
"""
# This implementation comes from an older version of the itertools
# documentation. The newer implementation is easier to read but is
# less lazy.
if n > 2:
yield 2
start = 3
data = bytearray((0, 1)) * (n // 2)
for p in iter_index(data, 1, start, stop=isqrt(n) + 1):
yield from iter_index(data, 1, start, p * p)
data[p * p : n : p + p] = bytes(len(range(p * p, n, p + p)))
start = p * p
yield from iter_index(data, 1, start)
def _batched(iterable, n, *, strict=False): # pragma: no cover
"""Batch data into tuples of length *n*. If the number of items in
*iterable* is not divisible by *n*:
* The last batch will be shorter if *strict* is ``False``.
* :exc:`ValueError` will be raised if *strict* is ``True``.
>>> list(batched('ABCDEFG', 3))
[('A', 'B', 'C'), ('D', 'E', 'F'), ('G',)]
On Python 3.13 and above, this is an alias for :func:`itertools.batched`.
"""
if n < 1:
raise ValueError('n must be at least one')
iterator = iter(iterable)
while batch := tuple(islice(iterator, n)):
if strict and len(batch) != n:
raise ValueError('batched(): incomplete batch')
yield batch
if hexversion >= 0x30D00A2: # pragma: no cover
from itertools import batched as itertools_batched
def batched(iterable, n, *, strict=False):
return itertools_batched(iterable, n, strict=strict)
batched.__doc__ = _batched.__doc__
else: # pragma: no cover
batched = _batched
def transpose(matrix):
"""Swap the rows and columns of the input matrix.
>>> list(transpose([(1, 2, 3), (11, 22, 33)]))
[(1, 11), (2, 22), (3, 33)]
The caller should ensure that the dimensions of the input are compatible.
If the input is empty, no output will be produced.
"""
return zip(*matrix, strict=True)
def _is_scalar(value, stringlike=(str, bytes)):
"Scalars are bytes, strings, and non-iterables."
try:
iter(value)
except TypeError:
return True
return isinstance(value, stringlike)
def _flatten_tensor(tensor):
"Depth-first iterator over scalars in a tensor."
iterator = iter(tensor)
while True:
try:
value = next(iterator)
except StopIteration:
return iterator
iterator = chain((value,), iterator)
if _is_scalar(value):
return iterator
iterator = chain.from_iterable(iterator)
def reshape(matrix, shape):
"""Change the shape of a *matrix*.
If *shape* is an integer, the matrix must be two dimensional
and the shape is interpreted as the desired number of columns:
>>> matrix = [(0, 1), (2, 3), (4, 5)]
>>> cols = 3
>>> list(reshape(matrix, cols))
[(0, 1, 2), (3, 4, 5)]
If *shape* is a tuple (or other iterable), the input matrix can have
any number of dimensions. It will first be flattened and then rebuilt
to the desired shape which can also be multidimensional:
>>> matrix = [(0, 1), (2, 3), (4, 5)] # Start with a 3 x 2 matrix
>>> list(reshape(matrix, (2, 3))) # Make a 2 x 3 matrix
[(0, 1, 2), (3, 4, 5)]
>>> list(reshape(matrix, (6,))) # Make a vector of length six
[0, 1, 2, 3, 4, 5]
>>> list(reshape(matrix, (2, 1, 3, 1))) # Make 2 x 1 x 3 x 1 tensor
[(((0,), (1,), (2,)),), (((3,), (4,), (5,)),)]
Each dimension is assumed to be uniform, either all arrays or all scalars.
Flattening stops when the first value in a dimension is a scalar.
Scalars are bytes, strings, and non-iterables.
The reshape iterator stops when the requested shape is complete
or when the input is exhausted, whichever comes first.
"""
if isinstance(shape, int):
return batched(chain.from_iterable(matrix), shape)
first_dim, *dims = shape
scalar_stream = _flatten_tensor(matrix)
reshaped = reduce(batched, reversed(dims), scalar_stream)
return islice(reshaped, first_dim)
def matmul(m1, m2):
"""Multiply two matrices.
>>> list(matmul([(7, 5), (3, 5)], [(2, 5), (7, 9)]))
[(49, 80), (41, 60)]
The caller should ensure that the dimensions of the input matrices are
compatible with each other.
Supports all numeric types: int, float, complex, Decimal, Fraction.
"""
n = len(m2[0])
return batched(starmap(_sumprod, product(m1, transpose(m2))), n)
def _factor_pollard(n):
# Return a factor of n using Pollard's rho algorithm.
# Efficient when n is odd and composite.
for b in range(1, n):
x = y = 2
d = 1
while d == 1:
x = (x * x + b) % n
y = (y * y + b) % n
y = (y * y + b) % n
d = gcd(x - y, n)
if d != n:
return d
raise ValueError('prime or under 5') # pragma: no cover
_primes_below_211 = tuple(sieve(211))
def factor(n):
"""Yield the prime factors of n.
>>> list(factor(360))
[2, 2, 2, 3, 3, 5]
Finds small factors with trial division. Larger factors are
either verified as prime with ``is_prime`` or split into
smaller factors with Pollard's rho algorithm.
"""
# Corner case reduction
if n < 2:
return
# Trial division reduction
for prime in _primes_below_211:
while not n % prime:
yield prime
n //= prime
# Pollard's rho reduction
primes = []
todo = [n] if n > 1 else []
for n in todo:
if n < 211**2 or is_prime(n):
primes.append(n)
else:
fact = _factor_pollard(n)
todo += (fact, n // fact)
yield from sorted(primes)
def polynomial_eval(coefficients, x):
"""Evaluate a polynomial at a specific value.
Computes with better numeric stability than Horner's method.
Evaluate ``x^3 - 4 * x^2 - 17 * x + 60`` at ``x = 2.5``:
>>> coefficients = [1, -4, -17, 60]
>>> x = 2.5
>>> polynomial_eval(coefficients, x)
8.125
Note that polynomial coefficients are specified in descending power order.
Supports all numeric types: int, float, complex, Decimal, Fraction.
"""
n = len(coefficients)
if n == 0:
return type(x)(0)
powers = map(pow, repeat(x), reversed(range(n)))
return _sumprod(coefficients, powers)
def sum_of_squares(iterable):
"""Return the sum of the squares of the input values.
>>> sum_of_squares([10, 20, 30])
1400
Supports all numeric types: int, float, complex, Decimal, Fraction.
"""
return _sumprod(*tee(iterable))
def polynomial_derivative(coefficients):
"""Compute the first derivative of a polynomial.
Evaluate the derivative of ``x³ - 4 x² - 17 x + 60``:
>>> coefficients = [1, -4, -17, 60]
>>> derivative_coefficients = polynomial_derivative(coefficients)
>>> derivative_coefficients
[3, -8, -17]
Note that polynomial coefficients are specified in descending power order.
Supports all numeric types: int, float, complex, Decimal, Fraction.
"""
n = len(coefficients)
powers = reversed(range(1, n))
return list(map(mul, coefficients, powers))
def totient(n):
"""Return the count of natural numbers up to *n* that are coprime with *n*.
Euler's totient function φ(n) gives the number of totatives.
Totative are integers k in the range 1 ≤ k ≤ n such that gcd(n, k) = 1.
>>> n = 9
>>> totient(n)
6
>>> totatives = [x for x in range(1, n) if gcd(n, x) == 1]
>>> totatives
[1, 2, 4, 5, 7, 8]
>>> len(totatives)
6
Reference: https://en.wikipedia.org/wiki/Euler%27s_totient_function
"""
for prime in set(factor(n)):
n -= n // prime
return n
# MillerRabin primality test: https://oeis.org/A014233
_perfect_tests = [
(2047, (2,)),
(9080191, (31, 73)),
(4759123141, (2, 7, 61)),
(1122004669633, (2, 13, 23, 1662803)),
(2152302898747, (2, 3, 5, 7, 11)),
(3474749660383, (2, 3, 5, 7, 11, 13)),
(18446744073709551616, (2, 325, 9375, 28178, 450775, 9780504, 1795265022)),
(
3317044064679887385961981,
(2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41),
),
]
@lru_cache
def _shift_to_odd(n):
'Return s, d such that 2**s * d == n'
s = ((n - 1) ^ n).bit_length() - 1
d = n >> s
assert (1 << s) * d == n and d & 1 and s >= 0
return s, d
def _strong_probable_prime(n, base):
assert (n > 2) and (n & 1) and (2 <= base < n)
s, d = _shift_to_odd(n - 1)
x = pow(base, d, n)
if x == 1 or x == n - 1:
return True
for _ in range(s - 1):
x = x * x % n
if x == n - 1:
return True
return False
# Separate instance of Random() that doesn't share state
# with the default user instance of Random().
_private_randrange = random.Random().randrange
def is_prime(n):
"""Return ``True`` if *n* is prime and ``False`` otherwise.
Basic examples:
>>> is_prime(37)
True
>>> is_prime(3 * 13)
False
>>> is_prime(18_446_744_073_709_551_557)
True
Find the next prime over one billion:
>>> next(filter(is_prime, count(10**9)))
1000000007
Generate random primes up to 200 bits and up to 60 decimal digits:
>>> from random import seed, randrange, getrandbits
>>> seed(18675309)
>>> next(filter(is_prime, map(getrandbits, repeat(200))))
893303929355758292373272075469392561129886005037663238028407
>>> next(filter(is_prime, map(randrange, repeat(10**60))))
269638077304026462407872868003560484232362454342414618963649
This function is exact for values of *n* below 10**24. For larger inputs,
the probabilistic Miller-Rabin primality test has a less than 1 in 2**128
chance of a false positive.
"""
if n < 17:
return n in {2, 3, 5, 7, 11, 13}
if not (n & 1 and n % 3 and n % 5 and n % 7 and n % 11 and n % 13):
return False
for limit, bases in _perfect_tests:
if n < limit:
break
else:
bases = (_private_randrange(2, n - 1) for i in range(64))
return all(_strong_probable_prime(n, base) for base in bases)
def loops(n):
"""Returns an iterable with *n* elements for efficient looping.
Like ``range(n)`` but doesn't create integers.
>>> i = 0
>>> for _ in loops(5):
... i += 1
>>> i
5
"""
return repeat(None, n)
def multinomial(*counts):
"""Number of distinct arrangements of a multiset.
The expression ``multinomial(3, 4, 2)`` has several equivalent
interpretations:
* In the expansion of ``(a + b + c)⁹``, the coefficient of the
``a³b⁴c²`` term is 1260.
* There are 1260 distinct ways to arrange 9 balls consisting of 3 reds, 4
greens, and 2 blues.
* There are 1260 unique ways to place 9 distinct objects into three bins
with sizes 3, 4, and 2.
The :func:`multinomial` function computes the length of
:func:`distinct_permutations`. For example, there are 83,160 distinct
anagrams of the word "abracadabra":
>>> from more_itertools import distinct_permutations, ilen
>>> ilen(distinct_permutations('abracadabra'))
83160
This can be computed directly from the letter counts, 5a 2b 2r 1c 1d:
>>> from collections import Counter
>>> list(Counter('abracadabra').values())
[5, 2, 2, 1, 1]
>>> multinomial(5, 2, 2, 1, 1)
83160
A binomial coefficient is a special case of multinomial where there are
only two categories. For example, the number of ways to arrange 12 balls
with 5 reds and 7 blues is ``multinomial(5, 7)`` or ``math.comb(12, 5)``.
Likewise, factorial is a special case of multinomial where
the multiplicities are all just 1 so that
``multinomial(1, 1, 1, 1, 1, 1, 1) == math.factorial(7)``.
Reference: https://en.wikipedia.org/wiki/Multinomial_theorem
"""
return prod(map(comb, accumulate(counts), counts))
def _running_median_minheap_and_maxheap(iterator): # pragma: no cover
"Non-windowed running_median() for Python 3.14+"
read = iterator.__next__
lo = [] # max-heap
hi = [] # min-heap (same size as or one smaller than lo)
with suppress(StopIteration):
while True:
heappush_max(lo, heappushpop(hi, read()))
yield lo[0]
heappush(hi, heappushpop_max(lo, read()))
yield (lo[0] + hi[0]) / 2
def _running_median_minheap_only(iterator): # pragma: no cover
"Backport of non-windowed running_median() for Python 3.13 and prior."
read = iterator.__next__
lo = [] # max-heap (actually a minheap with negated values)
hi = [] # min-heap (same size as or one smaller than lo)
with suppress(StopIteration):
while True:
heappush(lo, -heappushpop(hi, read()))
yield -lo[0]
heappush(hi, -heappushpop(lo, -read()))
yield (hi[0] - lo[0]) / 2
def _running_median_windowed(iterator, maxlen):
"Yield median of values in a sliding window."
window = deque()
ordered = []
for x in iterator:
window.append(x)
insort(ordered, x)
if len(ordered) > maxlen:
i = bisect_left(ordered, window.popleft())
del ordered[i]
n = len(ordered)
m = n // 2
yield ordered[m] if n & 1 else (ordered[m - 1] + ordered[m]) / 2
def running_median(iterable, *, maxlen=None):
"""Cumulative median of values seen so far or values in a sliding window.
Set *maxlen* to a positive integer to specify the maximum size
of the sliding window. The default of *None* is equivalent to
an unbounded window.
For example:
>>> list(running_median([5.0, 9.0, 4.0, 12.0, 8.0, 9.0]))
[5.0, 7.0, 5.0, 7.0, 8.0, 8.5]
>>> list(running_median([5.0, 9.0, 4.0, 12.0, 8.0, 9.0], maxlen=3))
[5.0, 7.0, 5.0, 9.0, 8.0, 9.0]
Supports numeric types such as int, float, Decimal, and Fraction,
but not complex numbers which are unorderable.
On version Python 3.13 and prior, max-heaps are simulated with
negative values. The negation causes Decimal inputs to apply context
rounding, making the results slightly different than that obtained
by statistics.median().
"""
iterator = iter(iterable)
if maxlen is not None:
maxlen = index(maxlen)
if maxlen <= 0:
raise ValueError('Window size should be positive')
return _running_median_windowed(iterator, maxlen)
if not _max_heap_available:
return _running_median_minheap_only(iterator) # pragma: no cover
return _running_median_minheap_and_maxheap(iterator) # pragma: no cover
def _windowed_running_mean(iterator, n):
window = deque()
running_sum = 0
for value in iterator:
window.append(value)
running_sum += value
if len(window) > n:
running_sum -= window.popleft()
yield running_sum / len(window)
def running_mean(iterable, *, maxlen=None):
"""Cumulative mean of values seen so far or values in a sliding window.
Set *maxlen* to a positive integer to specify the maximum size
of the sliding window. The default of *None* is equivalent to
an unbounded window.
For example:
>>> list(running_mean([40, 30, 50, 46, 39, 44]))
[40.0, 35.0, 40.0, 41.5, 41.0, 41.5]
>>> list(running_mean([40, 30, 50, 46, 39, 44], maxlen=3))
[40.0, 35.0, 40.0, 42.0, 45.0, 43.0]
Supports numeric types such as int, float, complex, Decimal, and Fraction.
No extra effort is made to reduce round-off errors for float inputs.
So the results may be slightly different from `statistics.mean`.
"""
iterator = iter(iterable)
if maxlen is None:
return map(truediv, accumulate(iterator), count(1))
if maxlen <= 0:
raise ValueError('Window size should be positive')
return _windowed_running_mean(iterator, maxlen)
def random_derangement(iterable):
"""Return a random derangement of elements in the iterable.
Equivalent to but much faster than ``choice(list(derangements(iterable)))``.
"""
seq = tuple(iterable)
if len(seq) < 2:
if len(seq) == 0:
return ()
raise IndexError('No derangments to choose from')
perm = list(range(len(seq)))
start = tuple(perm)
while True:
shuffle(perm)
if not any(map(is_, start, perm)):
return itemgetter(*perm)(seq)