import os 
import torch, os
import numpy as np
import torch.nn as nn
import math, copy, pdb
import argparse
import torch.nn.functional as F
import math, copy, pdb
from torch.autograd import Variable
import numpy as np
from lookahead import Lookahead
def get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

class PositionalEncoder_fixed(nn.Module):
    def __init__(self, lenWord = 64, max_seq_len = 200, dropout = 0.1):
        super().__init__()
        self.lenWord = lenWord
        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(max_seq_len, lenWord)
        for pos in range(max_seq_len):
            for i in range(0, lenWord, 2):
                pe[pos, i] =  math.sin(pos / (10000 ** ((2 * i)/lenWord)))
                if lenWord != 1:
                    pe[pos, i + 1] =  math.cos(pos / (10000 ** ((2 * (i + 1))/lenWord)))
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
 
    
    def forward(self, x):
        x = x * math.sqrt(self.lenWord)
        seq_len = x.size(1)
        pe = Variable(self.pe[:,:seq_len], requires_grad=False)
        x = x + pe
        return self.dropout(x)

class PositionalEncoder(nn.Module):
    def __init__(self, SeqLen = 51, lenWord = 64):
        super().__init__()
        self.lenWord = lenWord
        self.pe = torch.nn.Parameter(torch.Tensor(51, lenWord), requires_grad = True)
        self.pe.data.uniform_(0.0, 1.0)
    def forward(self, x):
        seq_len = x.size(1)
        return x + self.pe[:seq_len, :]


class Norm(nn.Module):
    def __init__(self, d_model, flag_TX, BN_C, eps = 1e-6):
        super(Norm, self).__init__()
        self.flag_TX = flag_TX
        self.size = d_model
        self.alpha = nn.Parameter(torch.ones(self.size))
        self.bias = nn.Parameter(torch.zeros(self.size))
        
        self.eps = eps
    
    def forward(self, x):
        x = self.alpha * (x - x.mean(dim=-1, keepdim=True)) \
        / (x.std(dim=-1, keepdim=True) + self.eps) + self.bias

        return x

class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff=2048, dropout = 0.1):
        super(FeedForward, self).__init__()
    
        self.linear_1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear_2 = nn.Linear(d_ff, d_model)
    
    def forward(self, x):
        x = self.dropout(F.relu(self.linear_1(x)))
        x = self.linear_2(x)
        return x

def attention(q, k, v, d_k, dropout=None):
    
    scores = torch.matmul(q, k.transpose(-2, -1)) /  math.sqrt(d_k)
    
    scores = F.softmax(scores, dim=-1)      
    
    if dropout is not None:
        scores = dropout(scores)
    output = torch.matmul(scores, v)
    return output

class MultiHeadAttention(nn.Module):
    def __init__(self, heads, d_model, dropout = 0.1):
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.d_k = d_model // heads
        self.h = heads
        self.q_linear = nn.Linear(d_model, d_model, bias = False)
        self.v_linear = nn.Linear(d_model, d_model, bias = False)
        self.k_linear = nn.Linear(d_model, d_model, bias = False)
        self.dropout = nn.Dropout(dropout)
        self.FC = nn.Linear(d_model, d_model)
    
    def forward(self, q, k, v, decoding = 0):
        
        bs = q.size(0)
        # perform linear operation and split into N heads
        k = self.k_linear(k).view(bs, -1, self.h, self.d_k)
        q = self.q_linear(q).view(bs, -1, self.h, self.d_k)
        v = self.v_linear(v).view(bs, -1, self.h, self.d_k)
        k = k.transpose(1,2)
        q = q.transpose(1,2)
        v = v.transpose(1,2)
        scores = attention(q, k, v, self.d_k, self.dropout)
        concat = scores.transpose(1,2).contiguous().view(bs, -1, self.d_model)
        output = self.FC(concat)
    
        return output

class EncoderLayer(nn.Module):
    def __init__(self, d_model, heads, flag_TX, dropout=0.1):
        super(EncoderLayer, self).__init__()
        self.norm_1 = Norm(d_model, flag_TX, BN_C = 51)
        self.norm_2 = Norm(d_model, flag_TX, BN_C = 51)
        self.MulAttn = MultiHeadAttention(heads, d_model, dropout=dropout)
        self.ffNet = FeedForward(d_model, d_ff = 4*d_model, dropout=dropout)
        self.dropout_1 = nn.Dropout(dropout)
        self.dropout_2 = nn.Dropout(dropout)
        
    def forward(self, x):
        x2 = self.norm_1(x)
        x = x + self.dropout_1(self.MulAttn(x2,x2,x2))
        x2 = self.norm_2(x)
        x = x + self.dropout_2(self.ffNet(x2))
        return x

