Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
igemm_swizzle.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
28 #pragma once
29 
30 #include "cutlass/fragment.h"
31 
32 namespace cutlass {
33 namespace gemm {
34 
36 
37 template <typename GlobalIterator_>
38 struct IgemmSwizzle {
40  typedef GlobalIterator_ GlobalIterator;
42  typedef typename GlobalIterator::Fragment Fragment;
44  typedef typename GlobalIterator::FragmentShape FragmentShape;
45 
50 
53 
55  static_assert(FragmentShape::kH % 4 == 0 && ShapeCount<FragmentShape>::kWc % 4 == 0,
56  "Not multiple of 4");
57 
59  CUTLASS_DEVICE IgemmSwizzle() {}
60 
62  CUTLASS_DEVICE void transform(Fragment const& src, Fragment& dst) {
63  // Expose src/dst as int arrays.
64  int const* src_int = reinterpret_cast<int const*>(&src[0]);
65  int* dst_int = reinterpret_cast<int*>(&dst[0]);
66 
67  // Transpose the data.
68  for (int d = 0; d < FragmentShape::kD; ++d) {
69  for (int h = 0; h < FragmentShape::kH / 4; ++h) {
70  for (int w = 0; w < ShapeCount<FragmentShape>::kWc / 4; ++w) {
71  int const i0 = d * (ShapeCount<FragmentShape>::kHwc / 4) +
72  (4 * h + 0) * (ShapeCount<FragmentShape>::kWc / 4) + w;
73  int const i1 = d * (ShapeCount<FragmentShape>::kHwc / 4) +
74  (4 * h + 1) * (ShapeCount<FragmentShape>::kWc / 4) + w;
75  int const i2 = d * (ShapeCount<FragmentShape>::kHwc / 4) +
76  (4 * h + 2) * (ShapeCount<FragmentShape>::kWc / 4) + w;
77  int const i3 = d * (ShapeCount<FragmentShape>::kHwc / 4) +
78  (4 * h + 3) * (ShapeCount<FragmentShape>::kWc / 4) + w;
79 
80  int a0 = src_int[i0];
81  int a1 = src_int[i1];
82  int a2 = src_int[i2];
83  int a3 = src_int[i3];
84 
85  // // DEBUG.
86  // if (threadIdx.x == 0) {
87  // printf("a=0x%08x 0x%08x 0x%08x 0x%08x\n", a0, a1, a2, a3);
88  // }
89 
90  int b0, b1, b2, b3, c0;
91  asm volatile("prmt.b32 %0, %1, %2, 0x0040;" : "=r"(b0) : "r"(a0), "r"(a1));
92  asm volatile("prmt.b32 %0, %1, %2, 0x0040;" : "=r"(c0) : "r"(a2), "r"(a3));
93  asm volatile("prmt.b32 %0, %1, %2, 0x5410;" : "=r"(b0) : "r"(b0), "r"(c0));
94 
95  asm volatile("prmt.b32 %0, %1, %2, 0x0051;" : "=r"(b1) : "r"(a0), "r"(a1));
96  asm volatile("prmt.b32 %0, %1, %2, 0x0051;" : "=r"(c0) : "r"(a2), "r"(a3));
97  asm volatile("prmt.b32 %0, %1, %2, 0x5410;" : "=r"(b1) : "r"(b1), "r"(c0));
98 
99  asm volatile("prmt.b32 %0, %1, %2, 0x0062;" : "=r"(b2) : "r"(a0), "r"(a1));
100  asm volatile("prmt.b32 %0, %1, %2, 0x0062;" : "=r"(c0) : "r"(a2), "r"(a3));
101  asm volatile("prmt.b32 %0, %1, %2, 0x5410;" : "=r"(b2) : "r"(b2), "r"(c0));
102 
103  asm volatile("prmt.b32 %0, %1, %2, 0x0073;" : "=r"(b3) : "r"(a0), "r"(a1));
104  asm volatile("prmt.b32 %0, %1, %2, 0x0073;" : "=r"(c0) : "r"(a2), "r"(a3));
105  asm volatile("prmt.b32 %0, %1, %2, 0x5410;" : "=r"(b3) : "r"(b3), "r"(c0));
106 
107  // // DEBUG.
108  // if (threadIdx.x == 0) {
109  // printf("b=0x%08x 0x%08x 0x%08x 0x%08x\n", b0, b1, b2, b3);
110  // }
111 
112  dst_int[i0] = b0;
113  dst_int[i1] = b1;
114  dst_int[i2] = b2;
115  dst_int[i3] = b3;
116  }
117  }
118  }
119  }
120 };
121 
123 
124 } // namespace gemm
125 } // namespace cutlass
Definition: convert.h:33
std::is_same (false specialization)
Definition: platform.h:420
GlobalIterator::FragmentShape FragmentShape
The shape of the source fragment.
Definition: igemm_swizzle.h:44
Definition: igemm_swizzle.h:38
GlobalIterator_ GlobalIterator
The global iterator.
Definition: igemm_swizzle.h:40
CUTLASS_DEVICE void transform(Fragment const &src, Fragment &dst)
Transform a fragment.
Definition: igemm_swizzle.h:62
Fragment OutputFragment
The destination fragment.
Definition: igemm_swizzle.h:49
#define static_assert(__e, __m)
Definition: platform.h:153
Fragment InputFragment
The source fragment.
Definition: igemm_swizzle.h:47
GlobalIterator::Fragment Fragment
The source fragment.
Definition: igemm_swizzle.h:42
CUTLASS_DEVICE IgemmSwizzle()
The src/dst must be int8 fragments.
Definition: igemm_swizzle.h:59
Defines Fragment, a statically-sized array for storing parts of matrices within a thread&#39;s registers...
Compute derived counted of a Layout Concept based class.
Definition: shape.h:79