/**
File:		Core/InputOutput/NeMatlabIO.cpp

Author:		
Email:		
Site:       

Copyright (c) 2016 . All rights reserved.
*/

#include <NeCorePCH.h>
#include <Core/NeMatlabIO.h>

namespace NeuralEngine
{
	/*using namespace cv;
	using namespace std;*/

	bool MatlabIO::Open(std::string filename, std::string mode)
	{
		// open the file
		filename_ = filename;
		if (mode.compare("r") == 0) fid_.open(filename.c_str(), std::fstream::in | std::fstream::binary);
		if (mode.compare("w") == 0) fid_.open(filename.c_str(), std::fstream::out | std::fstream::binary);
		return !fid_.fail();
	}

	bool MatlabIO::Close(void) 
	{
		// close the file and release any associated objects
		fid_.close();
		return !fid_.fail();
	}

	void MatlabIO::TransposeMat(const cv::Mat& src, cv::Mat& dst)
	{
		if (src.channels() > 1) {
			std::vector<cv::Mat> vec;
			split(src, vec);
			for (unsigned int n = 0; n < vec.size(); ++n) {
				transpose(vec[n], vec[n]);
			}
			merge(vec, dst);
		}
		else {
			transpose(src, dst);
		}
	}

	void MatlabIO::GetHeader(void) 
	{
		// get the header information from the Mat file
		for (unsigned int n = 0; n < HEADER_LENGTH + 1; ++n) header_[n] = '\0';
		for (unsigned int n = 0; n < SUBSYS_LENGTH + 1; ++n) subsys_[n] = '\0';
		for (unsigned int n = 0; n < ENDIAN_LENGTH + 1; ++n) endian_[n] = '\0';
		fid_.read(header_, sizeof(char)*HEADER_LENGTH);
		fid_.read(subsys_, sizeof(char)*SUBSYS_LENGTH);
		fid_.read((char *)&version_, sizeof(int16_t));
		fid_.read(endian_, sizeof(char)*ENDIAN_LENGTH);

		// get the actual version
		if (version_ == 0x0100) version_ = VERSION_5;
		if (version_ == 0x0200) version_ = VERSION_73;

		// get the endianess
		if (strcmp(endian_, "IM") == 0) byte_swap_ = false;
		if (strcmp(endian_, "MI") == 0) byte_swap_ = true;
		// turn on byte swapping if necessary
		fid_.SetByteSwap(byte_swap_);

		//printf("Header: %s\nSubsys: %s\nVersion: %d\nEndian: %s\nByte Swap: %d\n", header_, subsys_, version_, endian_, byte_swap_);
		bytes_read_ = 128;
	}

	const char* MatlabIO::ReadVariableTag(uint32_t &data_type, uint32_t &dbytes, uint32_t &wbytes, const char *data) 
	{

		bool small_ = false;
		const uint32_t *datai = reinterpret_cast<const uint32_t *>(data);
		data_type = datai[0];

		if ((data_type >> 16) != 0) {
			// small data format
			dbytes = data_type >> 16;
			data_type = (data_type << 16) >> 16;
			small_ = true;
		}
		else {
			// regular format
			dbytes = datai[1];
		}

		// get the whole number of bytes (wbytes) consumed by this variable, including header and padding
		if (small_) wbytes = 8;
		else if (data_type == MAT_COMPRESSED) wbytes = 8 + dbytes;
		else wbytes = 8 + dbytes + ((8 - dbytes) % 8);

		// return the seek head positioned over the data payload
		return data + (small_ ? 4 : 8);
	}

