正逆双方向に計算できる SFMT

正逆双方向に計算できる SFMT のクラスを作ってみた(※ただし周期が \(2^{19937}-1\) のものに限定した)。
逆算部分の理論はメルセンヌ・ツイスタの tempering の逆関数TinyMT の更新関数の逆関数とほぼ変わらないので、詳細はそちらを参照。



速度を度外視して汎用性重視で作ったので高速化は各自で。
というか高速化プログラミング苦手だし、そもそも C# 使ってる時点で遅いし、みたいな。
それとおそらく SIMD-oriented な実装になってないので、その意味でも Fast じゃない。
結果が SFMT と同じになるだけ……

UInt128.cs

SFMT の更新の様子が分かりやすくなるように、 128ビットの符号なし整数型を使用する。

using System;
using System.Collections.Generic;
using System.Text;

// C#
namespace SFMTTest {
	internal struct UInt128 : IEquatable<UInt128> {
		#region constructor
		public UInt128(uint word3,uint word2,uint word1,uint word0) {
			this._w3 = word3;
			this._w2 = word2;
			this._w1 = word1;
			this._w0 = word0;
		}
		#endregion

		#region static field
		private const int WORD_SIZE = 32;
		#endregion

		#region static property
		public static UInt128 Zero { get { return new UInt128(0u,0u,0u,0u); } }
		public static UInt128 One { get { return new UInt128(0u,0u,0u,1u); } }
		#endregion

		#region field property
		// 第3ワード(最上位ワード)
		private readonly uint _w3;
		public uint Word3 { get { return _w3; } }
		// 第2ワード
		private readonly uint _w2;
		public uint Word2 { get { return _w2; } }
		// 第1ワード
		private readonly uint _w1;
		public uint Word1 { get { return _w1; } }
		// 第0ワード(最下位ワード)
		private readonly uint _w0;
		public uint Word0 { get { return _w0; } }
		#endregion
		#region property
		// 上位倍長ワード(上位64bit)
		public ulong DWordH { get { return ( (ulong)_w3 << 32 ) | _w2; } }
		// 下位倍長ワード(下位64bit)
		public ulong DWordL { get { return ( (ulong)_w1 << 32 ) | _w0; } }
		#endregion

		#region
		public bool IsZero {
			get {
				return _w0 == 0x0u
				    && _w1 == 0x0u
				    && _w2 == 0x0u
				    && _w3 == 0x0u;
			}
		}
		#endregion

		#region indexer
		public uint this[int n] {
			get {
				switch( n & 3 ) {
				case 0: return _w0;
				case 1: return _w1;
				case 2: return _w2;
				case 3: return _w3;
				default: throw new IndexOutOfRangeException();
				}
			}
		}
		#endregion

		#region property
		public string TextBin {
			get {
				return Convert.ToString(_w3,2).PadLeft(WORD_SIZE,'0')
				     + Convert.ToString(_w2,2).PadLeft(WORD_SIZE,'0')
				     + Convert.ToString(_w1,2).PadLeft(WORD_SIZE,'0')
				     + Convert.ToString(_w0,2).PadLeft(WORD_SIZE,'0');
			}
		}
		public string TextHex {
			get {
				return _w3.ToString("X08")
				     + _w2.ToString("X08")
				     + _w1.ToString("X08")
				     + _w0.ToString("X08");
			}
		}
		#endregion

