TinyMT の更新関数 tinymt32_next_state(tinymt32_t*) の逆関数に関する考察

TinyMT の更新関数 tinymt32_next_state(tinymt32_t*) の逆関数において、逆関数適用後の status[0] の最上位bitが不定値になる問題について調べてみた。


前置きが長いので「本題」で検索してヒットところから読めば充分だと思います。タブンネ
理論なんていいから最終的なプログラムだけ教えろ!って人は最終節(「まとめ」で検索してヒットしたあたり)だけ読めばいいと思います。メガタブンネ


準備1

非負整数 \(n\) および \(N\) ビットの非負整数 \(x\) に対して、次の関数 \(f_{R}\left(x\right),\,f_{L}\left(x\right)\) の逆関数 \(f_{R}^{-1}\left(x\right),\,f_{L}^{-1}\left(x\right)\) は次式で表される。

(*1)
\begin{eqnarray*}
f_{R}\left(x\right)&\overset{\mathrm{def}}{=}&x\oplus\left(x >> n\right)\\
f_{L}\left(x\right)&\overset{\mathrm{def}}{=}&x\oplus\left(x << n\right)
\end{eqnarray*}

\[\begin{array}[rlcl]\\
f_{R}^{-1}\left(x\right)
&\left(\;=\quad\displaystyle\bigoplus_{k=0}^{t}\left(x >> \left(k\cdot n\right)\right)\quad\right)
&=&\displaystyle\bigoplus_{k=0}^{T}\left(x >> \left(2^{k}\cdot n\right)\right)\\
f_{L}^{-1}\left(x\right)
&\left(\;=\quad\displaystyle\bigoplus_{k=0}^{t}\left(x << \left(k\cdot n\right)\right)\quad\right)
&=&\displaystyle\bigoplus_{k=0}^{T}\left(x << \left(2^{k}\cdot n\right)\right)
\end{array}\]


ただし \(t\) および \(T\) は以下の通り。
\begin{eqnarray*}
t&\overset{\mathrm{def}}{=}&\max\left(\left\lceil \frac{N}{\left|n\right|}\right\rceil-1,0\right)\\
T&\overset{\mathrm{def}}{=}&\max\left(\left\lceil\log_{2}\frac{N}{\left|n\right|}\right\rceil,0\right)
\end{eqnarray*}


参考:メルセンヌ・ツイスタのtemperingの逆関数に関する考察 - Plus Le Blog
参考:Xorshift の逆関数に関する考察/正逆双方向に計算できる Xorshift - Plus Le Blog

これをプログラムで書くと、以下のようになる( \(N=32\) の場合)。

// C#
private const int N = 32;
private readonly int n;
private uint f_R(uint x) {
	x ^= ( x >> n );
	return x;
}
private uint f_L(uint x) {
	x ^= ( x << n );
	return x;
}
private uint f_R_inv(uint x) {
	for( int k=n;k<N;k<<=1 ) {
		x ^= ( x >> k );
	}
	return x;
}
private uint f_L_inv(uint x) {
	for( int k=n;k<N;k<<=1 ) {
		x ^= ( x << k );
	}
	return x;
}

準備2

本記事のプログラムで使う定数および変数を以下に定めておく。

// C#
private const int BIT_SIZE = 32;        // 変数のビット長
private const int SH0 = 1;              // TINYMT32_SH0
private const int SH1 = 10;             // TINYMT32_SH1
private const uint MASK = 0x7fffffffu;  // TINYMT32_MASK
private const uint MAT1 = 0x8f7011eeu;  // tinymt32_t.mat1
private const uint MAT2 = 0xfc78ff1fu;  // tinymt32_t.mat2
private uint[] status = new uint[4];    // tinymt32_t.status

更新関数 Next()

まず、C++ で書かれた TinyMT の更新関数 tinymt32_next_state(tinymt32_t * random) を C# に書き換えてみる。C++ ニガテ……
(※元のソースはTiny Mersenne Twister (TinyMT)参照。)

// C#
// 更新関数:元のソースの tinymt32_next_state(tinymt32_t * random) に相当
public void Next() {
	uint y = status[3];
	uint x = ( status[0] & MASK ) ^ status[1] ^ status[2];
	x ^= ( x << SH0 );
	y ^= ( y >> SH0 ) ^ x;
	status[0] = status[1];
	status[1] = status[2];
	status[2] = x ^ ( y << SH1 );
	status[3] = y;
	if( ( y & 0x1u ) == 0x1u ) {
		status[1] ^= MAT1;
		status[2] ^= MAT2;
	}
	return;
}

