/***************************************************************************************************
 * Copyright (c) 2017-2020, NVIDIA CORPORATION.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 *modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright notice,
 *this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *notice, this list of conditions and the following disclaimer in the
 *documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the names of its
 *contributors may be used to endorse or promote products derived from this
 *software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
 *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
 *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING
 *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
 *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 **************************************************************************************************/

/*! \file
  \brief CUTLASS layout visualization example
*/

#pragma once

#include <algorithm>
#include <stdexcept>
#include <vector>

#include "cutlass/coord.h"
#include "cutlass/util/reference/host/tensor_foreach.h"

#include "register_layout.h"

/////////////////////////////////////////////////////////////////////////////////////////////////

/// Permits copying dynamic vectors into static-length vectors
template <typename TensorCoord, int Rank>
struct vector_to_coord {
    vector_to_coord(TensorCoord& coord, std::vector<int> const& vec) {
        coord[Rank - 1] = vec.at(Rank - 1);

        if (Rank > 1) {
            vector_to_coord<TensorCoord, Rank - 1>(coord, vec);
        }
    }
};

/// Permits copying dynamic vectors into static-length vectors
template <typename TensorCoord>
struct vector_to_coord<TensorCoord, 1> {
    vector_to_coord(TensorCoord& coord, std::vector<int> const& vec) {
        coord[0] = vec.at(0);
    }
};

/// Permits copying dynamic vectors into static-length vectors
template <typename TensorCoord>
struct vector_to_coord<TensorCoord, 0> {
    vector_to_coord(TensorCoord& coord, std::vector<int> const& vec) {}
};

/////////////////////////////////////////////////////////////////////////////////////////////////

template <typename T>
std::ostream& operator<<(std::ostream& out, std::vector<T> const& vec) {
    auto it = vec.begin();
    if (it != vec.end()) {
        out << *it;
        for (++it; it != vec.end(); ++it) {
            out << ", " << *it;
        }
    }
    return out;
}

/////////////////////////////////////////////////////////////////////////////////////////////////

/// Permits copying static-length vectors into dynamic vectors
template <typename TensorCoord, int Rank>
struct coord_to_vector {
    coord_to_vector(std::vector<int>& vec, TensorCoord const& coord) {
        vec.at(Rank - 1) = coord[Rank - 1];
        coord_to_vector<TensorCoord, Rank - 1>(vec, coord);
    }
};

/// Permits copying static-length vectors into dynamic vectors
template <typename TensorCoord>
struct coord_to_vector<TensorCoord, 1> {
    coord_to_vector(std::vector<int>& vec, TensorCoord const& coord) {
        vec.at(0) = coord[0];
    }
};

/// Permits copying static-length vectors into dynamic vectors
template <typename TensorCoord>
struct coord_to_vector<TensorCoord, 0> {
    coord_to_vector(std::vector<int>& vec, TensorCoord const& coord) {}
};

/////////////////////////////////////////////////////////////////////////////////////////////////

/// Structure representing an element in source memory
struct Element {
    std::vector<int> coord;  ///< logical coordinate of element (as vector)
    int offset;              ///< linear offset from source memory
    int color;               ///< enables coloring each element to indicate

    /// Default ctor
    inline Element() : offset(-1), color(0) {}

    /// Construct from logical coordinate and initial offset
    inline Element(std::vector<int> const& coord_, int offset_, int color_ = 0)
            : coord(coord_), offset(offset_), color(color_) {}

    /// Returns true if element is in a defined state
    inline bool valid() const { return offset >= 0; }
};

/////////////////////////////////////////////////////////////////////////////////////////////////

/// Visualizes memory layouts by constructing a 'shape'
template <typename Layout_>
class VisualizeLayout : public VisualizeLayoutBase {
public:
    using Layout = Layout_;
    using TensorCoord = typename Layout::TensorCoord;
    using Stride = typename Layout::Stride;

public:
    Options options;
    Layout layout;
    TensorCoord extent;
    std::vector<Element> elements;

public:
    /// Initializes the problem space
    VisualizeLayout() {}

