Task * GPU,OMP,CPU;

Region * * * SOCKMEM,SYSMEM;
Region * * GPU FBMEM,ZCMEM;

# Layout taskname regionname memory AOS F_order;
Layout * * * SOA C_order; # Exact; # Align==128 Compact

InstanceLimit task_4 1; # controlled by command line in TacoMapper

CollectMemory task_4 *;

def block_primitive(IPoint x, ISpace y, MSpace z, int dim1, int dim2) {
    return x[dim1] * z.size[dim2] / y.size[dim1];
}

def cyclic_primitive(IPoint x, ISpace y, MSpace z, int dim1, int dim2) {
    return x[dim1] % z.size[dim2];
}

m_2d = Machine(GPU); # nodes * processors
m_1d = m_2d.merge(0, 1);

def same_point(Task task) {
    return m_2d[*task.parent.processor(m_2d)]; # same point as parent
}

def block1d(Task task) {
    return m_1d[block_primitive(task.ipoint, task.ispace, m_1d, 0, 0)];
}

def auto_cyclic_3d(Task task) {
    # though task is 3d, but task.ipoint[2] is always 0 (because task.ispace[2] = 1)
    linearized = task.ipoint[0] + task.ispace[0] * task.ipoint[1] 
                + task.ispace[0] * task.ispace[1] * task.ipoint[2];
    shard_result = linearized % m_2d.size[0]; # cyclic over #nodes
    slice_result = (linearized / m_2d.size[0]) % m_2d.size[1]; # then cyclic over #procs
    return m_2d[shard_result, slice_result];
}

def auto3d(Task task) {
    m_4d = m_2d.auto_split(0, task.ispace); # split the original 0 dim into 0,1,2 dim
    # subspace: task.ispace / m_4d[:-1]
    m_6d = m_4d.auto_split(3, task.ispace / m_4d[:-1]); # split the processor (previosly 1, now 3) dim into 3,4,5 dim w.r.t subspace
    upper = tuple(block_primitive(task.ipoint, task.ispace, m_6d, i, i) for i in (0,1,2));
    lower = tuple(cyclic_primitive(task.ipoint, task.ispace, m_6d, i, i + 3) for i in (0,1,2));
    return m_6d[*upper, *lower];
}

# 2nodes: rpoc=2 c=2 
# 4nodes: rpoc=4, c=1
# 8nodes: rpoc=4 c=2
# 16nodes: rpoc=8 c=1

SingleTaskMap task_4 same_point;
IndexTaskMap task_5 auto3d; # task_5 launch space: (rpoc, rpoc, c)
IndexTaskMap task_1,task_2, task_3 auto_cyclic_3d; # task_1, task_2, task_3: (rpoc, rpoc, 1), has PLACE_SHARD tag
IndexTaskMap init_cublas block1d;