これを、数式に変換しやすいような形式に加工する。具体的には、

  • 2文字以上の変数名・変数名を1文字に変更( SH0,SH1 を除く)
    • status[0], status[1], status[2], status[3] をそれぞれ p, q, r, s とおいた
  • 1行で複数の処理をしているところを分解(例:「 y ^= ( y >> SH0 ) ^ x; 」→「 y ^= ( y >> SH0 ); y ^= x; 」)
  • 既に使った変数への代入を廃止(例:「 x ^= ( x << SH0 ) 」→「 x2 = x1 ^ ( x1 << SH0 ) 」)
  • 分岐処理の条件判断部を関数化して三項演算子に変更(例:「 if( ( y & 0x1u ) == 0x1u ) { …… } 」→「 IsOdd(y) ? …… : ……; 」)

すると以下のようになる。

// C#
private const uint M  = MASK;
private const uint M1 = MAT1;
private const uint M2 = MAT2;
public static uint[] GetNext(uint p_t,uint q_t,uint r_t,uint s_t) {
	uint y1 = s_t;
	uint x1 = ( p_t & M ) ^ q_t ^ r_t;
	uint x2 = x1 ^ ( x1 << SH0 );
	uint y2 = y1 ^ ( y1 >> SH0 );
	uint y3 = y2 ^ x2;
	uint p_tPlus1 = q_t;
	uint q_dash  = r_t;
	uint r_dash  = x2 ^ ( y3 << SH1 );
	uint s_tPlus1 = y3;
	uint q_tPlus1 = IsOdd(y3)
	              ? q_dash ^ M1
	              : q_dash;
	uint r_tPlus1 = IsOdd(y3)
	              ? r_dash ^ M2
	              : r_dash;
	return new uint[]{ p_tPlus1, q_tPlus1, r_tPlus1, s_tPlus1, };
}
private static bool IsOdd(uint x) {
	return ( x & 0x1u ) == 0x1u;
}

これで更新関数を数式に起こす準備が整った。

数式化

初期状態から数えて \(t\) 回更新後の status を \(p_{t},\,q_{t},\,r_{t},\,s_{t}\) とおく。

(*2)
\begin{eqnarray*}
p_{t}&\overset{\mathrm{def}}{=}&\mathtt{status[0]}\\
q_{t}&\overset{\mathrm{def}}{=}&\mathtt{status[1]}\\
r_{t}&\overset{\mathrm{def}}{=}&\mathtt{status[2]}\\
s_{t}&\overset{\mathrm{def}}{=}&\mathtt{status[3]}
\end{eqnarray*}

このとき、 \(p_{t},\,q_{t},\,r_{t},\,s_{t}\) から \(p_{t+1},\,q_{t+1},\,r_{t+1},\,s_{t+1}\) を求める式は以下のように書き起こされる。

(*3)
\begin{eqnarray*}
y_{1}&=&s_{t}\\
x_{1}&=&\left(p_{t}\cap M\right)\oplus q_{t}\oplus r_{t}\\
x_{2}&=&x_{1}\oplus\left(x_{1} << \mathtt{SH0}\right)\\
y_{2}&=&y_{1}\oplus\left(y_{1} >> \mathtt{SH0}\right)\\
y_{3}&=&y_{2}\oplus x_{2}\\
p_{t+1}&=&q_{t}\\
q'&=&r_{t}\\
r'&=&x_{2}\oplus\left(y_{3} << \mathtt{SH1}\right)\\
s_{t+1}&=&y_{3}\\
q_{t+1}&=&
\begin{cases}
q'\oplus M_{1} & \left(y_{3} \text{が奇数のとき}\right)\\
q' & \left(y_{3} \text{が偶数のとき}\right)
\end{cases}\\
r_{t+1}&=&
\begin{cases}
r'\oplus M_{2} & \left(y_{3} \text{が奇数のとき}\right)\\
r' & \left(y_{3} \text{が偶数のとき}\right)
\end{cases}
\end{eqnarray*}

この11連連立方程式を \(p_{t},\,q_{t},\,r_{t},\,s_{t}\) について解けば、目的の逆関数が得られる。

連立方程式を解く