	MatlabIOContainer MatlabIO::ConstructStruct(std::vector<char>& name, std::vector<uint32_t>& dims, std::vector<char>& real)
	{

		std::vector<std::vector<MatlabIOContainer> > array;
		const char* real_ptr = &(real[0]);
		// get the length of each field
		uint32_t length_type;
		uint32_t length_dbytes;
		uint32_t length_wbytes;
		const char* length_ptr = ReadVariableTag(length_type, length_dbytes, length_wbytes, real_ptr);
		uint32_t length = reinterpret_cast<const uint32_t*>(length_ptr)[0];

		// get the total number of fields
		uint32_t nfields_type;
		uint32_t nfields_dbytes;
		uint32_t nfields_wbytes;
		const char* nfields_ptr = ReadVariableTag(nfields_type, nfields_dbytes, nfields_wbytes, real_ptr + length_wbytes);
		assert((nfields_dbytes % length) == 0);
		uint32_t nfields = nfields_dbytes / length;

		// populate a vector of field names
		std::vector<std::string> field_names;
		for (unsigned int n = 0; n < nfields; ++n) {
			field_names.push_back(std::string(nfields_ptr + (n*length)));
		}

		// iterate through each of the cells and construct the matrices
		const char* field_ptr = real_ptr + length_wbytes + nfields_wbytes;
		for (unsigned int m = 0; m < Product<uint32_t>(dims); ++m) {
			std::vector<MatlabIOContainer> strct;
			for (unsigned int n = 0; n < nfields; ++n) {

				MatlabIOContainer field;
				uint32_t data_type;
				uint32_t dbytes;
				uint32_t wbytes;
				const char* data_ptr = ReadVariableTag(data_type, dbytes, wbytes, field_ptr);
				assert(data_type == MAT_MATRIX);
				field = CollateMatrixFields(data_type, dbytes, std::vector<char>(data_ptr, data_ptr + dbytes));
				field.SetName(field_names[n]);
				strct.push_back(field);
				field_ptr += wbytes;
			}
			array.push_back(strct);
		}
		return MatlabIOContainer(std::string(&(name[0])), array);
	}

	MatlabIOContainer MatlabIO::ConstructCell(std::vector<char>& name, std::vector<uint32_t>& dims, std::vector<char>& real)
	{

		std::vector<MatlabIOContainer> cell;
		char* field_ptr = &(real[0]);
		for (unsigned int n = 0; n < Product<uint32_t>(dims); ++n) {
			MatlabIOContainer field;
			uint32_t data_type;
			uint32_t dbytes;
			uint32_t wbytes;
			const char* data_ptr = ReadVariableTag(data_type, dbytes, wbytes, field_ptr);
			//printf("cell data_type: %d,  dbytes: %d\n", data_type, dbytes);
			assert(data_type == MAT_MATRIX);
			field = CollateMatrixFields(data_type, dbytes, std::vector<char>(data_ptr, data_ptr + dbytes));
			cell.push_back(field);
			field_ptr += wbytes;
		}
		return MatlabIOContainer(std::string(&(name[0])), cell);
	}

	MatlabIOContainer MatlabIO::ConstructSparse(std::vector<char>&, std::vector<uint32_t>&, std::vector<char>&, std::vector<char>&)
	{

		MatlabIOContainer variable;
		return variable;
	}

	
	MatlabIOContainer MatlabIO::ConstructString(std::vector<char>& name, std::vector<uint32_t>&, std::vector<char>& real)
	{
		// make sure the data is null terminated
		real.push_back('\0');
		return MatlabIOContainer(std::string(&(name[0])), std::string(&(real[0])));
	}

