なんか考えてることとか

変な人が主にプログラミング関連で考えていることをまとめる。

Python 3.11で末尾再帰が書けるようになる

  • 2022/10/28
    • factorial関数の定義に誤りがあったので修正

3.10以前のPythonでは、再帰関数に対して「末尾呼び出し最適化」など何も対策をしていないために、普通に再帰関数を定義すると、問題が発生してしまっていた。

しかしPython 3.11でその問題が解消された再帰関数が書けるようになった。今日は再帰関数の問題点と、Python 3.11における再帰関数の定義の手法について書いていこうと思う。

1. 再帰関数とは

プログラミング言語において「再帰関数」とは、関数内部で自身を呼び出すような仕組みを持つ関数である。たとえばPythonだと以下のように定義する。

def rec(x: int):
    if x <= 0:
        print('end of recursion.')
    else:
        print(x)
        return rec(x - 1)

まず最初に再帰関数の特徴として見てもらいたいのが、自身の呼び出しを戻り値として返している点である。

        return rec(x - 1)

これにより0 < xであったとき、rec(x - 1)を呼び出す。呼び出した関数内でまた0 < xであったとき、rec(x - 1)を呼び出す。さらに呼び出した関数内でまた0 < xであったとき、・・・を繰り返していき、関数呼び出しの際にx1引くという処理を繰り返しているため、最終的にはx \leq 0のときに関数を終了するという処理が可能となる。print()を除くと、以下のように計算を行っていると考えることができる(ただし endは関数終了を意味する)。

 rec(n) \\
→ \textrm{if} \, n \leq 0 \, \textrm{then} \, end \, \textrm{else} \, rec(n - 1) \\
→ rec(n - 1) \\
→ \textrm{if} \, (n - 1) \leq 0 \, \textrm{then} \, end \, \textrm{else} \, rec((n - 1) - 1) \\
→ rec((n - 1) - 1) \\
→ \, ... \\
→ rec(0) \\
→ \textrm{if} \, 0 \leq 0 \, \textrm{then} \, end \, \textrm{else} \, rec(0 - 1) \\
→ end

例として、rec(10)を実行してみる。

>>> rec(10)
10
9
8
7
6
5
4
3
2
1
end of recursion.

賢明な方ならわかると思うが、再帰while文やfor文と同様に繰り返し処理として表現することが可能で、さらに再帰という操作自体は副作用*1がない
副作用を極力避ける純粋関数型言語では、そもそもwhile, forが言語機能として提供されておらず、繰り返し処理は基本再帰で記述する*2

2. 再帰関数の問題点

このようにメリットもあるため、積極的に使っていこうと思うかもしれないが、実はPython 3.10以前においては、むしろ再帰を使うことは避けるべきであった。その最たる理由としては、メモリ消費の問題である。

基本的に関数は、呼び出すたびにスタックメモリを確保し、関数が終了すると開放するようになっている。ということは、再帰関数では関数を終了するまで、自身の呼び出しによりどんどんスタックメモリが確保されていくのである。

メモリにも種類があり、スタックメモリはその一種であるが、スタックメモリは一時的に使っていくという側面が強いためにその容量はほかのメモリと比べて少ないのである。そのため、再帰関数はスタックメモリが枯渇するという問題が容易に発生しうる。ちなみに、スタックメモリが枯渇した際にスタックメモリを確保することで発生する問題をスタックオーバーフローと言う。

3. 他言語での再帰関数による問題への対策

先ほど、Python 3.10以前では再帰関数は定義できるがスタックオーバーフローの問題があるため、使用は避けるべきであると書いた。では、スタックオーバーフローを起こさないような再帰関数を定義することは可能なのだろうか?

他言語であれば、それは可能である。

たとえば、Haskellなどは、評価戦略として遅延評価(厳密には必要呼びと呼ばれている評価戦略)と言われる、Pythonとは異なる評価戦略を採用している。詳細はこちらのサイトに任せるが、遅延評価を採用したことで、あまり意識することなくスタックオーバーフローの起こらない再帰関数を定義できるのである*3