		#region operator
		public static explicit operator UInt128(uint x) {
			return new UInt128(0u,0u,0u,x);
		}
		public static explicit operator UInt128(ulong x) {
			return new UInt128(0u,0u,(uint)( x>>32 ),(uint)x);
		}
		public static UInt128 operator&(UInt128 x,UInt128 y) {
			return new UInt128(
				x._w3 & y._w3,
				x._w2 & y._w2,
				x._w1 & y._w1,
				x._w0 & y._w0
			);
		}
		public static UInt128 operator^(UInt128 x,UInt128 y) {
			return new UInt128(
				x._w3 ^ y._w3,
				x._w2 ^ y._w2,
				x._w1 ^ y._w1,
				x._w0 ^ y._w0
			);
		}
		//public static UInt128 operator|(UInt128 x,UInt128 y) {
		//	return new UInt128(
		//		x._w3 | y._w3,
		//		x._w2 | y._w2,
		//		x._w1 | y._w1,
		//		x._w0 | y._w0
		//	);
		//}
		public static UInt128 operator<<(UInt128 x,int n) {
			n &= 0x7f;

			uint w3 = x._w3;
			uint w2 = x._w2;
			uint w1 = x._w1;
			uint w0 = x._w0;
			for( ;n>=WORD_SIZE;n-=WORD_SIZE ) {
				w3 = w2;
				w2 = w1;
				w1 = w0;
				w0 = 0x0u;
			}
			if( n > 0 ) {
				w3 = w3 << n | w2 >> ( WORD_SIZE-n );
				w2 = w2 << n | w1 >> ( WORD_SIZE-n );
				w1 = w1 << n | w0 >> ( WORD_SIZE-n );
				w0 = w0 << n;
			}
			return new UInt128(w3,w2,w1,w0);
		}
		public static UInt128 operator>>(UInt128 x,int n) {
			n &= 0x7f;

			uint w3 = x._w3;
			uint w2 = x._w2;
			uint w1 = x._w1;
			uint w0 = x._w0;
			for( ;n>=WORD_SIZE;n-=WORD_SIZE ) {
				w0 = w1;
				w1 = w2;
				w2 = w3;
				w3 = 0x0u;
			}
			if( n > 0 ) {
				w0 = w0 >> n | w1 << ( WORD_SIZE-n );
				w1 = w1 >> n | w2 << ( WORD_SIZE-n );
				w2 = w2 >> n | w3 << ( WORD_SIZE-n );
				w3 = w3 >> n;
			}
			return new UInt128(w3,w2,w1,w0);
		}
		#endregion

		#region override method
		public override string ToString() { return TextHex; }
		#endregion
		#region implement method
		public bool Equals(UInt128 other) {
			return this.Word0 == other.Word0
				&& this.Word1 == other.Word1
				&& this.Word2 == other.Word2
				&& this.Word3 == other.Word3;
		}
		#endregion

	}
}

InvertibleSFMT19937.cs

分かりやすさを重視して上記の UInt128 構造体を使っている。ので遅い。
実際に使うなら uint[4] で実装したほうがいいと思う。
ポケモンサンムーンの乱数計算にしか使わないのなら、(32ビットの乱数生成をしないので) ulong[2] のほうがなおいい。タブンネ

using System;
using System.Collections.Generic;
using System.Text;

// C#
namespace SFMTTest {
	public class InvertibleSFMT19937 {
		#region static constructor
		static InvertibleSFMT19937() {
			const uint MSK1 = 0xdfffffefu;      // SFMT_MSK1
			const uint MSK2 = 0xddfecb7fu;      // SFMT_MSK2
			const uint MSK3 = 0xbffaffffu;      // SFMT_MSK3
			const uint MSK4 = 0xbffffff6u;      // SFMT_MSK4
			const uint PARITY1 = 0x00000001u;   // SFMT_PARITY1
			const uint PARITY2 = 0x00000000u;   // SFMT_PARITY2
			const uint PARITY3 = 0x00000000u;   // SFMT_PARITY3
			const uint PARITY4 = 0x13c9e684u;   // SFMT_PARITY4
			const uint sr1Mask = 0xffffffffu >> SR1_BIT;
			const uint sl1Mask = 0xffffffffu << SL1_BIT;

			// initialize MSK_B
			MSK_B = new UInt128(                // MSK_B =
			                MSK4 & sr1Mask,     //     ( SFMT_MSK4 & ( 0xffffffffu >> SFMT_SR1 ) ) << 96 |
			                MSK3 & sr1Mask,     //     ( SFMT_MSK3 & ( 0xffffffffu >> SFMT_SR1 ) ) << 64 |
			                MSK2 & sr1Mask,     //     ( SFMT_MSK2 & ( 0xffffffffu >> SFMT_SR1 ) ) << 32 |
			                MSK1 & sr1Mask      //     ( SFMT_MSK1 & ( 0xffffffffu >> SFMT_SR1 ) )
			            );
			// initialize MSK_D
			MSK_D = new UInt128(                // MSK_D =
			                sl1Mask,            //     ( 0xffffffffu << SFMT_SL1 ) << 96 |
			                sl1Mask,            //     ( 0xffffffffu << SFMT_SL1 ) << 64 |
			                sl1Mask,            //     ( 0xffffffffu << SFMT_SL1 ) << 32 |
			                sl1Mask             //     ( 0xffffffffu << SFMT_SL1 )
			            );

			// initialize PARITY
			PARITY = new UInt128(               // PARITY =
			                PARITY4,            //     SFMT_PARITY4 << 96 |
			                PARITY3,            //     SFMT_PARITY3 << 64 |
			                PARITY2,            //     SFMT_PARITY2 << 32 |
			                PARITY1             //     SFMT_PARITY1
			            );
		}
		#endregion
		#region constructor
		public InvertibleSFMT19937(uint seed) {
			InitGenRand(seed);
		}
		#endregion

		#region static field
		private const int MEXP = 19937;             // SFMT_MEXP
		private const int N = MEXP / 128 + 1;       // SFMT_N
		//private const int N32 = N * 4;            // SFMT_N32
		//private const int N64 = N * 2;            // SFMT_N64
		private const int POS1 = 122;               // SFMT_POS1
		private const int SL1_BIT  = 18;            // SFMT_SL1   // shift-left  bit  size
		private const int SL2_BYTE = 1;             // SFMT_SL2   // shift-left  byte size
		private const int SR1_BIT  = 11;            // SFMT_SR1   // shift-right bit  size
		private const int SR2_BYTE = 1;             // SFMT_SR2   // shift-right byte size
		private const int SL2_BIT  = SL2_BYTE * 8;  //            // shift-left  bit  size
		private const int SR2_BIT  = SR2_BYTE * 8;  //            // shift-right bit  size
		private static readonly UInt128 MSK_B;      //            // initialize in static constructor
		private static readonly UInt128 MSK_D;      //            // initialize in static constructor
		private static readonly UInt128 PARITY;     //            // initialize in static constructor
		#endregion
		#region field
		private UInt128[] _state = new UInt128[N];  // 元のソースの sfmt_t.state に相当
		private int _index = 0;                     // 元のソースの idx に相当    ※ idx = ( _index % N ) * 4 + _pos
		private int _pos = 0;                       // 同上
		#endregion

		#region static method
		/// <summary>x を N で割った剰余(余り)を求める。</summary>
		/// <param name="x">剰余を求める値</param>
		/// <returns>x を N で割った剰余。 0 以上 N 未満の値を返す。</returns>
		private static int ModN(int x) {
			return x>=0 ? x % N : N - 1 - ( -1-x ) % N;
		}
		// 元のソースの do_recursion(w128_t *r, w128_t *a, w128_t *b, w128_t *c, w128_t *d) に相当
		private static UInt128 DoRecursion(UInt128 a,UInt128 b,UInt128 c,UInt128 d) {
			UInt128 x = a << SL2_BIT;
			UInt128 y = c >> SR2_BIT;
			UInt128 r = a ^ x ^ ( ( b >> SR1_BIT ) & MSK_B ) ^ y ^ ( ( d << SL1_BIT ) & MSK_D );
			return r;
		}
		// DoRecursion() の逆関数
		private static UInt128 DoRecursionInv(UInt128 r,UInt128 b,UInt128 c,UInt128 d) {
			UInt128 y = c >> SR2_BIT;
			UInt128 a = r ^ ( ( b >> SR1_BIT ) & MSK_B ) ^ y ^ ( ( d << SL1_BIT ) & MSK_D );
			for( int sh=SL2_BIT;sh<128;sh<<=1 ) {
				a ^= ( a << sh );
			}
			return a;
		}
		#endregion

