# -*- coding: utf-8 -*-
import tensorflow as tf

# example_proto is a string, and function singCoil_parse_function(example_proto) will decode this string by define a format dics 
def singCoil_parse_function(example_proto):
    dics = {'k_real': tf.io.VarLenFeature(dtype=tf.float32),
            'k_imag': tf.io.VarLenFeature(dtype=tf.float32),
            'label_real': tf.io.VarLenFeature(dtype=tf.float32),
            'label_imag': tf.io.VarLenFeature(dtype=tf.float32),
            'k_shape': tf.io.VarLenFeature(dtype=tf.int64),
            'label_shape': tf.io.VarLenFeature(dtype=tf.int64)}
    # decode
    parsed_example = tf.io.parse_single_example(example_proto, dics)
    parsed_example['k_real'] = tf.sparse.to_dense(parsed_example['k_real'])
    parsed_example['k_imag'] = tf.sparse.to_dense(parsed_example['k_imag'])
    parsed_example['label_real'] = tf.sparse.to_dense(parsed_example['label_real'])
    parsed_example['label_imag'] = tf.sparse.to_dense(parsed_example['label_imag'])
    parsed_example['k_shape'] = tf.sparse.to_dense(parsed_example['k_shape'])
    parsed_example['label_shape'] = tf.sparse.to_dense(parsed_example['label_shape'])

    k = tf.complex(parsed_example['k_real'], parsed_example['k_imag'])
    label = tf.complex(parsed_example['label_real'], parsed_example['label_imag'])

    k = tf.reshape(k, parsed_example['k_shape'])
    label = tf.reshape(label, parsed_example['label_shape'])

    return k, label


def parse_function(example_proto):
    dics = {'k_real': tf.io.VarLenFeature(dtype=tf.float32),
            'k_imag': tf.io.VarLenFeature(dtype=tf.float32),
            'label_real': tf.io.VarLenFeature(dtype=tf.float32),
            'label_imag': tf.io.VarLenFeature(dtype=tf.float32),
            'csm_real': tf.io.VarLenFeature(dtype=tf.float32),
            'csm_imag': tf.io.VarLenFeature(dtype=tf.float32),
            'k_shape': tf.io.FixedLenFeature(shape=(4,), dtype=tf.int64),
            'img_shape': tf.io.FixedLenFeature(shape=(3,), dtype=tf.int64),
            'csm_shape': tf.io.FixedLenFeature(shape=(3,), dtype=tf.int64)}


    parsed_example = tf.io.parse_single_example(example_proto, dics)
    parsed_example['k_real'] = tf.sparse.to_dense(parsed_example['k_real'])
    parsed_example['k_imag'] = tf.sparse.to_dense(parsed_example['k_imag'])
    parsed_example['label_real'] = tf.sparse.to_dense(parsed_example['label_real'])
    parsed_example['label_imag'] = tf.sparse.to_dense(parsed_example['label_imag'])
    parsed_example['csm_real'] = tf.sparse.to_dense(parsed_example['csm_real'])
    parsed_example['csm_imag'] = tf.sparse.to_dense(parsed_example['csm_imag'])

    k = tf.complex(parsed_example['k_real'], parsed_example['k_imag'])
    label = tf.complex(parsed_example['label_real'], parsed_example['label_imag'])
    csm = tf.complex(parsed_example['csm_real'], parsed_example['csm_imag'])
    
    k = tf.reshape(k, parsed_example['k_shape'])
#     k = tf.transpose(k, (1,0,2,3))
    label = tf.reshape(label, parsed_example['img_shape'])
    csm = tf.reshape(csm, parsed_example['csm_shape'])

    # k = k / tf.cast(tf.reduce_max(tf.abs(k)), tf.complex64)
    # csm = csm / tf.cast(tf.reduce_max(tf.abs(csm)), tf.complex64)

    return k, label, csm