#ifndef HELPER_MATH_H
#define HELPER_MATH_H

#include <cuda_runtime.h>
#include "cuda_fp16.h"
#include <stdio.h>

// Overwrite *

__device__ __forceinline__ float2 operator*(float2 a, float b)
{
    return make_float2(a.x * b, a.y * b);   
}

__device__ __forceinline__ float4 operator*(float4 a, float b)
{
    return make_float4(a.x * b, a.y * b, a.z * b, a.w * b);
}

// Overwrite max

__device__ __forceinline__ void vmax(float &a, float b)
{
    a = a > b ? a : b;
}


__device__ __forceinline__ void vmax(float2 &a, float2 b)
{
    a.x = a.x > b.x ? a.x : b.x;
    a.y = a.y > b.y ? a.y : b.y;
}

__device__ __forceinline__ void vmax(float4 &a, float4 b)
{
    a.x = a.x > b.x ? a.x : b.x;
    a.y = a.y > b.y ? a.y : b.y;
    a.z = a.z > b.z ? a.z : b.z;
    a.w = a.w > b.w ? a.w : b.w;
}

__device__ __forceinline__ bool operator>(half a, half b)
{
    return __hgt(a, b);
}

// Array helpers

template <typename T, int Len>
struct Array {
    __device__ __forceinline__ Array(){}
    __device__ __forceinline__ Array(float value){
        for (int i = 0; i < Len; ++i){
            data[i] = value;
        }
    }
    T data[Len];

    __device__ __forceinline__ void reset(){
        for (int i = 0; i < Len; ++i){
            data[i] = 0.0f;
        }
    }
};


template <typename T, int Len>
struct ArrayAddFunc {
  __device__ __forceinline__ Array<T, Len> operator()(const Array<T, Len>& p1,
                                                      const Array<T, Len>& p2) {
    Array<T, Len> result;
    for (int i = 0; i < Len; ++i) {
      result.data[i] = p1.data[i] + p2.data[i];
    }
    return result;
  }
};

__device__ void print_val(int blockid, int threadid, float value){
    if (blockid == 0 && threadid == 0) printf("tid: %d, value is: %.8f\n", threadid, float(value));
}

#endif