	MatlabIOContainer MatlabIO::CollateMatrixFields(uint32_t, uint32_t, std::vector<char> data)
	{
		// get the flags
		bool complx = data[9] & (1 << 3);
		//bool logical = data[9] & (1 << 1);

		// get the type of the encapsulated data
		char enc_data_type = data[8];
		// the preamble size is 16 bytes
		uint32_t pre_wbytes = 16;

		// get the dimensions
		uint32_t dim_type;
		uint32_t dim_dbytes;
		uint32_t dim_wbytes;
		const char* dim_data = ReadVariableTag(dim_type, dim_dbytes, dim_wbytes, &(data[pre_wbytes]));
		std::vector<uint32_t> dims(reinterpret_cast<const uint32_t *>(dim_data), reinterpret_cast<const uint32_t *>(dim_data + dim_dbytes));

		// get the variable name
		uint32_t name_type;
		uint32_t name_dbytes;
		uint32_t name_wbytes;
		const char* name_data = ReadVariableTag(name_type, name_dbytes, name_wbytes, &(data[pre_wbytes + dim_wbytes]));
		std::vector<char> name(name_data, name_data + name_dbytes);
		name.push_back('\0');
		//printf("The variable name is: %s\n", &(name[0]));

		// if the encoded data type is a cell array, bail out now
		if (enc_data_type == MAT_CELL_CLASS) {
			std::vector<char> real(data.begin() + pre_wbytes + dim_wbytes + name_wbytes, data.end());
			return ConstructCell(name, dims, real);
		}
		else if (enc_data_type == MAT_STRUCT_CLASS) {
			std::vector<char> real(data.begin() + pre_wbytes + dim_wbytes + name_wbytes, data.end());
			return ConstructStruct(name, dims, real);
		}

		// get the real data
		uint32_t real_type;
		uint32_t real_dbytes;
		uint32_t real_wbytes;
		const char* real_data = ReadVariableTag(real_type, real_dbytes, real_wbytes, &(data[pre_wbytes + dim_wbytes + name_wbytes]));
		std::vector<char> real(real_data, real_data + real_dbytes);
		//printf("The variable type is: %d\n", enc_data_type);
		//printf("Total number of bytes in data segment: %d\n", real_dbytes);

		std::vector<char> imag;
		if (complx) {
			// get the imaginery data
			uint32_t imag_type;
			uint32_t imag_dbytes;
			uint32_t imag_wbytes;
			const char* imag_data = ReadVariableTag(imag_type, imag_dbytes, imag_wbytes, &(data[pre_wbytes + dim_wbytes + name_wbytes + real_wbytes]));
			assert(imag_type == real_type);
			for (; imag_data != imag_data + imag_dbytes; imag_data++) imag.push_back(*imag_data);
		}

		// construct whatever object we happened to get
		MatlabIOContainer variable;
		switch (enc_data_type) 
		{
			// integral types
		case MAT_INT8_CLASS:      variable = ConstructMatrix<int8_t>(name, dims, real, imag, real_type); break;
		case MAT_UINT8_CLASS:     variable = ConstructMatrix<uint8_t>(name, dims, real, imag, real_type); break;
		case MAT_INT16_CLASS:     variable = ConstructMatrix<int16_t>(name, dims, real, imag, real_type); break;
		case MAT_UINT16_CLASS:    variable = ConstructMatrix<uint16_t>(name, dims, real, imag, real_type); break;
		case MAT_INT32_CLASS:     variable = ConstructMatrix<int32_t>(name, dims, real, imag, real_type); break;
		case MAT_UINT32_CLASS:    variable = ConstructMatrix<uint32_t>(name, dims, real, imag, real_type); break;
		case MAT_FLOAT_CLASS:     variable = ConstructMatrix<float>(name, dims, real, imag, real_type); break;
		case MAT_DOUBLE_CLASS:    variable = ConstructMatrix<double>(name, dims, real, imag, real_type); break;
		case MAT_INT64_CLASS:     variable = ConstructMatrix<int64_t>(name, dims, real, imag, real_type); break;
		case MAT_UINT64_CLASS:    variable = ConstructMatrix<uint64_t>(name, dims, real, imag, real_type); break;
		case MAT_CHAR_CLASS:      variable = ConstructString(name, dims, real); break;
			// sparse types
		case MAT_SPARSE_CLASS:    variable = ConstructSparse(name, dims, real, imag); break;
			// non-handled types
		case MAT_OBJECT_CLASS:	  break;
		default: 				  break;
		}
		return variable;
	}

