function [uList,zDemoList,maskList,taskNum] = TrainBatchGen(batchSz,InputNoiseStd,taskNum)
if nargin<3
    %taskNum=randi(6);
    %make memory tasks 2x more likely than other tasks
    taskNums=1:6;
    taskNums=repelem(taskNums,[1 1 2 2 1 1]);
    taskNum=taskNums(randi(length(taskNums)));

end

angStim=rand(1,batchSz)*2*pi;

switch taskNum
    case {1,2} % Delayed pro/anti
        contextL=randi(21,1,batchSz)+14;
        stimL=randi(66,1,batchSz)+9;
        responseL=randi(21,1,batchSz)+14;
        L=contextL+stimL+responseL;
        Lmax=max(L);
        uList=zeros(9,batchSz,Lmax);
        zDemoList=zeros(3,batchSz,Lmax);
        maskList=zDemoList;

        for k=1:batchSz
            %fixation input
            uList(1,k,1:contextL(k)+stimL(k))=1;
            %stimulus input
            uList(2,k,contextL(k)+1:end)=cos(angStim(k));
            uList(3,k,contextL(k)+1:end)=sin(angStim(k));
            %rule input
            uList(3+taskNum,k,:)=1;
            %fixation output
            zDemoList(1,k,1:contextL(k)+stimL(k))=.8;
            %response output
            zDemoList(2,k,contextL(k)+stimL(k)+1:end)=cos(angStim(k)+pi*(taskNum==2));
            zDemoList(3,k,contextL(k)+stimL(k)+1:end)=sin(angStim(k)+pi*(taskNum==2));
            %mask
            maskList(1,k,6:contextL(k)+stimL(k))=1;
            maskList(1,k,contextL(k)+stimL(k)+6:L(k))=5;
            maskList(1,k,:)=(1/2)*maskList(1,k,:)/sum(maskList(1,k,:));
            maskList(2,k,:)=maskList(1,k,:)/2;
            maskList(3,k,:)=maskList(2,k,:);
        end


    case {3,4} % Memory pro/anti
        contextL=randi(21,1,batchSz)+14;
        stimL=randi(71,1,batchSz)+9;
        memoryL=randi(71,1,batchSz)+9;
        responseL=randi(21,1,batchSz)+14;
        L=contextL+stimL+memoryL+responseL;
        Lmax=max(L);
        uList=zeros(9,batchSz,Lmax);
        zDemoList=zeros(3,batchSz,Lmax);
        maskList=zDemoList;

        for k=1:batchSz
            %fixation input
            uList(1,k,1:contextL(k)+stimL(k)+memoryL(k))=1;
            %stimulus input
            uList(2,k,contextL(k)+1:contextL(k)+stimL(k))=cos(angStim(k));
            uList(3,k,contextL(k)+1:contextL(k)+stimL(k))=sin(angStim(k));
            %rule input
            uList(3+taskNum,k,:)=1;
            %fixation output
            zDemoList(1,k,1:contextL(k)+stimL(k)+memoryL(k))=.8;
            %response output
            zDemoList(2,k,contextL(k)+stimL(k)+memoryL(k)+1:end)=cos(angStim(k)+pi*(taskNum==4));
            zDemoList(3,k,contextL(k)+stimL(k)+memoryL(k)+1:end)=sin(angStim(k)+pi*(taskNum==4));
            %mask
            maskList(1,k,6:contextL(k)+stimL(k)+memoryL(k))=1;
            maskList(1,k,contextL(k)+stimL(k)+memoryL(k)+6:L(k))=5;
            maskList(1,k,:)=(1/2)*maskList(1,k,:)/sum(maskList(1,k,:));
            maskList(2,k,:)=maskList(1,k,:)/2;
            maskList(3,k,:)=maskList(2,k,:);
        end
    case {5,6} % React pro/anti
        contextL=randi(101,1,batchSz)+24;
        responseL=randi(71,1,batchSz)+14;
        L=contextL+responseL;
        Lmax=max(L);
        uList=zeros(9,batchSz,Lmax);
        zDemoList=zeros(3,batchSz,Lmax);
        maskList=zDemoList;

        for k=1:batchSz
            %fixation input
            uList(1,k,:)=1;
            %stimulus input
            uList(2,k,contextL(k)+1:end)=cos(angStim(k));
            uList(3,k,contextL(k)+1:end)=sin(angStim(k));
            %rule input
            uList(3+taskNum,k,:)=1;
            %fixation output
            zDemoList(1,k,1:contextL(k))=.8;
            %response output
            zDemoList(2,k,contextL(k)+1:end)=cos(angStim(k)+pi*(taskNum==6));
            zDemoList(3,k,contextL(k)+1:end)=sin(angStim(k)+pi*(taskNum==6));
            %mask
            maskList(1,k,6:contextL(k))=1;
            maskList(1,k,contextL(k)+6:L(k))=5;
            maskList(1,k,:)=(1/2)*maskList(1,k,:)/sum(maskList(1,k,:));
            maskList(2,k,:)=maskList(1,k,:)/2;
            maskList(3,k,:)=maskList(2,k,:);
        end
end

uList=uList+InputNoiseStd*randn(size(uList));

end