……方程式を解くと言っても、今回の目的はあくまで逆関数をプログラムで書くことなので、厳密に解く必要はない。
つまり \(x_{1}\) などの途中計算用の変数を消去する必要はない。
というより、消去してしまうと逆に再度プログラムに落とすときにめんどうになる。

これに注意して(*3)を変形すると、次のようになる。
(並び順を概ね下から順にして一部の変数を移項しただけに近いけど……)

(*4)
\begin{eqnarray*}
y_{3}&=&s_{t+1}\\
q'&=&
\begin{cases}
q_{t+1}\oplus M_{1} & \left(y_{3} \text{が奇数のとき}\right)\\
q_{t+1} & \left(y_{3} \text{が偶数のとき}\right)
\end{cases}\\
r'&=&
\begin{cases}
r_{t+1}\oplus M_{2} & \left(y_{3} \text{が奇数のとき}\right)\\
r_{t+1} & \left(y_{3} \text{が偶数のとき}\right)
\end{cases}\\
x_{2}&=&r'\oplus\left(y_{3} << \mathtt{SH1}\right)\\
y_{2}&=&y_{3}\oplus x_{2}\\
x_{1}&=&\bigoplus_{k=0}^{T}\left(x_{2} << \left(2^{k}\cdot\mathtt{SH0}\right)\right)\\
y_{1}&=&\bigoplus_{k=0}^{T}\left(y_{2} >> \left(2^{k}\cdot\mathtt{SH0}\right)\right)\\
s_{t}&=&y_{1}\\
r_{t}&=&q'\\
q_{t}&=&p_{t+1}\\
p_{t}\cap M&=&x_{1}\oplus q_{t}\oplus r_{t}
\end{eqnarray*}


ただし、 \(T\overset{\mathrm{def}}{=}\max\left(\left\lceil\log_{2}\frac{N}{\left|\mathtt{SH0}\right|}\right\rceil,0\right)\) である。

※唐突に出てきた2つの \(\displaystyle\bigoplus_{k=0}^{T}\) については冒頭の準備1を参照。

これをそのままプログラムに落とすと、次のようになるような気がする

// C#
// GetNext() の逆関数<em>の成り損ない</em>
public static uint[] GetNext_inv(uint p_tPlus1,uint q_tPlus1,uint r_tPlus1,uint s_tPlus1) {
	uint y3 = s_tPlus1;
	uint q_dash = IsOdd(y3)
	            ? q_tPlus1 ^ M1
	            : q_tPlus1;
	uint r_dash = IsOdd(y3)
	            ? r_tPlus1 ^ M2
	            : r_tPlus1;
	uint x2 = r_dash ^ ( y3 << SH1 );
	uint y2 = y3 ^ x2;
	uint x1;
	{
		x1 = x2;
		for( int sh=SH0;sh<BIT_SIZE;sh<<=1 ) {
			x1 ^= ( x1 << sh );
		}
	}
	uint y1;
	{
		y1 = y2;
		for( int sh=SH0;sh<BIT_SIZE;sh<<=1 ) {
			y1 ^= ( y1 >> sh );
		}
	}
	uint s_t = y1;
	uint r_t = q_dash;
	uint q_t = p_tPlus1;
	uint p_t = x1 ^ q_t ^ r_t;		// ←???
	return new uint[]{ p_t, q_t, r_t, s_t, };
}

……最後の「???」の行が怪しいが、 \(p_{t+1},\,q_{t+1},\,r_{t+1},\,s_{t+1}\) からこの逆関数で得た \(p_{t},\,q_{t},\,r_{t},\,s_{t}\) に対して再び更新関数を適用すると \(p_{t+1},\,q_{t+1},\,r_{t+1},\,s_{t+1}\) が得られるので、とりあえずこのまま進めることにする。

これでようやく、この記事の本題に入れる。

\(p_{t}\) の最上位ビットについて

さてここで、 \(p_{t}\) の最上位ビットに注目してこれまでの流れを見てみよう。

まず、更新関数 GetNext() の中で \(p_{t}\) を参照しているのは次の1箇所のみ。

// C#
public static uint[] GetNext(uint p_t,uint q_t,uint r_t,uint s_t) {
	// ……
	uint x1 = ( p_t & M ) ^ q_t ^ r_t;
	// ……
	return new uint[]{ p_tPlus1, q_tPlus1, r_tPlus1, s_tPlus1, };
}

この部分を数式で表したのが(*3)で、適当に移項したのが(*4)だった。

(*3)再掲
\[
x_{1}=\left(p_{t}\cap M\right)\oplus q_{t}\oplus r_{t}
\]