    /// visualization method
    bool visualize(Options const& options_) {
        options = options_;

        if (options.extent.size() != TensorCoord::kRank) {
            std::cerr << "--extent must have rank " << TensorCoord::kRank
                      << " (given: " << options.extent.size() << ")"
                      << std::endl;

            return false;
        }

        vector_to_coord<TensorCoord, TensorCoord::kRank>(extent,
                                                         options.extent);

        // Construct the layout for a packed tensor
        if (options.stride.empty()) {
            layout = Layout::packed(extent);
        } else if (options.stride.size() != Stride::kRank) {
            std::cerr << "--stride must have rank " << Stride::kRank
                      << " (given: " << options.stride.size() << ")"
                      << std::endl;

            return false;
        } else {
            // Stride from
            Stride stride;
            vector_to_coord<Stride, Stride::kRank>(stride, options.stride);

            layout = Layout(stride);
        }

        // Resize elements, setting elements to 'undefined' state
        elements.resize(layout.capacity(extent));

        // enumerate points in tensor space and assign
        cutlass::reference::host::TensorForEachLambda(
                extent, [&](TensorCoord coord) {
                    std::vector<int> coord_vec(TensorCoord::kRank, 0);
                    coord_to_vector<TensorCoord, TensorCoord::kRank>(coord_vec,
                                                                     coord);

                    int offset = int(layout(coord));

                    if (offset >= int(elements.size())) {
                        std::cerr << "Layout error - " << coord_vec
                                  << " is out of range (computed offset: "
                                  << offset << ", capacity: " << elements.size()
                                  << std::endl;

                        throw std::out_of_range(
                                "(TensorForEach) layout error - coordinate out "
                                "of range");
                    }

                    elements.at(offset) = Element(coord_vec, offset);
                });

        return true;
    }

    /// Verifies the layout satisfies vectorization requirements
    bool verify(bool verbose, std::ostream& out) { return true; }

private:
    /// returns a pair (is_vectorizable, one_changing_rank) to determine if a
    /// vector exists (consecutive logical coordinates or uniformly invalid)
    /// at the given location.
    std::pair<bool, int> _is_vectorizable(int i) const {
        // (all elements are invalid) or
        // (all elements are valid AND
        //  exactly one rank is changing AND
        //  elements are consecutive)

        // Don't need vectorization.
        if (options.vectorize <= 2)
            return std::make_pair(false, -1);

        // Boundary check.
        if (i > elements.size() ||
            (i + options.vectorize - 1) > elements.size())
            return std::make_pair(false, -1);

        // Check if either all elements are valid or invalid.
        bool all_elements_invalid = std::all_of(
                elements.begin() + i, elements.begin() + i + options.vectorize,
                [](Element const& e) { return !e.valid(); });

        bool all_elements_valid = std::all_of(
                elements.begin() + i, elements.begin() + i + options.vectorize,
                [](Element const& e) { return e.valid(); });

        if (!all_elements_invalid && !all_elements_valid)
            return std::make_pair(false, -1);

        // From here, it is vectorizable.
        if (all_elements_invalid)
            return std::make_pair(true, -1);

        // Check if only exactly one rank is changing.
        int one_changing_rank = -1;
        for (int j = 0; j < options.vectorize; ++j) {
            for (int r = 0; r < TensorCoord::kRank; ++r) {
                if (elements.at(i + j).coord.at(r) !=
                    elements.at(i).coord.at(r)) {
                    if (one_changing_rank == -1) {
                        one_changing_rank = r;
                    } else if (one_changing_rank != r) {
                        return std::make_pair(false, -1);
                    }
                }
            }
        }

        return std::make_pair(true, one_changing_rank);
    }

    /// Prints a vector of elements
    void _print_vector(std::ostream& out, int i, int one_changing_rank) {
        Element const& base_element = elements.at(i);
        if (base_element.valid()) {
            out << "(";
            for (int r = 0; r < TensorCoord::kRank; ++r) {
                if (r) {
                    out << ", ";
                }

                if (r == one_changing_rank) {
                    out << base_element.coord.at(r) << ".."
                        << (base_element.coord.at(r) + options.vectorize - 1);
                } else {
                    out << base_element.coord.at(r);
                }
            }
            out << ")";
        } else {
            out << " ";
        }
    }

    /// Prints a single element
    void _print_element(std::ostream& out, int k) {
        Element const& element = elements.at(k);
        if (element.valid()) {
            out << "(";
            for (int v = 0; v < TensorCoord::kRank; ++v) {
                out << (v ? ", " : "") << element.coord.at(v);
            }
            out << ")";
        } else {
            out << " ";
        }
    }

public:
    /// Pretty-prints the layout to the console
    void print_csv(std::ostream& out, char delim = '|', char new_line = '\n') {
        int row = -1;

        for (int i = 0; i < int(elements.size()); i += options.vectorize) {
            if (i % options.output_shape.at(0)) {
                out << delim;
            } else {
                if (row >= 0) {
                    out << new_line;
                }
                ++row;
                if (row == options.output_shape.at(1)) {
                    out << new_line;
                    row = 0;
                }
            }

            auto is_vector = _is_vectorizable(i);

            if (is_vector.first) {
                _print_vector(out, i,
                              is_vector.second);  // print a vector starting at
                                                  // element i
            } else {
                for (int j = 0; j < options.vectorize;
                     ++j) {  // print individual elements [i..i+j)
                    _print_element(out, i + j);
                }
            }
        }

        out << new_line << std::flush;
    }

    /// Help message
    virtual std::ostream& print_help(std::ostream& out) {
        out << "TensorCoord rank " << TensorCoord::kRank
            << ", Stride rank: " << Stride::kRank;
        return out;
    }
};

/////////////////////////////////////////////////////////////////////////////////////////////////
