Example: Product with FMA
typedef union { __m128 f4; float f[4]; } v4f;
float dot_product_fma(float *a, float *b, int n)
{
assert(n%4 == 0);
__m128 cv = _mm_set_ps(0.0f, 0.0f, 0.0f, 0.0f); // intermediate sum
for(int i = 0; i < n/4; i++)
{
__m128 av = _mm_load_ps(&a[4*i]); // load 4 float from a
__m128 bv = _mm_load_ps(&b[4*i]); // load 4 float from b
cv = _mm_fmadd_ps(av, bv, cv); // cv = av*bv + cv
}
v4f* res = (v4f*)&cv;
return res->f[0] + res->f[1] + res->f[2] + res->f[3];
}
int main()
{
float a[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0};
float b[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0};
float res = dot_product_fma(a, b, 8);
printf("%f\n", res); // res=204.0
return 0;
}Last updated