(*4)再掲
\[
p_{t}\cap M=x_{1}\oplus q_{t}\oplus r_{t}
\]

これを暫定的にプログラムに落とすと次のようになるような気がしたのだった。

// C#
public static uint[] GetNext_inv(uint p_tPlus1,uint q_tPlus1,uint r_tPlus1,uint s_tPlus1) {
	// ……
	uint p_t = x1 ^ q_t ^ r_t;		// ←???
	// ……
	return new uint[]{ p_t, q_t, r_t, s_t, };
}


この関数は先述の通り確かに更新関数 Next() の逆関数の働きをしているように見えるが、上の「???」の行は次の(*5)を表しているにすぎず、(*4)を表せてはいない。

(*5)
\[
p_{t}=x_{1}\oplus q_{t}\oplus r_{t}
\]

(*4)再掲
\[
p_{t}\cap M=x_{1}\oplus q_{t}\oplus r_{t}
\]

そのため情報が欠落し、この更新逆関数 GetNext_inv() には2回以上連続で使うと元の値を復元できない可能性があるという欠陥がある。

この原因は更新関数 GetNext() 内で \(M=\mathtt{0x7fffffff}\) によるマスク( \(p_{t}\cap M\) )をしており、 \(p_{t}\) の最上位ビットの情報が失われるため。
\(p_{t}\) を参照している箇所は他にないので、 \(p_{t}\) の最上位ビットを復元することはどうやってもできないような気がする

……が、ここで あきらめる なんて 10000こうねん はやいんだよ!▼

\(p_{t}\) の最上位ビットの復元

問題の本質は(*4)をプログラムに落としたときに最上位ビットの情報が失われたことにある。

(*4)再掲
\[
p_{t}\cap M=x_{1}\oplus q_{t}\oplus r_{t}
\]

そもそも(*4)は次の2つの情報を持っている。

(*6)

  • \(x_{1}\oplus q_{t}\oplus r_{t}\) の最上位ビットは \(0\) に等しい
  • \(x_{1}\oplus q_{t}\oplus r_{t}\) の下位31ビットは \(p_{t}\) の下位31ビットに等しい

このうち先のプログラムでは(*6)の第2の情報しか使っていないので、1ビット分の情報(=\(p_t\) の最上位ビット)が欠落する。
逆に言えば、第1の情報を使えば最後の1ビットを復元できるかもしれないということである。タブンネ

そこで、2回連続で更新逆関数 GetNext_inv() を使用したときに不定値1ビットがどう影響するかを見てみよう。

// C#
// 再掲:GetNext() の逆関数<em>の成り損ない</em>
public static uint[] GetNext_inv(uint p_tPlus1,uint q_tPlus1,uint r_tPlus1,uint s_tPlus1) {
	// (前略) ←※ここで p_tPlus1 は未使用
	uint s_t = y1;
	uint r_t = q_dash;
	uint q_t = p_tPlus1;        // ※2
	uint p_t = x1 ^ q_t ^ r_t;  // ※1
	return new uint[]{ p_t, q_t, r_t, s_t, };
}

まずこの逆関数の1回目の適用により、※1の行で \(p_{t}\) の最上位ビットが不定値となる。
1回目の \(p_{t}\) =2回目の \(p_{t+1}\) なので、逆関数の2回目の適用時には \(p_{t+1}\) の最上位ビットが初めから不定値となっている。
すると※2の行により、 \(q_{t}\) の最上位ビットが不定値になる。
逆関数を2回以上適用すると値が復元できない(場合がある)理由はここにある。

しかし(*6)の第1の情報から、(2回目の) \(x_{1}\oplus q_{t}\oplus r_{t}\) は0に等しいはずである。

(*6)再掲

  • \(x_{1}\oplus q_{t}\oplus r_{t}\) の最上位ビットは \(0\) に等しい
  • \(x_{1}\oplus q_{t}\oplus r_{t}\) の下位31ビットは \(p_{t}\) の下位31ビットに等しい

なので、(2回目の) ※2の後で \(x_{1}\oplus q_{t}\oplus r_{t}\) の最上位ビットをチェックし、もしも0に等しくなければ(2回目の) \(q_{t}\) の最上位ビットは誤りであると分かる。
誤りと分かったならビット反転すれば正しい値を得られる。

このことを利用すれば、 \(q_{t}\) の最上位ビットの誤りを防ぐことができる。