		#region method / init
		// 元のソースの sfmt_init_gen_rand(sfmt_t * sfmt, uint32_t seed) に相当
		public void InitGenRand(uint seed) {
			uint[] arr = new uint[5];
			arr[0] = seed;
			for( uint j=0;j<N;j++ ) {
				for( uint i=1;i<=4;i++ ) {
					arr[i] = 1812433253u * ( arr[i-1] ^ ( arr[i-1]>>30 ) ) + i + j * 4;
				}
				_state[j] = new UInt128(arr[3],arr[2],arr[1],arr[0]);
				arr[0] = arr[4];
			}
			PeriodCertification();

			_index = 0;
			_pos = 0;

			// テーブル全体を更新
			for( int i=0;i<N;i++ ) {
				UpdateNext();
			}
			return;
		}
		// 元のソースの period_certification(sfmt_t * sfmt) に相当
		private void PeriodCertification() {
			// 立っているビット数の偶奇を数える
			UInt128 inner = _state[0] & PARITY;
			for( int i=64;i>0;i>>=1 ) {
				inner ^= inner >> i;
			}
			// 奇数個ならOK
			if( ( inner.Word0 & 1u ) == 1u ) {
				return;
			}

			// 偶数個だとNG → 修正
			UInt128 work = UInt128.One;
			for( int i=0;i<128;i++ ) {
				if( !( work & PARITY ).IsZero ) {
					_state[0] ^= work;
					return;
				}
				work <<= 1;
			}
			return;
		}
		#endregion
		#region method / update
		// 元のソースの sfmt_gen_rand_all(sfmt_t * sfmt) 中の for ループの中身に相当
		private void UpdateNext() {
			_state[ModN(_index)] = DoRecursion(
			                            _state[ModN(_index)],
			                            _state[ModN(_index+POS1)],
			                            _state[ModN(_index+N-2)],
			                            _state[ModN(_index+N-1)]
			                        );
			_index++;
			_pos = 0;
			return;
		}
		// UpdateNext() の逆関数
		private void UpdatePrev() {
			_index--;
			_pos = 3;
			_state[ModN(_index)] = DoRecursionInv(
			                            _state[ModN(_index)],
			                            _state[ModN(_index+POS1)],
			                            _state[ModN(_index+N-2)],
			                            _state[ModN(_index+N-1)]
			                        );
			return;
		}
		#endregion
		#region method / get rand
		// 元のソースの sfmt_genrand_uint32(sfmt_t * sfmt) に相当
		public uint NextUInt32() {
			uint r = _state[ModN(_index)][_pos++];
			if( _pos >= 4 ) {
				UpdateNext();
			}
			return r;
		}
		// 元のソースの sfmt_genrand_uint64(sfmt_t * sfmt) に相当
		public ulong NextUInt64() {
			ulong next;
			switch( _pos & 3 ) {
			case 0:
				next = _state[ModN(_index)].DWordL;
				_pos = 2;
				return next;
			case 2:
				next = _state[ModN(_index)].DWordH;
				UpdateNext();
				return next;
			default:
				// 元のソースの「 assert(sfmt->idx % 2 == 0); 」の部分に相当
				throw new Exception("_posが奇数です");
			}
		}
		// NextUInt32() の逆関数
		public uint PrevUInt32() {
			_pos--;
			if( _pos < 0 ) {
				UpdatePrev();
			}
			uint r = _state[ModN(_index)][_pos];
			return r;
		}
		// NextUInt64() の逆関数
		public ulong PrevUInt64() {
			ulong prev;
			switch( _pos & 3 ) {
			case 0:
				UpdatePrev();
				_pos = 2;
				prev = _state[ModN(_index)].DWordH;
				return prev;
			case 2:
				_pos = 0;
				prev = _state[ModN(_index)].DWordL;
				return prev;
			default:
				// 元のソースの「 assert(sfmt->idx % 2 == 0); 」の部分に相当
				throw new Exception("_posが奇数です");
			}
		}
		#endregion

	}
}