また、末尾呼び出し最適化をしてくれるコンパイラは、末尾再帰*4を書いた場合に限りスタックオーバーフローをしないような関数に変換してくれる。これにより、末尾再帰を意識して書くことで、スタックオーバーフローの発生しない再帰関数の定義が可能となる。

4. Python 3.11での再帰関数

Python 3.11から、スタックオーバーフローが発生しない再帰関数の定義ができるようになった

なぜかと言うと、Python 3.11では「CPython高速化計画」の一環で「関数呼び出しのインライン化」といった処理速度を高速にするためのCPythonソースコードの最適化が行われたためである。つまり、再帰関数に対する対策と言うよりかは、最適化の副産物である。

4.1. 再帰関数の例

「関数呼び出しのインライン化」により、たとえば階乗を求める関数factorial再帰を使うことで以下のように定義することが可能となった。

def factorial(n: int) -> int:
    '''
    階乗を求める関数
    '''
    def _factorial(n: int, acc: int) -> int:
        if n <= 0:
            return acc
        else:
            return _factorial(n - 1, n * acc)
    
    return _factorial(n, 1)

ここで、Pythonでは3.10以前からスタックオーバーフローを回避するために最大再帰回数に上限が設けられており、その上限を超えるとRecursionErrorが発生する*5。そのため極端に回数の多い再帰関数を実行する場合、あらかじめsysモジュールのsetrecursionlimit()に大きい整数値を渡しておくべきだろう。

>>> import sys
>>> # たとえば1億回弱再帰するなら1億に設定する
>>> sys.setrecursionlimit(100_000_000)

ではfactorial(100)を実行してみる。

>>> factorial(100)
93326215443944152681699238856266700490715968264381621468592963895217599993229915608941463976156518286253697920827223758251185210916864000000000000000000000000

Haskellfoldl/foldr関数も以下のようにいとも簡単に定義できる(簡単のために、型ヒントはlist型に絞っているが)。ただしpop()メソッドを使った副作用が必要である。理由については4.2.にて後述する。

from typing import TypeVar
from collections.abc import Callable

T = TypeVar('T')
U = TypeVar('U')

def foldl(f: Callable[[U, T], U], acc: U, xs: list[T]) -> U:
    '''
    foldl関数
    '''
    if xs == []:
        return acc
    else:
        return foldl(f, f(acc, xs.pop(0)), xs)

def foldr(f: Callable[[T, U], U], acc: U, xs: list[T]) -> U:
    '''
    foldr関数
    '''
    if xs == []:
        return acc
    else:
        return foldr(f, f(xs.pop(), acc), xs)

foldl/foldr関数は、リストの全要素を用いて新しい値を生成する処理を一般化したものである。たとえばfoldl関数でsum関数を、foldr関数でmap関数を定義してみる。

from typing import TypeVar
from collections.abc import Callable
from operator import add


def sum_(xs: list[int | float]) -> int | float:
    '''
    sum関数
    '''
    return foldl(add, 0, xs)


T = TypeVar('T')
U = TypeVar('U')

def map_(f: Callable[[T], U], xs: list[T]) -> list[U]:
    '''
    map関数
    '''
    return foldr(lambda x, acc: (acc.insert(0, f(x)), acc)[1], [], xs)

>>> sum_([i+1 for i in range(10)])
55
>>> map_(lambda x: f'{x}', [i+1 for i in range(5)])
['1', '2', '3', '4', '5']

4.2. 注意点

このように、Python 3.11で今まで作れなかった「スタックオーバーフローにならない再帰関数」が定義できるようになったのは、偶然であるにしろ、かなり魅力的な要素だと言える。

しかし、できるようになったとは言っても、依然として「スタックオーバーフローが発生する再帰関数」を定義することは容易にできてしまう。たとえば、factorial関数の定義を見ていて、以下のようにすればもっと簡単に定義できると思わなかっただろうか。

