Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
linear_scaling.h
Go to the documentation of this file.
1 
2 /***************************************************************************************************
3  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without modification, are permitted
6  * provided that the following conditions are met:
7  * * Redistributions of source code must retain the above copyright notice, this list of
8  * conditions and the following disclaimer.
9  * * Redistributions in binary form must reproduce the above copyright notice, this list of
10  * conditions and the following disclaimer in the documentation and/or other materials
11  * provided with the distribution.
12  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
13  * to endorse or promote products derived from this software without specific prior written
14  * permission.
15  *
16  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
17  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
18  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
19  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
20  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
21  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
22  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24  *
25  **************************************************************************************************/
29 #pragma once
30 
32 
33 namespace cutlass {
34 namespace gemm {
35 
37 
38 template <typename T>
39 CUTLASS_DEVICE bool is_zero(T x) {
40  return x == T(0);
41 }
42 
43 #if !defined(__CUDACC_RTC__) || defined(CUTLASS_NVRTC_HAS_FP16)
44 CUTLASS_DEVICE bool is_zero(half x) { return reinterpret_cast<int16_t&>(x) == int16_t(0); }
45 #endif
46 
48 
50 template <typename Scalar_, typename FragmentMultiplyAdd_ = FragmentMultiplyAdd<Scalar_, Scalar_> >
51 struct LinearScaling {
52  // The scalar.
53  typedef Scalar_ Scalar;
54  // The accumulator Type
55  typedef typename FragmentMultiplyAdd_::ScalarAccum ScalarAccum;
56  // The adapater.
57  typedef FragmentMultiplyAdd_ FragmentMultiplyAdd;
58 
60  struct Params {
63 
64  //
65  // Methods
66  //
67 
68  // Constructor
70  Params(Scalar _alpha = 0, Scalar _beta = 0) : alpha(_alpha), beta(_beta) {}
71 
74  alpha = _alpha;
75  beta = _beta;
76  return 0;
77  }
78 
80  template <typename GemmDesc_>
81  CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc) {
82  alpha = desc.alpha;
83  beta = desc.beta;
84  return 0;
85  }
86  };
87 
88  //
89  // Data members
90  //
91 
93 
94  //
95  // Methods
96  //
97 
99  CUTLASS_DEVICE LinearScaling() { }
100 
102  CUTLASS_DEVICE LinearScaling(Params const& _params) : params(_params) {}
103 
107  CUTLASS_DEVICE
108  bool source_required() const {
109  return !is_zero(params.beta);
110  }
111 
113  template <typename FragmentA_, typename FragmentB_>
114  CUTLASS_DEVICE void evaluate(FragmentA_ const& accum, FragmentB_& output) {
116  mad.multiply(params.alpha, accum, output);
117 
118  }
119 
121  template <typename ScalarAccum, typename ScalarOutput, int size>
122  CUTLASS_DEVICE void evaluate(ScalarAccum const *accum, ScalarOutput *output) {
123  Fragment<ScalarAccum, size> FragAccum;
124  Fragment<ScalarOutput, size> FragOutput;
125 #pragma unroll
126  for (int i = 0; i < size; i++) {
127  FragAccum[i] = accum[i];
128  FragOutput[i] = output[i];
129  }
130  evaluate(FragAccum, FragOutput);
131 #pragma unroll
132  for (int i = 0; i < size; i++) {
133  output[i] = FragOutput[i];
134  }
135  }
136 
138  template <typename FragmentA_, typename FragmentB_>
139  CUTLASS_DEVICE void evaluate(FragmentA_ const& accum, FragmentB_ const& old, FragmentB_& output) {
141  FragmentB_ tmp;
142  mad.multiply(params.beta, old, tmp);
143  mad.multiply_add(params.alpha, accum, tmp, output);
144  }
145 
147  template <typename ScalarAccum, typename ScalarOutput, int size>
148  CUTLASS_DEVICE void evaluate(ScalarAccum const *accum, ScalarOutput const *old, ScalarOutput *output) {
149  Fragment<ScalarAccum, size> FragAccum;
150  Fragment<ScalarOutput, size> FragOutput;
152 #pragma unroll
153  for (int i = 0; i < size; i++) {
154  FragAccum[i] = accum[i];
155  FragOutput[i] = output[i];
156  FragOld[i] = old[i];
157  }
158  evaluate(FragAccum, FragOld, FragOutput);
159 #pragma unroll
160  for (int i = 0; i < size; i++) {
161  output[i] = FragOutput[i];
162  }
163  }
164 };
165 
167 
168 } // namespace gemm
169 } // namespace cutlass
CUTLASS_HOST_DEVICE int initialize(Scalar _alpha, Scalar _beta)
Initialize the parameters.
Definition: linear_scaling.h:73
Definition: convert.h:33
Scalar alpha
The alpha/beta scaling params.
Definition: linear_scaling.h:62
CUTLASS_DEVICE bool source_required() const
Definition: linear_scaling.h:108
CUTLASS_DEVICE void evaluate(ScalarAccum const *accum, ScalarOutput *output)
Evaluate the functor, without using fragment in the API.
Definition: linear_scaling.h:122
CUTLASS_DEVICE void evaluate(FragmentA_ const &accum, FragmentB_ const &old, FragmentB_ &output)
Evaluate the functor.
Definition: linear_scaling.h:139
CUTLASS_DEVICE void evaluate(FragmentA_ const &accum, FragmentB_ &output)
Evaluate the functor.
Definition: linear_scaling.h:114
Scalar beta
Definition: linear_scaling.h:62
A template defining Fragment Concept.
Definition: fragment.h:99
Params params
Definition: linear_scaling.h:92
FragmentMultiplyAdd_::ScalarAccum ScalarAccum
Definition: linear_scaling.h:55
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const &desc)
Initialize the parameters.
Definition: linear_scaling.h:81
Defines multiply-add operations on fragments within a thread.
FragmentMultiplyAdd_ FragmentMultiplyAdd
Definition: linear_scaling.h:57
CUTLASS_DEVICE LinearScaling()
Ctor.
Definition: linear_scaling.h:99
CUTLASS_DEVICE bool is_zero(T x)
Definition: linear_scaling.h:39
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
CUTLASS_DEVICE LinearScaling(Params const &_params)
Ctor.
Definition: linear_scaling.h:102
The parameters.
Definition: linear_scaling.h:60
Functor to compute linear combination of fragments.
Definition: linear_scaling.h:51
Scalar_ Scalar
Definition: linear_scaling.h:53
CUTLASS_DEVICE void evaluate(ScalarAccum const *accum, ScalarOutput const *old, ScalarOutput *output)
Evaluate the functor, without using fragment in the API.
Definition: linear_scaling.h:148
CUTLASS_HOST_DEVICE Params(Scalar _alpha=0, Scalar _beta=0)
Definition: linear_scaling.h:70