Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
linear_scaling_device_ptr.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
28 #pragma once
29 
30 #include "cutlass/cutlass.h"
33 
34 namespace cutlass {
35 
37 
38 namespace gemm {
39 
41 
45 template <typename Scalar_, typename FragmentMultiplyAdd_ = FragmentMultiplyAdd<Scalar_, Scalar_> >
46 struct LinearScalingDevicePtr : public LinearScaling<Scalar_, FragmentMultiplyAdd_> {
47 
50 
51  // The scalar.
52  typedef typename Base::Scalar Scalar;
53 
55  class Params {
56  private:
59 
62 
63  public:
64  //
65  // Methods
66  //
67 
68  // Constructor
70  Params() {}
71 
72  // Constructor
75  Scalar alpha,
76  Scalar beta
77  ):
78  alpha_(alpha),
79  beta_(beta) {}
80 
81  // Constructor
84  Scalar const *alpha_ptr,
85  Scalar const *beta_ptr
86  ):
87  alpha_(alpha_ptr),
88  beta_(alpha_ptr) {}
89 
92  Scalar alpha,
93  Scalar beta) {
94 
95  alpha_ = alpha;
96  beta_ = beta;
97 
98  return 0;
99  }
100 
103  Scalar const *alpha,
104  Scalar const *beta) {
105 
106  alpha_ = alpha;
107  beta_= beta;
108 
109  return 0;
110  }
111 
113  template <typename GemmDesc_>
114  CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc) {
115 
116  alpha_ = desc.alpha;
117  beta_ = desc.beta;
118 
119  return 0;
120  }
121 
124  Scalar alpha() const {
125  return alpha_;
126  }
127 
130  Scalar beta() const {
131  return beta_;
132  }
133  };
134 
135  //
136  // Methods
137  //
138 
141  this->params.alpha = _params.alpha();
142  this->params.beta = _params.beta();
143  }
144 };
145 
147 
148 } // namespace gemm
149 } // namespace cutlass
CUTLASS_HOST_DEVICE int initialize(Scalar const *alpha, Scalar const *beta)
Initialize the parameters.
Definition: linear_scaling_device_ptr.h:102
The parameters.
Definition: linear_scaling_device_ptr.h:55
Definition: convert.h:33
CUTLASS_HOST_DEVICE Params(Scalar const *alpha_ptr, Scalar const *beta_ptr)
Definition: linear_scaling_device_ptr.h:83
Implements the BLAS linear scaling function alpha*AB + beta*C.
Implements the BLAS linear scaling function alpha*AB + beta*C.
CUTLASS_HOST_DEVICE int initialize(Scalar alpha, Scalar beta)
Initialize the parameters.
Definition: linear_scaling_device_ptr.h:91
Params params
Definition: linear_scaling.h:92
LinearScaling< Scalar_, FragmentMultiplyAdd_ > Base
Linear Scaling class used.
Definition: linear_scaling_device_ptr.h:49
CUTLASS_HOST_DEVICE Params()
Definition: linear_scaling_device_ptr.h:70
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
CUTLASS_HOST_DEVICE Params(Scalar alpha, Scalar beta)
Definition: linear_scaling_device_ptr.h:74
CUTLASS_HOST_DEVICE Scalar beta() const
Gets the beta scalar.
Definition: linear_scaling_device_ptr.h:130
CUTLASS_HOST_DEVICE LinearScalingDevicePtr(Params const &_params)
Ctor.
Definition: linear_scaling_device_ptr.h:140
CUTLASS_HOST_DEVICE Scalar alpha() const
Gets the alpha scalar.
Definition: linear_scaling_device_ptr.h:124
Definition: linear_scaling_device_ptr.h:46
Functor to compute linear combination of fragments.
Definition: linear_scaling.h:51
Scalar_ Scalar
Definition: linear_scaling.h:53
Base::Scalar Scalar
Definition: linear_scaling_device_ptr.h:52
Basic include for CUTLASS macros.
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const &desc)
Initialize the parameters.
Definition: linear_scaling_device_ptr.h:114