#pragma once
#include <cstdint>
#include <cmath>
#include <memory>
#include <mutex>

#include <util/macros.h>
#include <containers/dense_matrix.h>
#include <inputs/dn_pefs/dn_lrm_pefs.h>

#include <gpu/contexts/device_context.h>
#include <gpu/containers/dense_matrix.h>
#include <gpu/containers/transfers.h>

#include <gpu/ops/dndn_matmul.h>
#include <gpu/ops/frobenius_product.h>
#include <gpu/ops/custom/ag_to_numerator.h>
#include <gpu/ops/custom/compute_w_ag.h>
#include <gpu/ops/custom/elwise_square.h>
#include <gpu/ops/custom/gradient_descent.h>
#include <gpu/ops/custom/hadamard_product.h>
#include <gpu/ops/custom/multiplicative_update.h>

#include "./config.h"


namespace npeff {
namespace factorizations {
namespace dn_lrm_factorization2 {


// Utility/helper stuff.
namespace internal {


// Struct to just have a place to put all of the device memory
// allocation pointers in one place.
struct DeviceAllocPtrs {
    // Scalars.
    float* minus_1 = nullptr;

    float* d_tr_WW_HH = nullptr;
    float* d_tr_WHX = nullptr;

    // The input matrix.
    float* d_A = nullptr;

    // Pointers to parameters to learn.
    float* d_W = nullptr;
    float* d_G = nullptr;

    // Allocations for holding intermediates:
    // 
    // Information about their sizes are included in their names.
    //      n = number of examples within this partition
    //      r = rank of decomposition
    //      c = number of classes
    //      m = number of parameters

    // Size = n * c * r.
    float* d_ncr1 = nullptr;
    float* d_ncr2 = nullptr;

    // Size = n * r.
    float* d_nr = nullptr;

    // All have sizes of r * r.
    float* d_rr1 = nullptr;
    float* d_rr2 = nullptr;
    float* d_rr3 = nullptr;
    float* d_rr4 = nullptr;

    // Size = r * m
    float* d_rm = nullptr;
};

// If a matrix gets modified in-place without any changes in
// its shape, I might not create a separate matrix entry here.
// 
// Note that the actual device memory chunks associates to each
// matrix can overlap. Furthermore, it will be impossible to 
// use some combinations of these matrices at the same time due
// to this.
struct DeviceMatrices {
    using DenseMatrixPtr = std::unique_ptr<gpu::DenseMatrix>;

    DenseMatrixPtr A;

    DenseMatrixPtr W;
    DenseMatrixPtr G;

    // Simple/common intermediates.
    DenseMatrixPtr WW;

    DenseMatrixPtr GG;
    DenseMatrixPtr HH;

    DenseMatrixPtr AG;

    // W-step specific intermediates.
    DenseMatrixPtr W_update_numerator;
    DenseMatrixPtr W_update_denominator;

    // G-step specific intermediates.
    DenseMatrixPtr WW_GG;
    DenseMatrixPtr W_AG;
    DenseMatrixPtr G_gradient;
};


}  // internal


// Worker associated to a single GPU.
class DeviceWorker {
    using DnLrmPefsPartitionPtr = std::unique_ptr<inputs::DnLrmPefsPartition>;

    const FactorizationConfig config;
    const int64_t partition_index;
    const int64_t n_partitions;
    // const int64_t examples_offset;

    gpu::DeviceContext dctx;

    // Holds pointers to the device memory allocations. Simply
    // put here for convenience.
    internal::DeviceAllocPtrs d_ptrs;

    // Holds the device matrix objects.
    internal::DeviceMatrices d_ms;

    DnLrmPefsPartitionPtr host_partition;

public:

    DeviceWorker(
        DnLrmPefsPartitionPtr host_partition,
        FactorizationConfig config,
        int64_t partition_index,
        int64_t n_partitions,
        // int64_t examples_offset,
        ncclComm_t comm
    ) : 
        host_partition(std::move(host_partition)),
        config(config),
        partition_index(partition_index),
        n_partitions(n_partitions),
        // examples_offset(examples_offset),
        dctx(gpu::DeviceContext(partition_index, comm, config.rand_gen_seed))
    {}

    /////////////////////////////////////////////////////////////////
    // Functions implemented in this header file.

    int64_t get_device() const { return dctx.device; }

    // Pointer to device memory.
    float* get_G_data_ptr() const { return (float*) d_ms.G->data; }

    void synchronize_stream() { dctx.synchronize_stream(); }


    /////////////////////////////////////////////////////////////////
    // Initialization stuff.

    // Must be called once before doing anything else. This allocates memory
    // on the device and moves data to the GPU. Parameters will be randomly
    // initialized.
    // 
    // NOTE: The G matrices (and any other parameters shared across multiple
    // devices) should/must be made consistent AFTER this is called.
    void set_up_work();

    void set_G_from_host(npeff::DenseMatrix<float>* initial_G);

    void nccl_broadcast_of_G(DeviceWorker& src_worker);

    /////////////////////////////////////////////////////////////////
    // General stuff.

    // Computes these matrix product sfor the examples in this partition.
    void compute_partition_AG_async();
    void compute_partition_GG_async();

    void compute_local_WW_async();
    void nccl_all_reduce_WW();

    void nccl_all_reduce_G_gradient();

    /////////////////////////////////////////////////////////////////
    // W-update stuff.

    // Options for the W-update step, mostly controlling what to recompute.
    struct WUpdateOptions {
        bool recompute_partition_AG = true;
        bool recompute_partition_GG = true;
    };

    // Updates the portion of the coefficients matrix W located on this partition.
    void update_partition_W_async(WUpdateOptions opts);

    /////////////////////////////////////////////////////////////////
    // G-update stuff.

    // Options for the G-update step, mostly controlling what to recompute.
    struct GUpdateOptions {
        float learning_rate;
        bool recompute_WW = true;
        bool recompute_partition_AG = true;
        bool recompute_partition_GG = true;
    };

    // TODO: Add documentation to the methods below.
    void update_G_step_pre_all_reduces_async(GUpdateOptions opts);
    void update_G_step_all_reduces_async(GUpdateOptions opts);
    void update_G_step_post_all_reduces_async(GUpdateOptions opts);

    /////////////////////////////////////////////////////////////////
    // Loss computation stuff.

    // Options for the loss computation, mostly controlling what to recompute.
    struct LossComputationOptions {
        bool recompute_partition_AG = true;
        bool recompute_partition_GG = true;
    };

    // NOTE: This will invalidate the data stored at the WW matrix.
    void compute_partition_loss_async(LossComputationOptions opts);

    // Returns the loss restricted to this partition, ignoring the tr_XX term.
    void read_partition_loss_from_device(std::mutex* mtx, float* loss);

    /////////////////////////////////////////////////////////////////
    // Reading data back to host stuff.

    // The host_write_location must be on the host.
    void read_W_from_gpu_async(float* host_write_location);

    // The host_write_location must be on the host.
    void read_G_from_gpu_async(float* host_write_location);

private:

    void write_scalars_to_gpu();
    void allocate_and_create_device_matrices();
    void initialize_device_matrices();

    // The host_write_location must be on the host.
    void read_matrix_from_gpu_async(gpu::DenseMatrix& matrix, float* host_write_location);
};


}  // dn_lrm_factorization2
}  // factorizations
}  // npeff
