正逆双方向に計算できる SFMT
正逆双方向に計算できる SFMT のクラスを作ってみた(※ただし周期が \(2^{19937}-1\) のものに限定した)。
逆算部分の理論はメルセンヌ・ツイスタの tempering の逆関数や TinyMT の更新関数の逆関数とほぼ変わらないので、詳細はそちらを参照。
速度を度外視して汎用性重視で作ったので高速化は各自で。
というか高速化プログラミング苦手だし、そもそも C# 使ってる時点で遅いし、みたいな。
それとおそらく SIMD-oriented な実装になってないので、その意味でも Fast じゃない。
結果が SFMT と同じになるだけ……
UInt128.cs
SFMT の更新の様子が分かりやすくなるように、 128ビットの符号なし整数型を使用する。
using System; using System.Collections.Generic; using System.Text; // C# namespace SFMTTest { internal struct UInt128 : IEquatable<UInt128> { #region constructor public UInt128(uint word3,uint word2,uint word1,uint word0) { this._w3 = word3; this._w2 = word2; this._w1 = word1; this._w0 = word0; } #endregion #region static field private const int WORD_SIZE = 32; #endregion #region static property public static UInt128 Zero { get { return new UInt128(0u,0u,0u,0u); } } public static UInt128 One { get { return new UInt128(0u,0u,0u,1u); } } #endregion #region field property // 第3ワード(最上位ワード) private readonly uint _w3; public uint Word3 { get { return _w3; } } // 第2ワード private readonly uint _w2; public uint Word2 { get { return _w2; } } // 第1ワード private readonly uint _w1; public uint Word1 { get { return _w1; } } // 第0ワード(最下位ワード) private readonly uint _w0; public uint Word0 { get { return _w0; } } #endregion #region property // 上位倍長ワード(上位64bit) public ulong DWordH { get { return ( (ulong)_w3 << 32 ) | _w2; } } // 下位倍長ワード(下位64bit) public ulong DWordL { get { return ( (ulong)_w1 << 32 ) | _w0; } } #endregion #region public bool IsZero { get { return _w0 == 0x0u && _w1 == 0x0u && _w2 == 0x0u && _w3 == 0x0u; } } #endregion #region indexer public uint this[int n] { get { switch( n & 3 ) { case 0: return _w0; case 1: return _w1; case 2: return _w2; case 3: return _w3; default: throw new IndexOutOfRangeException(); } } } #endregion #region property public string TextBin { get { return Convert.ToString(_w3,2).PadLeft(WORD_SIZE,'0') + Convert.ToString(_w2,2).PadLeft(WORD_SIZE,'0') + Convert.ToString(_w1,2).PadLeft(WORD_SIZE,'0') + Convert.ToString(_w0,2).PadLeft(WORD_SIZE,'0'); } } public string TextHex { get { return _w3.ToString("X08") + _w2.ToString("X08") + _w1.ToString("X08") + _w0.ToString("X08"); } } #endregion #region operator public static explicit operator UInt128(uint x) { return new UInt128(0u,0u,0u,x); } public static explicit operator UInt128(ulong x) { return new UInt128(0u,0u,(uint)( x>>32 ),(uint)x); } public static UInt128 operator&(UInt128 x,UInt128 y) { return new UInt128( x._w3 & y._w3, x._w2 & y._w2, x._w1 & y._w1, x._w0 & y._w0 ); } public static UInt128 operator^(UInt128 x,UInt128 y) { return new UInt128( x._w3 ^ y._w3, x._w2 ^ y._w2, x._w1 ^ y._w1, x._w0 ^ y._w0 ); } //public static UInt128 operator|(UInt128 x,UInt128 y) { // return new UInt128( // x._w3 | y._w3, // x._w2 | y._w2, // x._w1 | y._w1, // x._w0 | y._w0 // ); //} public static UInt128 operator<<(UInt128 x,int n) { n &= 0x7f; uint w3 = x._w3; uint w2 = x._w2; uint w1 = x._w1; uint w0 = x._w0; for( ;n>=WORD_SIZE;n-=WORD_SIZE ) { w3 = w2; w2 = w1; w1 = w0; w0 = 0x0u; } if( n > 0 ) { w3 = w3 << n | w2 >> ( WORD_SIZE-n ); w2 = w2 << n | w1 >> ( WORD_SIZE-n ); w1 = w1 << n | w0 >> ( WORD_SIZE-n ); w0 = w0 << n; } return new UInt128(w3,w2,w1,w0); } public static UInt128 operator>>(UInt128 x,int n) { n &= 0x7f; uint w3 = x._w3; uint w2 = x._w2; uint w1 = x._w1; uint w0 = x._w0; for( ;n>=WORD_SIZE;n-=WORD_SIZE ) { w0 = w1; w1 = w2; w2 = w3; w3 = 0x0u; } if( n > 0 ) { w0 = w0 >> n | w1 << ( WORD_SIZE-n ); w1 = w1 >> n | w2 << ( WORD_SIZE-n ); w2 = w2 >> n | w3 << ( WORD_SIZE-n ); w3 = w3 >> n; } return new UInt128(w3,w2,w1,w0); } #endregion #region override method public override string ToString() { return TextHex; } #endregion #region implement method public bool Equals(UInt128 other) { return this.Word0 == other.Word0 && this.Word1 == other.Word1 && this.Word2 == other.Word2 && this.Word3 == other.Word3; } #endregion } }
InvertibleSFMT19937.cs
分かりやすさを重視して上記の UInt128 構造体を使っている。ので遅い。
実際に使うなら uint[4] で実装したほうがいいと思う。
ポケモンサンムーンの乱数計算にしか使わないのなら、(32ビットの乱数生成をしないので) ulong[2] のほうがなおいい。タブンネ。
using System; using System.Collections.Generic; using System.Text; // C# namespace SFMTTest { public class InvertibleSFMT19937 { #region static constructor static InvertibleSFMT19937() { const uint MSK1 = 0xdfffffefu; // SFMT_MSK1 const uint MSK2 = 0xddfecb7fu; // SFMT_MSK2 const uint MSK3 = 0xbffaffffu; // SFMT_MSK3 const uint MSK4 = 0xbffffff6u; // SFMT_MSK4 const uint PARITY1 = 0x00000001u; // SFMT_PARITY1 const uint PARITY2 = 0x00000000u; // SFMT_PARITY2 const uint PARITY3 = 0x00000000u; // SFMT_PARITY3 const uint PARITY4 = 0x13c9e684u; // SFMT_PARITY4 const uint sr1Mask = 0xffffffffu >> SR1_BIT; const uint sl1Mask = 0xffffffffu << SL1_BIT; // initialize MSK_B MSK_B = new UInt128( // MSK_B = MSK4 & sr1Mask, // ( SFMT_MSK4 & ( 0xffffffffu >> SFMT_SR1 ) ) << 96 | MSK3 & sr1Mask, // ( SFMT_MSK3 & ( 0xffffffffu >> SFMT_SR1 ) ) << 64 | MSK2 & sr1Mask, // ( SFMT_MSK2 & ( 0xffffffffu >> SFMT_SR1 ) ) << 32 | MSK1 & sr1Mask // ( SFMT_MSK1 & ( 0xffffffffu >> SFMT_SR1 ) ) ); // initialize MSK_D MSK_D = new UInt128( // MSK_D = sl1Mask, // ( 0xffffffffu << SFMT_SL1 ) << 96 | sl1Mask, // ( 0xffffffffu << SFMT_SL1 ) << 64 | sl1Mask, // ( 0xffffffffu << SFMT_SL1 ) << 32 | sl1Mask // ( 0xffffffffu << SFMT_SL1 ) ); // initialize PARITY PARITY = new UInt128( // PARITY = PARITY4, // SFMT_PARITY4 << 96 | PARITY3, // SFMT_PARITY3 << 64 | PARITY2, // SFMT_PARITY2 << 32 | PARITY1 // SFMT_PARITY1 ); } #endregion #region constructor public InvertibleSFMT19937(uint seed) { InitGenRand(seed); } #endregion #region static field private const int MEXP = 19937; // SFMT_MEXP private const int N = MEXP / 128 + 1; // SFMT_N //private const int N32 = N * 4; // SFMT_N32 //private const int N64 = N * 2; // SFMT_N64 private const int POS1 = 122; // SFMT_POS1 private const int SL1_BIT = 18; // SFMT_SL1 // shift-left bit size private const int SL2_BYTE = 1; // SFMT_SL2 // shift-left byte size private const int SR1_BIT = 11; // SFMT_SR1 // shift-right bit size private const int SR2_BYTE = 1; // SFMT_SR2 // shift-right byte size private const int SL2_BIT = SL2_BYTE * 8; // // shift-left bit size private const int SR2_BIT = SR2_BYTE * 8; // // shift-right bit size private static readonly UInt128 MSK_B; // // initialize in static constructor private static readonly UInt128 MSK_D; // // initialize in static constructor private static readonly UInt128 PARITY; // // initialize in static constructor #endregion #region field private UInt128[] _state = new UInt128[N]; // 元のソースの sfmt_t.state に相当 private int _index = 0; // 元のソースの idx に相当 ※ idx = ( _index % N ) * 4 + _pos private int _pos = 0; // 同上 #endregion #region static method /// <summary>x を N で割った剰余(余り)を求める。</summary> /// <param name="x">剰余を求める値</param> /// <returns>x を N で割った剰余。 0 以上 N 未満の値を返す。</returns> private static int ModN(int x) { return x>=0 ? x % N : N - 1 - ( -1-x ) % N; } // 元のソースの do_recursion(w128_t *r, w128_t *a, w128_t *b, w128_t *c, w128_t *d) に相当 private static UInt128 DoRecursion(UInt128 a,UInt128 b,UInt128 c,UInt128 d) { UInt128 x = a << SL2_BIT; UInt128 y = c >> SR2_BIT; UInt128 r = a ^ x ^ ( ( b >> SR1_BIT ) & MSK_B ) ^ y ^ ( ( d << SL1_BIT ) & MSK_D ); return r; } // DoRecursion() の逆関数 private static UInt128 DoRecursionInv(UInt128 r,UInt128 b,UInt128 c,UInt128 d) { UInt128 y = c >> SR2_BIT; UInt128 a = r ^ ( ( b >> SR1_BIT ) & MSK_B ) ^ y ^ ( ( d << SL1_BIT ) & MSK_D ); for( int sh=SL2_BIT;sh<128;sh<<=1 ) { a ^= ( a << sh ); } return a; } #endregion #region method / init // 元のソースの sfmt_init_gen_rand(sfmt_t * sfmt, uint32_t seed) に相当 public void InitGenRand(uint seed) { uint[] arr = new uint[5]; arr[0] = seed; for( uint j=0;j<N;j++ ) { for( uint i=1;i<=4;i++ ) { arr[i] = 1812433253u * ( arr[i-1] ^ ( arr[i-1]>>30 ) ) + i + j * 4; } _state[j] = new UInt128(arr[3],arr[2],arr[1],arr[0]); arr[0] = arr[4]; } PeriodCertification(); _index = 0; _pos = 0; // テーブル全体を更新 for( int i=0;i<N;i++ ) { UpdateNext(); } return; } // 元のソースの period_certification(sfmt_t * sfmt) に相当 private void PeriodCertification() { // 立っているビット数の偶奇を数える UInt128 inner = _state[0] & PARITY; for( int i=64;i>0;i>>=1 ) { inner ^= inner >> i; } // 奇数個ならOK if( ( inner.Word0 & 1u ) == 1u ) { return; } // 偶数個だとNG → 修正 UInt128 work = UInt128.One; for( int i=0;i<128;i++ ) { if( !( work & PARITY ).IsZero ) { _state[0] ^= work; return; } work <<= 1; } return; } #endregion #region method / update // 元のソースの sfmt_gen_rand_all(sfmt_t * sfmt) 中の for ループの中身に相当 private void UpdateNext() { _state[ModN(_index)] = DoRecursion( _state[ModN(_index)], _state[ModN(_index+POS1)], _state[ModN(_index+N-2)], _state[ModN(_index+N-1)] ); _index++; _pos = 0; return; } // UpdateNext() の逆関数 private void UpdatePrev() { _index--; _pos = 3; _state[ModN(_index)] = DoRecursionInv( _state[ModN(_index)], _state[ModN(_index+POS1)], _state[ModN(_index+N-2)], _state[ModN(_index+N-1)] ); return; } #endregion #region method / get rand // 元のソースの sfmt_genrand_uint32(sfmt_t * sfmt) に相当 public uint NextUInt32() { uint r = _state[ModN(_index)][_pos++]; if( _pos >= 4 ) { UpdateNext(); } return r; } // 元のソースの sfmt_genrand_uint64(sfmt_t * sfmt) に相当 public ulong NextUInt64() { ulong next; switch( _pos & 3 ) { case 0: next = _state[ModN(_index)].DWordL; _pos = 2; return next; case 2: next = _state[ModN(_index)].DWordH; UpdateNext(); return next; default: // 元のソースの「 assert(sfmt->idx % 2 == 0); 」の部分に相当 throw new Exception("_posが奇数です"); } } // NextUInt32() の逆関数 public uint PrevUInt32() { _pos--; if( _pos < 0 ) { UpdatePrev(); } uint r = _state[ModN(_index)][_pos]; return r; } // NextUInt64() の逆関数 public ulong PrevUInt64() { ulong prev; switch( _pos & 3 ) { case 0: UpdatePrev(); _pos = 2; prev = _state[ModN(_index)].DWordH; return prev; case 2: _pos = 0; prev = _state[ModN(_index)].DWordL; return prev; default: // 元のソースの「 assert(sfmt->idx % 2 == 0); 」の部分に相当 throw new Exception("_posが奇数です"); } } #endregion } }