// C#
public static uint[] GetNext_inv(uint p_tPlus1,uint q_tPlus1,uint r_tPlus1,uint s_tPlus1) {
	// (前略)
	uint s_t = y1;
	uint r_t = q_dash;
	uint q_t = p_tPlus1;
	if( ( ( x1 ^ q_t ^ r_t ) & 0x80000000u ) == 0x80000000u ) {  // new!
		q_t ^= 0x80000000u;                                      // new!
	}                                                            // new!
	uint p_t = x1 ^ q_t ^ r_t;
	return new uint[]{ p_t, q_t, r_t, s_t, };
}

これだけでは \(p_{t}\) の最上位ビットの誤りを防ぐことはできないが、もともと使われていないビットなのでこれ以上の修正は不要だろう。タブンネ
ちなみに逆関数2回目の \(q_{t}\) =2回目の \(p_{t+1}\) =1回目の \(p_{t}\) なので、2回目の \(q_{t}\) の最上位ビットの正誤判定を1回目の時点で行えば(1回目の) \(p_{t}\) の最上位ビットまで正確に求められる。しかし、更新逆関数を続けて呼ぶ場合には同じ計算を2度行うことになるので無駄だし、そもそも更新関数でも使われないビットなので求める価値もない。

ところで、追加した処理によって(2回目の) \(q_{t}\) の誤りを修正できることは分かったが、1回目の \(q_{t}\) の値が誤りになってしまうのではと不安に思うかもしれない。
しかしそれはありえない。
1回目の逆関数適用ということは更新関数 Next() の直後なので、入力された4変数 \(p_{t+1},\,q_{t+1},\,r_{t+1},\,s_{t+1}\) に不定値を含まない。
そのため、 \(x_{1}\oplus q_{t}\oplus r_{t}\) の最上位ビットは \(p_{t}\cap M\) の最上位ビット(すなわち0)に等しい。

// C#
// 再掲:更新関数
public static uint[] GetNext(uint p_t,uint q_t,uint r_t,uint s_t) {
	// ……
	uint x1 = ( p_t & M ) ^ q_t ^ r_t;
	// ……
}

したがって1回目の逆関数では追加した if 文を通ることはないので、その影響で \(q_{t}\) の値が誤りになることはないと言える。


以下、まとめ。

更新逆関数 Prev()

以上を踏まえると、更新関数 Next() およびその逆関数 Prev() は次のように書ける。
(更新逆関数 Prev() は無駄を省くなど改良してあるが、根幹は同じである。)

// C#
// field
private const uint MSB = 0x80000000u;   // 最上位ビット
private const uint LSB = 0x00000001u;   // 最下位ビット
private const int BIT_SIZE = 32;        // 変数のビット長
private const int SH0 = 1;              // TINYMT32_SH0
private const int SH1 = 10;             // TINYMT32_SH1
private const uint MASK = 0x7fffffffu;  // TINYMT32_MASK
private const uint MAT1 = 0x8f7011eeu;  // tinymt32_t.mat1
private const uint MAT2 = 0xfc78ff1fu;  // tinymt32_t.mat2
private uint[] status = new uint[4];    // tinymt32_t.status

// method
// 更新関数:元のソースの tinymt32_next_state(tinymt32_t * random) に相当
public void Next() {
	uint y = status[3];
	uint x = ( status[0] & MASK ) ^ status[1] ^ status[2];
	x ^= ( x << SH0 );
	y ^= ( y >> SH0 ) ^ x;
	status[0] = status[1];
	status[1] = status[2];
	status[2] = x ^ ( y << SH1 );
	status[3] = y;
	if( ( y & LSB ) == LSB ) {
		status[1] ^= MAT1;
		status[2] ^= MAT2;
	}
	return;
}
// 更新逆関数:Next() の逆関数
public void Prev() {
	uint y = status[3];
	if( ( y & LSB ) == LSB ) {
		status[1] ^= MAT1;
		status[2] ^= MAT2;
	}
	uint x = status[2] ^ ( y << SH1 );
	y ^= x;
	for( int sh=SH0;sh<BIT_SIZE;sh<<=1 ) {
		x ^= ( x << sh );
		y ^= ( y >> sh );
	}
	status[3] = y;
	status[2] = status[1];
	status[1] = status[0];
	status[0] = x ^ status[1] ^ status[2];
	if( ( status[0] & MSB ) == MSB ) {
		status[1] ^= MSB;
	}
	return;
}