def factorial(n: int) -> int:
    '''
    階乗を求める関数
    '''
    if n <= 0:
        return 1
    else:
        return n * factorial(n - 1)

このような再帰関数は定義してはならない。なぜなら、このように定義した再帰関数はスタックオーバーフローが発生するからである。以下のように戻り値で使われているn * factorial(n - 1)という式はスタックメモリを使っており、factorial()0が渡されるまでその計算が終わることがないために、どんどんスタックメモリを確保してしまうのである。

        return n * factorial(n - 1)

これを回避する方法は、戻り値を再帰呼び出しだけにすることである。そのためには、factorial関数を呼び出した際、以下のように、別の形の関数に変換してやる必要がある。

    def _factorial(n: int, acc: int) -> int:
        if n <= 0:
            return acc
        else:
            return _factorial(n - 1, n * acc)

この関数で戻り値としているのは、関数の再帰呼び出しだけである。この戻り値が再帰呼び出しだけとなっている再帰を「末尾再帰」と呼ぶ。あとは以下のように定義した関数内関数の呼び出しを戻り値にするだけで、所望の再帰関数を定義できる。

    return _factorial(n, 1)

また、listを用いる際も注意が必要である。たとえばHaskellなどでは、foldr関数は以下のように副作用がなくとも定義できるし、(新しいリスト生成する場合に限るが)スタックオーバーフローは発生しない。

foldr' :: (a -> b -> b) -> b -> [a] -> b
foldr' _ acc [] = acc
foldr' f acc (x:xs) = f x $ foldr' f acc xs

これをPython的に(末尾再帰にすることを意識したうえで)定義すると、以下のようになる。

from typing import TypeVar
from collections.abc import Callable

T = TypeVar('T')
U = TypeVar('U')

def foldr(f: Callable[[T, U], U], acc: U, xs: list[T]) -> U:
    '''
    foldr関数
    '''
    if xs == []:
        return acc
    else:
        return foldr(f, f(xs[-1], acc), xs[:-1])

これで問題ないと思われるかもしれないが、この試みは失敗する。なぜならlistのスライシングはlistの一部をコピーしたに過ぎないからである。つまり、新しいlistの生成を繰り返すことによりスタックメモリを確保し続けていくので、最終的にスタックオーバーフローを引き起こしてしまうのである。

これを回避するためには、残念ながら以下のように副作用を用いるしかない。

        return foldr(f, f(xs.pop(), acc), xs)

これはlistの結合の場合も同様である。

# これはNG
foldr(lambda x, acc: [f'{x}'] + acc, [], [1, 2, 3, 4, 5])

# これはOK
foldr(lambda x, acc: (acc.insert(0, f'{x}'), acc)[1], [], [1, 2, 3, 4, 5])

5. 終わりに

このように、Python 3.11ではCPythonコードの最適化の結果、「スタックオーバーフローの発生しない再帰関数」を定義できるようになった。
気を付けなければならないことも多少あるものの、Python 3.10でmatch文が追加されたこともあり、着実に関数型プログラミングとしてのPythonは進化してきていることは間違いないだろう。

これは最適化の副産物であるため、これ以上の進化は望めないかもしれないが、それでも、これ以上の関数型プログラミングとしてのPythonの進化には期待したい。

*1:式の評価以外の要因で値が変化していくこと

*2:ただしHaskellのようにモナドを使った「関数」として標準で用意されていることもある

*3:ただし、一部では評価を遅延させるのではなく、即時に評価させるような工夫がなければ別の要因によりスタックオーバーフローが発生する場合もある

*4:戻り値が関数の呼び出し「のみ」になっている再帰である

*5:ここで、「スタックオーバーフローが起こらない再帰関数」が定義できるようになったので、この制限はいらないと感じるかもしれない。これに関しては4.2.にて後述する