Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
hgemm_swizzle.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include <cuda_fp16.h>
32 #include "cutlass/fragment.h"
33 
34 namespace cutlass {
35 namespace gemm {
36 
38 
39 template <typename GlobalIterator_>
40 struct HgemmSwizzle {
42  typedef GlobalIterator_ GlobalIterator;
44  typedef typename GlobalIterator::Fragment Fragment;
46  typedef typename GlobalIterator::FragmentShape FragmentShape;
47 
52 
55 
57  static_assert(FragmentShape::kH == 2 && ShapeCount<FragmentShape>::kWc == 2, "Not multiple of 2");
58 
60  CUTLASS_DEVICE HgemmSwizzle() {}
61 
63  CUTLASS_DEVICE void transform(Fragment const& src, Fragment& dst) {
64  // Expose src/dst as int arrays.
65  int const* src_int = reinterpret_cast<int const*>(&src[0]);
66  int* dst_int = reinterpret_cast<int*>(&dst[0]);
67 
68  // Transpose the data.
69  for (int d = 0; d < FragmentShape::kD; ++d) {
70  // The indices to read two consecutive "rows".
71  int const i0 = 2 * d + 0;
72  int const i1 = 2 * d + 1;
73 
74  int a0 = src_int[i0];
75  int a1 = src_int[i1];
76 
77  int b0, b1;
78  asm volatile("prmt.b32 %0, %1, %2, 0x5410;" : "=r"(b0) : "r"(a0), "r"(a1));
79  asm volatile("prmt.b32 %0, %1, %2, 0x7632;" : "=r"(b1) : "r"(a0), "r"(a1));
80 
81  // The indices to store with "strides".
82  int const j0 = 0 * (ShapeCount<FragmentShape>::kDhw / 2) + d;
83  int const j1 = 1 * (ShapeCount<FragmentShape>::kDhw / 2) + d;
84 
85  dst_int[j0] = b0;
86  dst_int[j1] = b1;
87  }
88  }
89 };
90 
92 
93 } // namespace gemm
94 } // namespace cutlass
GlobalIterator_ GlobalIterator
The global iterator.
Definition: hgemm_swizzle.h:42
Definition: convert.h:33
std::is_same (false specialization)
Definition: platform.h:420
CUTLASS_DEVICE HgemmSwizzle()
The src/dst must be half fragments.
Definition: hgemm_swizzle.h:60
CUTLASS_DEVICE void transform(Fragment const &src, Fragment &dst)
Transform a fragment.
Definition: hgemm_swizzle.h:63
Fragment InputFragment
The input fragment.
Definition: hgemm_swizzle.h:49
Fragment OutputFragment
The output fragment.
Definition: hgemm_swizzle.h:51
#define static_assert(__e, __m)
Definition: platform.h:153
GlobalIterator::Fragment Fragment
The source fragment.
Definition: hgemm_swizzle.h:44
Defines Fragment, a statically-sized array for storing parts of matrices within a thread&#39;s registers...
GlobalIterator::FragmentShape FragmentShape
The shape of the source fragment.
Definition: hgemm_swizzle.h:46
Compute derived counted of a Layout Concept based class.
Definition: shape.h:79
Definition: hgemm_swizzle.h:40