	std::vector<char> MatlabIO::UncompressVariable(uint32_t& data_type, uint32_t& dbytes, uint32_t& wbytes, const std::vector<char> &data)
	{
		// setup the inflation parameters
		char buf[8];
		z_stream infstream;
		infstream.zalloc = Z_NULL;
		infstream.zfree = Z_NULL;
		infstream.opaque = Z_NULL;
		int ok = inflateInit(&infstream);
		if (ok != Z_OK) { std::cerr << "Unable to inflate variable" << std::endl; exit(-5); }

		// inflate the variable header
		infstream.avail_in = (uInt)data.size();
		infstream.next_in = (unsigned char *)&(data[0]);
		infstream.avail_out = 8;
		infstream.next_out = (unsigned char *)&buf;
		ok = inflate(&infstream, Z_NO_FLUSH);
		if (ok != Z_OK) { std::cerr << "Unable to inflate variable" << std::endl; exit(-5); }

		// get the headers
		ReadVariableTag(data_type, dbytes, wbytes, buf);

		// inflate the remainder of the variable, now that we know its size
		char *udata_tmp = new char[dbytes];
		infstream.avail_out = dbytes;
		infstream.next_out = (unsigned char *)udata_tmp;
		inflate(&infstream, Z_FINISH);
		inflateEnd(&infstream);

		// convert to a vector
		std::vector<char> udata(udata_tmp, udata_tmp + dbytes);
		delete[] udata_tmp;
		return udata;

	}

	MatlabIOContainer MatlabIO::ReadVariable(uint32_t data_type, uint32_t nbytes, const std::vector<char> &data)
	{

		// interpret the data
		MatlabIOContainer variable;
		switch (data_type) 
		{
		case MAT_COMPRESSED:
		{
			// uncompress the data
			uint32_t udata_type;
			uint32_t udbytes;
			uint32_t uwbytes;
			std::vector<char> udata = UncompressVariable(udata_type, udbytes, uwbytes, data);
			variable = ReadVariable(udata_type, udbytes, udata);
			break;
		}
		case MAT_MATRIX:
		{
			// deserialize the matrix
			variable = CollateMatrixFields(data_type, nbytes, data);
			break;
		}
		default: break;
		}
		return variable;
	}

	MatlabIOContainer MatlabIO::ReadBlock(void) 
	{

		// allocate the output
		MatlabIOContainer variable;

		// get the data type and number of bytes consumed
		// by this variable. Check to see if it's using
		// the small data format (seriously, who thought of that? You save at best 8 bytes...)
		uint32_t data_type;
		uint32_t dbytes;
		uint32_t wbytes;
		char buf[8];
		fid_.Read(buf, sizeof(char) * 8);
		ReadVariableTag(data_type, dbytes, wbytes, buf);

		// read the binary data block
		//printf("\nReading binary data block...\n"); fflush(stdout);
		char *data_tmp = new char[dbytes];
		fid_.read(data_tmp, sizeof(char)*dbytes);
		std::vector<char> data(data_tmp, data_tmp + dbytes);
		delete[] data_tmp;

		// move the seek head position to the next 64-bit boundary
		// (but only if the data is uncompressed. Saving yet another 8 tiny bytes...)
		if (data_type != MAT_COMPRESSED) {
			//printf("Aligning seek head to next 64-bit boundary...\n");
			std::streampos head_pos = fid_.tellg();
			int padding = head_pos % 8;
			fid_.seekg(padding, std::fstream::cur);
		}

		// now read the variable contained in the block
		return ReadVariable(data_type, dbytes, data);
	}

	std::vector<MatlabIOContainer> MatlabIO::Read(void) 
	{

		// allocate the output
		std::vector<MatlabIOContainer> variables;

		// read the header information
		GetHeader();

		// get all of the variables
		while (HasVariable()) {

			MatlabIOContainer variable;
			variable = ReadBlock();
			variables.push_back(variable);
		}
		return variables;
	}

	void MatlabIO::Whos(std::vector<MatlabIOContainer> variables) const
	{

		// get the longest filename
		unsigned int flmax = 0;
		for (uint n = 0; n < (uint)variables.size(); ++n) if (variables[n].Name().length() > flmax) flmax = variables[n].Name().length();

		printf("-------------------------\n");
		printf("File: %s\n", filename_.c_str());
		printf("%s\n", header_);
		printf("Variables:\n");
		for (unsigned int n = 0; n < variables.size(); ++n) {
			printf("%*s:  %s\n", flmax, variables[n].Name().c_str(), variables[n].Type().c_str());
		}
		printf("-------------------------\n");
		fflush(stdout);
	}
}