#pragma once

#include <cstdint>
#include <cmath>
#include <memory>

#include <util/macros.h>
#include <containers/dense_matrix.h>

#include <gpu/contexts/device_context.h>
#include <gpu/containers/dense_matrix.h>
#include <gpu/containers/transfers.h>

#include <gpu/ops/dndn_matmul.h>
#include <gpu/ops/frobenius_product.h>
#include <gpu/ops/custom/ag_to_numerator.h>
#include <gpu/ops/custom/compute_w_ag.h>
#include <gpu/ops/custom/elwise_square.h>
#include <gpu/ops/custom/gradient_descent.h>
#include <gpu/ops/custom/hadamard_product.h>
#include <gpu/ops/custom/multiplicative_update.h>

#include "./config.h"


namespace npeff {
namespace factorizations {
namespace dn_lrm_factorization {


// Utility/helper stuff.
namespace internal {


// Struct to just have a place to put all of the device memory
// allocation pointers in one place.
struct DeviceAllocPtrs {
    // Scalars.
    float* minus_1 = nullptr;

    float* d_tr_WW_HH = nullptr;
    float* d_tr_WHX = nullptr;

    // The input matrix.
    float* d_A = nullptr;

    // Pointers to parameters to learn.
    float* d_W = nullptr;
    float* d_G = nullptr;

    // Allocations for holding intermediates:
    // 
    // Information about their sizes are included in their names.
    //      n = number of examples
    //      r = rank of decomposition
    //      c = number of classes
    //      m = number of parameters within this partition

    // Size = n * c * r.
    float* d_ncr1 = nullptr;
    float* d_ncr2 = nullptr;

    // Size = n * r.
    float* d_nr = nullptr;

    // All have sizes of r * r.
    float* d_rr1 = nullptr;
    float* d_rr2 = nullptr;
    float* d_rr3 = nullptr;

    // Size = r * m
    float* d_rm = nullptr;
};


// If a matrix gets modified in-place without any changes in
// its shape, I might not create a separate matrix entry here.
// 
// Note that the actual device memory chunks associates to each
// matrix can overlap. Furthermore, it will be impossible to 
// use some combinations of these matrices at the same time due
// to this.
struct DeviceMatrices {
    using DenseMatrixPtr = std::unique_ptr<gpu::DenseMatrix>;

    DenseMatrixPtr A;

    DenseMatrixPtr W;
    DenseMatrixPtr G;

    // Simple/common intermediates.
    DenseMatrixPtr WW;

    DenseMatrixPtr GG;
    DenseMatrixPtr HH;

    DenseMatrixPtr AG;

    // W-step specific intermediates.
    DenseMatrixPtr W_update_numerator;
    DenseMatrixPtr W_update_denominator;

    // G-step specific intermediates.
    DenseMatrixPtr WW_GG;
    DenseMatrixPtr W_AG;
    DenseMatrixPtr G_gradient;
};


}  // internal


// Worker associated to a single GPU.
class DeviceWorker {
    using DnDnMatMul = gpu::ops::DnDnMatMul;
    using FrobeniousInnerProduct = gpu::ops::FrobeniousInnerProduct;

    const FactorizationConfig config;
    const int64_t partition_index;
    const int64_t n_partitions;

    gpu::DeviceContext dctx;

    // Holds pointers to the device memory allocations. Simply
    // put here for convenience.
    internal::DeviceAllocPtrs d_ptrs;

    // Holds the device matrix objects.
    internal::DeviceMatrices d_ms;

    // NOTE: This will become a null pointer and the
    // associated matrix deleted once its data has
    // been moved to the GPU.
    std::unique_ptr<npeff::DenseMatrix<float>> host_matrix_partition;

public:

    DeviceWorker(
        // host_matrix_partition.shape = [n_examples * n_classes, n_columns_in_partition]
        std::unique_ptr<npeff::DenseMatrix<float>> host_matrix_partition,
        FactorizationConfig config,
        int64_t partition_index,
        int64_t n_partitions,
        ncclComm_t comm
    ) : 
        host_matrix_partition(std::move(host_matrix_partition)),
        config(config),
        partition_index(partition_index),
        n_partitions(n_partitions),
        dctx(gpu::DeviceContext(partition_index, comm, config.rand_gen_seed))
    {}


    /////////////////////////////////////////////////////////////////
    // Functions implemented in this header file.

    int64_t get_device() const { return dctx.device; }

    // Pointer to device memory.
    float* get_W_data_ptr() const { return (float*) d_ms.W->data; }

    void synchronize_stream() { dctx.synchronize_stream(); }


    /////////////////////////////////////////////////////////////////
    // Initialization stuff.

    // Must be called once before doing anything else. This allocates memory
    // on the device and moves data to the GPU. Parameters will be randomly
    // initialized.
    // 
    // NOTE: The W matrices (and any other parameters shared across multiple
    // devices) should/must be made consistent AFTER this is called.
    void set_up_work();

    void nccl_broadcast_of_W(DeviceWorker& src_worker);


    /////////////////////////////////////////////////////////////////
    // General stuff.

    void compute_local_AG_GG_async();

    void nccl_all_reduce_AG_GG();


    /////////////////////////////////////////////////////////////////
    // W-update stuff.

    void update_local_W_after_all_reduces_async();


    /////////////////////////////////////////////////////////////////
    // G-update stuff.

    void update_local_G_after_all_reduces_async(float learning_rate_G);


    /////////////////////////////////////////////////////////////////
    // Loss computation stuff.

    void compute_loss_after_all_reduces_async();

    // NOTE: The actual loss will be a fixed constant plus what this
    // function returns.
    float read_loss_term_from_device();


    /////////////////////////////////////////////////////////////////
    // Reading data back to host stuff.

    // The host_write_location must be on the host.
    void read_W_from_gpu_async(float* host_write_location);

    // The host_write_location must be on the host.
    void read_G_from_gpu_async(float* host_write_location);

private:

    void write_scalars_to_gpu();

    void allocate_and_create_device_matrices();

    void initialize_device_matrices();

    // The host_write_location must be on the host.
    void read_matrix_from_gpu_async(gpu::DenseMatrix& matrix, float* host_write_location);
};


}  // dn_lrm_factorization
}  // factorizations
}  // npeff