class Encoder(nn.Module):
    def __init__(self,seq_len, d_model, N, heads, dropout, inputSize=4):
        super(Encoder, self).__init__()
        if inputSize == 4:
            flag_TX = 1
        else:
            flag_TX = 0
        self.N = N
        self.pe=PositionalEncoder(SeqLen=seq_len, lenWord=d_model)
        self.FC = nn.Linear(inputSize, d_model, bias = True)
        self.dropout = nn.Dropout(dropout)
        self.layers = get_clones(EncoderLayer(d_model, heads, flag_TX, dropout), N)
        self.norm = Norm(d_model, flag_TX, BN_C = 51)
    def forward(self, src):
        x = self.FC(src.float())
        x = self.pe(x)
        for i in range(self.N):
            x = self.layers[i](x)
            return self.norm(x)

class TransLayerNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-12):
        super(TransLayerNorm, self).__init__()
        self.gamma = nn.Parameter(torch.ones(hidden_size))
        self.beta = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps
    def forward(self, x):
        u = x.mean(-1, keepdim=True)
        s = (x - u).pow(2).mean(-1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.variance_epsilon)
        return self.gamma * x + self.beta      
class TransEmbeddings(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.LayerNorm = TransLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
    def forward(self, input_ids):
        input_shape = input_ids.size()
        seq_length = input_shape[1]
        device = input_ids.device
        position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
        position_ids = position_ids.unsqueeze(0).expand(input_shape[:2])
        position_embeddings = self.position_embeddings(position_ids)
        embeddings = input_ids + position_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

class Encoder_T(nn.Module):
    def __init__(self,seq_len, d_model, N, heads, dropout, inputSize_0,inputSize_fb):
        super(Encoder_T, self).__init__()
        flag_TX = 0
        self.N = N
        self.pe=PositionalEncoder(SeqLen=seq_len, lenWord=d_model)
        self.FC_src_0 = nn.Linear(inputSize_0, d_model, bias = True)
        self.FC_fb_0 = nn.Linear(3, 1, bias = True)
        self.FC_fb = nn.Linear(inputSize_fb, d_model, bias = True)
        self.FC_in = nn.Linear(d_model*2, d_model, bias = True)
        self.layers = get_clones(EncoderLayer(d_model, heads, flag_TX, dropout), N)
        self.norm = Norm(d_model, flag_TX, BN_C = 51)
    def forward(self, src_0,fb):
        x_0 =  (self.FC_src_0(src_0.float()))
        x_fb=self.FC_fb((self.FC_fb_0(fb.float())).squeeze(-1))
        x_in= torch.cat((x_0,x_fb),dim=-1)
        x=self.pe(self.FC_in(x_in))
        for i in range(self.N):
            x = self.layers[i](x)
        return self.norm(x)

def args_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument('--K', type=int, default=9, help="Length")
    parser.add_argument('--m', type=int, default=3, help="block")
    parser.add_argument('--rate', type=int, default=3, help="rate")
    parser.add_argument('--state', type=int, default=27)
    parser.add_argument('--action', type=int, default=3)
    parser.add_argument('--critic_step', type=int, default=20)
    parser.add_argument('--decoding_update_step', type=int, default=2)
    parser.add_argument('--noise_variance', type=float, default=20)
    parser.add_argument('--batchSize', type=int, default=4096*2, help="batch size")
    parser.add_argument('--lr', type=float, default=0.001, help="lr")
    parser.add_argument('--temp', type=float, default=50)
    parser.add_argument('--print_iter', type=int, default=5)
    parser.add_argument('--policy_weight', type=float, default=0)
    parser.add_argument('--eval_timestep', type=int, default=40)
    parser.add_argument('--heads_trx', type=int, default=1)
    parser.add_argument('--d_k_trx', type=int, default=32)
    parser.add_argument('--N_trx', type=int, default=2)
    parser.add_argument('--dropout', type=float, default=0.0)
    args = parser.parse_args()
    return args

class BERT_T(nn.Module):
    def __init__(self, args, seq_len,d_model, N, heads, dropout):
        super(BERT_T, self).__init__()
        self.encoder = Encoder_T(seq_len,d_model, N, heads, dropout, args.m, args.m*args.rate)
        self.out = nn.Linear(d_model, args.m*args.rate*args.state)
        self.dropout = nn.Dropout(dropout)
        self.tanh=nn.Tanh()
    def forward(self, src,fb):
        enc_out = self.encoder(src,fb)
        enc_out = self.out(enc_out)
        out_policy=enc_out
        return out_policy

class BERT_R(nn.Module):
    def __init__(self,args,seq_len, input_rate, d_model, N, heads, dropout):
        super(BERT_R, self).__init__()
        self.encoder = Encoder(seq_len,d_model, N, heads, dropout, inputSize = args.m*input_rate+1)
        self.out = nn.Linear(d_model, 2**args.m)
        self.dropout = nn.Dropout(dropout)
    def forward(self, src):
        enc_out = self.encoder(src)
        enc_out = self.out(enc_out)
        return enc_out

class Critic_Net(nn.Module):
    def __init__(self, args,hidden_dim):
        super(Critic_Net, self).__init__()
        self.block_num=args.K//args.m
        self.block_dim=args.m*args.state*args.rate
        self.liner_p = P_net_feature(args, hidden_dim)
        self.liner_ms = M_net_feature(args, hidden_dim) 
        self.block_num=args.K//args.m
    def forward(self, message,intial_state_for_critic, policy):
        initial_state=torch.from_numpy(intial_state_for_critic).cuda().float()
        init_state=(initial_state.unsqueeze(1))
        message_block_0=(2*message-1).float()[:,0,:].unsqueeze(1)
        message_block_other=(2*message-1).float()[:,1:,:]
        m_s_in_0=torch.cat((message_block_0,init_state),dim=-1)
        m_in_other=message_block_other
        predict = (self.liner_p(policy))+self.liner_ms(m_s_in_0,m_in_other)
        return predict

class ChannelEnv(nn.Module):
    def __init__(self,args, num_states, num_acts):
        super(ChannelEnv, self).__init__()
        self.args=args
        self.num_states = num_states
        self.num_acts = num_acts
        self.useful_indices = torch.tensor([3,4,8,12,13,14,15,17,22,23,24])
        self.optimal_policy=torch.zeros(27,3).cuda()
        self.optimal_exp_reward=0
        self.target_optimal_policy_pdf=torch.zeros(11,3).cuda()
       
    def forward(self, policy,c_state):
        z_q_=np.int64((torch.sigmoid(policy)*self.num_acts).floor().cpu().detach().numpy())
        z_q=np.where(z_q_>=self.num_acts-1,self.num_acts-1,z_q_)
        z_q=np.where(z_q<=0,0,z_q)
        batch_size,state_num,block_len=z_q.shape
        a_t=z_q-1
        final_output=np.int64(np.zeros((batch_size,block_len+1)))
        final_output[:,0]=c_state[:,0]
        valid_fb=np.int64(np.zeros((batch_size,block_len,3)))

        for iter in range(block_len):
            act_this_iter=a_t[:,:,iter]
            selected_act = act_this_iter[np.arange(batch_size)[:, None], c_state]
            y_t=c_state%9
            x_t=c_state//9
            x_t_next_=x_t+selected_act
            x_t_next=np.where(x_t_next_>2,2,x_t_next_)
            x_t_next=np.where(x_t_next<0,0,x_t_next)
            y_t_next=np.where(y_t<3,np.random.choice([6,7,8],size=(y_t.shape)),y_t-3)
            new_state=9*x_t_next+y_t_next
            valid_fb[:,iter,0]=c_state[:,0]
            valid_fb[:,iter,1]=selected_act[:,0]
            valid_fb[:,iter,2]=new_state[:,0]
            final_output[:,iter+1]=new_state.squeeze()
            c_state=new_state
        final_output=torch.from_numpy(final_output).cuda().unsqueeze(dim=1)
        valid_fb=torch.from_numpy(valid_fb).cuda().unsqueeze(dim=1)
        return final_output,valid_fb,new_state
            
class P_net(nn.Module):
    def __init__(self, args,hidden_dim):
        super(P_net, self).__init__()
        self.FC_1 = nn.Linear(args.state*args.K*args.rate, hidden_dim, bias = True)
        self.FC_2 = nn.Linear(hidden_dim, hidden_dim, bias = True)
        self.FC_3 = nn.Linear(hidden_dim, 1, bias = True)
    def forward(self, p):
        batch=p.shape[0]
        p=p.view(batch,-1)
        x_1 = self.FC_1(p)
        x_2 = self.FC_2(x_1)
        x_3 = self.FC_3(x_2)
        return x_3

class P_net_feature(nn.Module):
    def __init__(self, args,hidden_dim):
        super(P_net_feature, self).__init__()
        self.FC_1 = nn.Linear(args.state*args.rate*args.m, hidden_dim, bias = True)
        self.FC_2 = nn.Linear(hidden_dim, hidden_dim, bias = True)
        self.FC_3 = nn.Linear(hidden_dim, 2**args.m, bias = True)
    def forward(self, p):
        x_1 = self.FC_1(p)
        x_2 = self.FC_2(x_1)
        x_3 = self.FC_3(x_2)
        return x_3

class M_net_feature(nn.Module):
    def __init__(self, args,hidden_dim):
        super(M_net_feature, self).__init__()
        self.FC_1_s0 = nn.Linear(args.m+1, hidden_dim, bias = True)
        self.FC_1 = nn.Linear(args.m, hidden_dim, bias = True)
        self.FC_2 = nn.Linear(hidden_dim, hidden_dim, bias = True)
        self.FC_3 = nn.Linear(hidden_dim, 2**args.m, bias = True)
        
    def forward(self, m_0,m):
        x_1_0 = self.FC_1_s0(m_0)
        x_1=self.FC_1(m)
        x_1_all=torch.cat((x_1_0,x_1),dim=1)
        x_2 = self.FC_2(x_1_all)
        x_3 = self.FC_3(x_2)
        return x_3

class SysModel(nn.Module):
    def __init__(self, args):
        super(SysModel, self).__init__()
        self.args = args
        self.m=args.m
        self.Channel = ChannelEnv(args,num_states=args.state, num_acts=args.action) 
        self.Tmodel = BERT_T(args,self.args.K//self.m,args.d_model_trx, args.N_trx, args.heads_trx, args.dropout)
        self.Rmodel = BERT_R(args,self.args.K//self.m,args.rate, args.d_model_trx, args.N_trx+2, args.heads_trx, args.dropout)
        self.target_policy=self.Channel.target_optimal_policy_pdf
        self.exp_optimal_reward=self.Channel.optimal_exp_reward
       
    def forward(self, bVec,intial_state,codes=None,critic_flag=0):
        batch_size,bits_num,_=bVec.shape
        if critic_flag==1:
            coding_output_policy=codes
            channel_output_c=torch.zeros(batch_size,self.args.K//self.m,self.args.m*self.args.rate+1).cuda()
            cur_state_critic=intial_state
            for t_index in range (self.args.K//self.m):                
                channel_in_each_c=(coding_output_policy[:,t_index,:].unsqueeze(1)).view(batch_size,self.args.state,self.args.m*self.args.rate)
                channel_out_each_c,_,next_state=self.Channel(channel_in_each_c,cur_state_critic)
                cur_state_critic=next_state
                channel_output_c[:,t_index,:] = channel_out_each_c[:,0,:]      
            decSeq= self.Rmodel(channel_output_c)
        else:
            for idx in range(self.args.K//self.m):
                if idx == 0:
                    cur_state=intial_state
                    src_m=2*bVec[:,:(idx+1),:]-1
                    feedback_init=torch.zeros(batch_size,1,self.args.m*self.args.rate,3).cuda()
                    feedback_init[:,:,-1,-1]=torch.from_numpy(cur_state).cuda()
                    coding_output_policy=torch.zeros(batch_size,self.args.K//self.m,self.args.m*self.args.rate*self.args.state).cuda()
                    channel_output=torch.zeros(batch_size,self.args.K//self.m,self.args.m*self.args.rate+1).cuda()
                    coding_output = self.Tmodel(src_m,feedback_init)
                    coding_output_each_iter=coding_output[:,-1,:].unsqueeze(1)
                    channel_in_each_iter=coding_output_each_iter.view(batch_size,self.args.state,self.args.m*self.args.rate)
                    channel_out_each_iter,channel_fb_each_iter,next_state=self.Channel(channel_in_each_iter,cur_state)
                    cur_state=next_state
                    channel_fb = torch.cat([feedback_init, channel_fb_each_iter], dim=1)  
                    coding_output_policy[:,idx,:] = coding_output_each_iter[:,0,:]
                    channel_output[:,idx,:] = channel_out_each_iter[:,0,:]
                else:
                    src_m=2*bVec[:,:(idx+1),:]-1
                    coding_output = self.Tmodel(src_m,channel_fb)
                    coding_output_each_iter=coding_output[:,-1,:].unsqueeze(1)
                    channel_in_each_iter=coding_output_each_iter.view(batch_size,self.args.state,self.args.m*self.args.rate)
                    channel_out_each_iter,channel_fb_each_iter,next_state=self.Channel(channel_in_each_iter,cur_state)
                    cur_state=next_state
                    channel_fb = torch.cat([channel_fb, channel_fb_each_iter], dim=1)  
                    coding_output_policy[:,idx,:] = coding_output_each_iter[:,0,:]
                    channel_output[:,idx,:] = channel_out_each_iter[:,0,:]
            decSeq = self.Rmodel(channel_output)
        return decSeq,coding_output_policy

def bin2dec(bVec):
    batch_size,_,_=bVec.shape
    block_in=(bVec.view(batch_size,-1,args.m)).cpu().numpy()
    out_block=block_in.dot(2**np.arange(block_in.shape[-1])[::-1])
    return torch.from_numpy(out_block).cuda()

def dec2bin(block_Vec):
    batch=block_Vec.shape[0]
    x=block_Vec.cpu().numpy()
    binary_repr_v = np.vectorize(np.binary_repr)
    out_bin=binary_repr_v(x,args.m)
    converted_bins = (np.array(( [[[int(char) for char in sublist[0]] for sublist in sublist_outer] for sublist_outer in out_bin])))
    return torch.from_numpy(converted_bins).view(batch,-1,1).cuda()


def train_model( args,model,critic_model):
    print("-->-->-->-->-->-->-->-->-->--> start training ...")
    print_index=25
    print('paramater:')
    print(args)
    loss_cross_entropy=torch.nn.CrossEntropyLoss()
    optimizer_critic = torch.optim.Adam(critic_model.parameters(), lr=args.lr, betas=(0.9, 0.98), eps=1e-9)
    critic_loss=torch.nn.MSELoss()
    lr_encoder=args.lr/5
    optimizer_encoder_base = torch.optim.Adam(model.Tmodel.parameters(), lr=lr_encoder,betas=(0.9, 0.98), eps=1e-9)
    optimizer_decoder_base = torch.optim.Adam(model.Rmodel.parameters(), lr=args.lr,betas=(0.9, 0.98), eps=1e-9)
    optimizer_encoder = Lookahead(optimizer_encoder_base, alpha=0.5,k=6)
    optimizer_decoder = Lookahead(optimizer_decoder_base,alpha=0.5, k=6)
    print('adopt lookahead optimizer')
    print('lr for encoder: ',lr_encoder,'; lr for the decoder: ',args.lr)
    decoder_op_step=0
    loss_decoder_all=0
    numBatch = 10000 * 10000
    best_acc=0

    for batch_idx in range(numBatch):
        bVec = torch.randint(0,2,(args.batchSize, args.K, 1)).cuda()
        block_gt=bin2dec(bVec).unsqueeze(-1)
        batch_size,bits_num,_=bVec.shape
        in_block=(bVec.view(batch_size,-1,args.m))
        model.eval()
        critic_model.train()
        intial_state_for_critic = np.random.choice([6,7,8,15,16,17,24,25,26], size=(batch_size,1))
        pre_ori, codes_generated_this_batch = model(in_block,intial_state_for_critic,codes=None,critic_flag=0)
        codes_generated=codes_generated_this_batch.detach()
        if (batch_idx%args.decoding_update_step)==0:
            loss_critic_all=0
            for inner_step in range(args.critic_step):
                optimizer_critic.zero_grad()
                if inner_step==0:
                    code_update=codes_generated
                else:
                    noise_stddev=torch.full((codes_generated.shape),1/np.sqrt(args.noise_variance)).float()
                    noise_mean=torch.zeros_like(noise_stddev).float()
                    noise=torch.normal(mean=noise_mean,std=noise_stddev).float().cuda()
                    code_update=noise+codes_generated
                preds,_= model(in_block,intial_state_for_critic,code_update,critic_flag=1)
                output_environment=preds
                predicted_output=critic_model(in_block,intial_state_for_critic,code_update)
                loss_critic=critic_loss(F.softmax(predicted_output,dim=-1),F.softmax(output_environment,dim=-1))
                loss_critic.backward()
                optimizer_critic.step()
                loss_critic_all=loss_critic_all+loss_critic
            print('critic loss =', (round(loss_critic_all.item()/(inner_step+1),6)),' at step:',batch_idx)

            model.train()
            critic_model.eval()
            preds_1, codes_out_1 = model(in_block,intial_state_for_critic,codes=None,critic_flag=0)
            predicted_critic=critic_model(in_block,intial_state_for_critic,codes_out_1)

            ys_en = block_gt.long().contiguous().view(-1)
            predicted_loss = loss_cross_entropy(predicted_critic.contiguous().view(-1, predicted_critic.size(-1)), ys_en.to(args.device))
            true_loss = loss_cross_entropy(preds_1.contiguous().view(-1, preds_1.size(-1)), ys_en.to(args.device))

            encoder_loss=predicted_loss#+args.policy_weight*policy_loss
            optimizer_encoder.zero_grad()
            encoder_loss.backward()
            optimizer_encoder.step()
            
        model.train()
        critic_model.eval()
        preds_2,_= model(in_block,intial_state_for_critic,codes=None,critic_flag=0)
        ys_2 = block_gt.long().contiguous().view(-1)
        loss_decoding = loss_cross_entropy(preds_2.contiguous().view(-1, preds_2.size(-1)), ys_2.to(args.device))
        optimizer_decoder.zero_grad()
        loss_decoding.backward()
        optimizer_decoder.step()
        decoder_op_step=decoder_op_step+1
        loss_decoder_all=loss_decoder_all+loss_decoding
        if np.mod(batch_idx+1, args.decoding_update_step) == 0:
            print('Train the decoder with decoding loss:',(round(loss_decoder_all.item()/(decoder_op_step+1),6)), ' at step:',batch_idx)
            decoder_op_step=0
            loss_decoder_all=0

        if np.mod(batch_idx, print_index*args.decoding_update_step) == 0:
            with torch.no_grad():
                model.eval()
                critic_model.eval()
                reward_all=0
                acc_all=0
                intial_state_for_test = np.random.choice([6,7,8,15,16,17,24,25,26], size=(batch_size,1))
                current_state=intial_state_for_test
                for i in range(args.eval_timestep):
                    bVec_test = torch.randint(0,2,(args.batchSize, args.K, 1)).cuda()
                    in_block_test=(bVec_test.view(batch_size,-1,args.m))
                    preds_eval,_ = model(in_block_test,current_state,codes=None,critic_flag=0)
                    _, decoded_blocks = (F.softmax(preds_eval,dim=-1)).max(dim=-1)
                    bit_predt=(dec2bin((decoded_blocks.unsqueeze(dim=-1)))).long().contiguous().view(-1)
                    bits_gt = bVec_test.long().contiguous().view(-1)
                    accurate_bit_all=sum(bit_predt==bits_gt.to(args.device))
                    acc_once=accurate_bit_all/(len(bits_gt))
                    acc_all=acc_once+acc_all

                succRate = acc_all/args.eval_timestep
                ave_reward=reward_all/args.eval_timestep
                torch.set_printoptions(precision=8)
                print('************','acc at step:',batch_idx,', BER Performance: ', (1-succRate).item(),'************')
                torch.set_printoptions(precision=4)
                if best_acc<succRate:
                    best_acc=succRate
                    print('*************************, Save the model at step:',batch_idx,', with best BER Performance: ',(1-succRate).item(),'*************************')

if __name__ == '__main__':
    args = args_parser()
    train_flag=1
    args.device = 'cuda'
    args.d_model_trx = args.heads_trx * args.d_k_trx
    model = SysModel(args).to(args.device)
    critic_model = Critic_Net(args,hidden_dim=32).cuda()
    train_model( args,model,critic_model)
    