diff --git a/pystencils/include/philox_rand.h b/pystencils/include/philox_rand.h index eca71a200fe10ed2fa0f5b2cf8f6aa11930c8acc..0571eb39d114c5ea29c974e9630744d0bbcc1540 100644 --- a/pystencils/include/philox_rand.h +++ b/pystencils/include/philox_rand.h @@ -1,5 +1,4 @@ -#include <cstdint> - +#ifndef __OPENCL_VERSION__ #if defined(__SSE2__) || defined(_MSC_VER) #include <emmintrin.h> // SSE2 #endif @@ -33,12 +32,15 @@ #ifdef __riscv_v #include <riscv_vector.h> #endif +#endif -#ifndef __CUDA_ARCH__ +#ifdef __CUDA_ARCH__ +#define QUALIFIERS static __forceinline__ __device__ +#elif defined(__OPENCL_VERSION__) +#define QUALIFIERS static inline +#else #define QUALIFIERS inline #include "myintrin.h" -#else -#define QUALIFIERS static __forceinline__ __device__ #endif #define PHILOX_W32_0 (0x9E3779B9) @@ -48,8 +50,15 @@ #define TWOPOW53_INV_DOUBLE (1.1102230246251565e-16) #define TWOPOW32_INV_FLOAT (2.3283064e-10f) +#ifdef __OPENCL_VERSION__ +#include "opencl_stdint.h" +typedef uint32_t uint32; +typedef uint64_t uint64; +#else +#include <cstdint> typedef std::uint32_t uint32; typedef std::uint64_t uint64; +#endif #if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_SVE_BITS) && __ARM_FEATURE_SVE_BITS > 0 typedef svfloat32_t svfloat32_st __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS))); @@ -67,6 +76,9 @@ QUALIFIERS uint32 mulhilo32(uint32 a, uint32 b, uint32* hip) #if defined(__powerpc__) && (!defined(__clang__) || defined(__xlC__)) *hip = __mulhwu(a,b); return a*b; +#elif defined(__OPENCL_VERSION__) + *hip = mul_hi(a,b); + return a*b; #else uint64 product = ((uint64)a) * ((uint64)b); *hip = product >> 32; @@ -106,7 +118,12 @@ QUALIFIERS double _uniform_double_hq(uint32 x, uint32 y) QUALIFIERS void philox_double2(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3, - uint32 key0, uint32 key1, double & rnd1, double & rnd2) + uint32 key0, uint32 key1, +#ifdef __OPENCL_VERSION__ + double * rnd1, double * rnd2) +#else + double & rnd1, double & rnd2) +#endif { uint32 key[2] = {key0, key1}; uint32 ctr[4] = {ctr0, ctr1, ctr2, ctr3}; @@ -121,14 +138,23 @@ QUALIFIERS void philox_double2(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9 _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10 +#ifdef __OPENCL_VERSION__ + *rnd1 = _uniform_double_hq(ctr[0], ctr[1]); + *rnd2 = _uniform_double_hq(ctr[2], ctr[3]); +#else rnd1 = _uniform_double_hq(ctr[0], ctr[1]); rnd2 = _uniform_double_hq(ctr[2], ctr[3]); +#endif } QUALIFIERS void philox_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3, uint32 key0, uint32 key1, +#ifdef __OPENCL_VERSION__ + float * rnd1, float * rnd2, float * rnd3, float * rnd4) +#else float & rnd1, float & rnd2, float & rnd3, float & rnd4) +#endif { uint32 key[2] = {key0, key1}; uint32 ctr[4] = {ctr0, ctr1, ctr2, ctr3}; @@ -143,13 +169,20 @@ QUALIFIERS void philox_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3 _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 9 _philox4x32bumpkey(key); _philox4x32round(ctr, key); // 10 +#ifdef __OPENCL_VERSION__ + *rnd1 = ctr[0] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f); + *rnd2 = ctr[1] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f); + *rnd3 = ctr[2] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f); + *rnd4 = ctr[3] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f); +#else rnd1 = ctr[0] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f); rnd2 = ctr[1] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f); rnd3 = ctr[2] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f); rnd4 = ctr[3] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f); +#endif } -#ifndef __CUDA_ARCH__ +#if !defined(__CUDA_ARCH__) && !defined(__OPENCL_VERSION__) #if defined(__SSE4_1__) || defined(_MSC_VER) QUALIFIERS void _philox4x32round(__m128i* ctr, __m128i* key) { @@ -318,13 +351,13 @@ QUALIFIERS void _philox4x32round(__vector unsigned int* ctr, __vector unsigned i { #ifndef _ARCH_PWR8 __vector unsigned int lo0 = vec_mul(ctr[0], vec_splats(PHILOX_M4x32_0)); - __vector unsigned int lo1 = vec_mul(ctr[2], vec_splats(PHILOX_M4x32_1)); __vector unsigned int hi0 = vec_mulhuw(ctr[0], vec_splats(PHILOX_M4x32_0)); + __vector unsigned int lo1 = vec_mul(ctr[2], vec_splats(PHILOX_M4x32_1)); __vector unsigned int hi1 = vec_mulhuw(ctr[2], vec_splats(PHILOX_M4x32_1)); #elif defined(_ARCH_PWR10) __vector unsigned int lo0 = vec_mul(ctr[0], vec_splats(PHILOX_M4x32_0)); - __vector unsigned int lo1 = vec_mul(ctr[2], vec_splats(PHILOX_M4x32_1)); __vector unsigned int hi0 = vec_mulh(ctr[0], vec_splats(PHILOX_M4x32_0)); + __vector unsigned int lo1 = vec_mul(ctr[2], vec_splats(PHILOX_M4x32_1)); __vector unsigned int hi1 = vec_mulh(ctr[2], vec_splats(PHILOX_M4x32_1)); #else __vector unsigned int lohi0a = (__vector unsigned int) vec_mule(ctr[0], vec_splats(PHILOX_M4x32_0)); @@ -675,8 +708,8 @@ QUALIFIERS void philox_double2(uint32 ctr0, int32x4_t ctr1, uint32 ctr2, uint32 QUALIFIERS void _philox4x32round(svuint32x4_t & ctr, svuint32x2_t & key) { svuint32_t lo0 = svmul_u32_x(svptrue_b32(), svget4_u32(ctr, 0), svdup_u32(PHILOX_M4x32_0)); - svuint32_t lo1 = svmul_u32_x(svptrue_b32(), svget4_u32(ctr, 2), svdup_u32(PHILOX_M4x32_1)); svuint32_t hi0 = svmulh_u32_x(svptrue_b32(), svget4_u32(ctr, 0), svdup_u32(PHILOX_M4x32_0)); + svuint32_t lo1 = svmul_u32_x(svptrue_b32(), svget4_u32(ctr, 2), svdup_u32(PHILOX_M4x32_1)); svuint32_t hi1 = svmulh_u32_x(svptrue_b32(), svget4_u32(ctr, 2), svdup_u32(PHILOX_M4x32_1)); ctr = svset4_u32(ctr, 0, sveor_u32_x(svptrue_b32(), sveor_u32_x(svptrue_b32(), hi1, svget4_u32(ctr, 1)), svget2_u32(key, 0))); @@ -827,8 +860,8 @@ QUALIFIERS void _philox4x32round(vuint32m1_t & ctr0, vuint32m1_t & ctr1, vuint32 vuint32m1_t key0, vuint32m1_t key1) { vuint32m1_t lo0 = vmul_vv_u32m1(ctr0, vmv_v_x_u32m1(PHILOX_M4x32_0, vsetvlmax_e32m1()), vsetvlmax_e32m1()); - vuint32m1_t lo1 = vmul_vv_u32m1(ctr2, vmv_v_x_u32m1(PHILOX_M4x32_1, vsetvlmax_e32m1()), vsetvlmax_e32m1()); vuint32m1_t hi0 = vmulhu_vv_u32m1(ctr0, vmv_v_x_u32m1(PHILOX_M4x32_0, vsetvlmax_e32m1()), vsetvlmax_e32m1()); + vuint32m1_t lo1 = vmul_vv_u32m1(ctr2, vmv_v_x_u32m1(PHILOX_M4x32_1, vsetvlmax_e32m1()), vsetvlmax_e32m1()); vuint32m1_t hi1 = vmulhu_vv_u32m1(ctr2, vmv_v_x_u32m1(PHILOX_M4x32_1, vsetvlmax_e32m1()), vsetvlmax_e32m1()); ctr0 = vxor_vv_u32m1(vxor_vv_u32m1(hi1, ctr1, vsetvlmax_e32m1()), key0, vsetvlmax_e32m1()); diff --git a/pystencils/rng.py b/pystencils/rng.py index f5a970e96c8539487515afa53705cf6cab280961..ec1d6214797a6078d40a8e26320d9c160e6a727f 100644 --- a/pystencils/rng.py +++ b/pystencils/rng.py @@ -53,8 +53,9 @@ class RNGBase(CustomCodeNode): else: code += f"{vector_instruction_set[r.dtype.base_name] if vector_instruction_set else r.dtype} " + \ f"{r.name};\n" - code += (self._name + "(" + ", ".join([print_arg(a) for a in self.args] - + [r.name for r in self.result_symbols]) + ");\n") + args = [print_arg(a) for a in self.args] + \ + [('&' if dialect == 'opencl' else '') + r.name for r in self.result_symbols] + code += (self._name + "(" + ", ".join(args) + ");\n") return code def __repr__(self): diff --git a/pystencils_tests/test_random.py b/pystencils_tests/test_random.py index 18ff23b7bb7b7dd4062957a07b84dd6dd04bbf70..d3686a7320bcbd12aad92e7791c4db48d91a6b78 100644 --- a/pystencils_tests/test_random.py +++ b/pystencils_tests/test_random.py @@ -21,12 +21,16 @@ if get_compiler_config()['os'] == 'windows': instruction_sets.remove('avx512') -@pytest.mark.parametrize('target,rng', (('cpu', 'philox'), ('cpu', 'aesni'), ('gpu', 'philox'))) +@pytest.mark.parametrize('target,rng', (('cpu', 'philox'), ('cpu', 'aesni'), ('gpu', 'philox'), ('opencl', 'philox'))) @pytest.mark.parametrize('precision', ('float', 'double')) @pytest.mark.parametrize('dtype', ('float', 'double')) def test_rng(target, rng, precision, dtype, t=124, offsets=(0, 0), keys=(0, 0), offset_values=None): if target == 'gpu': pytest.importorskip('pycuda') + if target == 'opencl': + pytest.importorskip('pyopencl') + from pystencils.opencl.opencljit import init_globally + init_globally() if instruction_sets and set(['neon', 'sve', 'vsx', 'rvv']).intersection(instruction_sets) and rng == 'aesni': pytest.xfail('AES not yet implemented for this architecture') if rng == 'aesni' and len(keys) == 2: