import torch
from torch import nn
import numpy as np
import contextlib
import os
import argparse
from datetime import datetime
import logging
import time
import galois
from tqdm import tqdm
GF = galois.GF(2)
##########################################################################################
##########################################################################################
##########################################################################################

def Read_pc_matrixrix_alist(fileName):
    with open(fileName, 'r') as file:
        lines = file.readlines()
        columnNum, rowNum = np.fromstring(
            lines[0].rstrip('\n'), dtype=int, sep=' ')
        H = np.zeros((rowNum, columnNum)).astype(int)
        for column in range(4, 4 + columnNum):
            nonZeroEntries = np.fromstring(
                lines[column].rstrip('\n'), dtype=int, sep=' ')
            for row in nonZeroEntries:
                if row > 0:
                    H[row - 1, column - 4] = 1
        return H

def GetPCM(code):
    n, k = code.n, code.k
    path_pc_mat = os.path.join('Codes_DB', f'{code.code_type}_N{str(n)}_K{str(k)}')
    for file in os.listdir('Codes_DB'):
        if f'{code.code_type}_N{str(n)}_K{str(k)}' in file:
            if '.alist' in file:
                return Read_pc_matrixrix_alist(path_pc_mat+'.alist').astype(float)
            else:
                return np.loadtxt(path_pc_mat+'.txt').astype(float)
    raise Exception(f'Wrong code '+f'{code.code_type}_N{str(n)}_K{str(k)}')
##########################################################################################
##########################################################################################
##########################################################################################

def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)

def sign_to_bin(x):
    return 0.5 * (1 - x)

def bin_to_sign(x):
    return 1 - 2 * x

def EbN0_to_std(EbN0, rate):
    snr =  EbN0 + 10. * np.log10(2 * rate)
    return np.sqrt(1. / (10. ** (snr / 10.)))
##########################################################################################
##########################################################################################
##########################################################################################
def clamp_atanh( x ):
    clamp_val = 1 - 1e-7
    x = x.clamp(-clamp_val,clamp_val)
    return 0.5*(torch.log(1+x) - torch.log(1-x))

class Binarization(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return ((input>=0)*1. - (input<0)*1.).float()

    @staticmethod
    def backward(ctx, grad_output):
        x = ctx.saved_tensors[0]
        return grad_output*(torch.abs(x)<=1)
        # return grad_output
##########################################################################################
##########################################################################################
##########################################################################################   
##########################################################################################
##########################################################################################
##########################################################################################

class BP_Network(nn.Module):
    def __init__(self, code, args):
        super(BP_Network, self).__init__()
        self.H_init = args.H_init
        self.num_iters = args.num_iters
        self.bin = Binarization.apply
        self.H_matrix = nn.Parameter(self.H_init*bin_to_sign(code.H).float())
        self.args = args
        self.G_matrix = None
        self.bp_method = 'sum_prod'
    #######################################
    #######################################
    #######################################
    def forward(self, y):
        #assuming y is LLR
        return eval(f'self.forward_{self.bp_method}')(y)
    #######################################
    #######################################
    #######################################    
    def forward_sum_prod(self, y):
        C_bp = y.unsqueeze(-1)
        R_bp = None
        H_mat = self.get_bin_H()
        if self.training:
            # enhanced gradients/local minima slight escape
            if self.args.eps_scenario == 0:
                H_mat = H_mat+self.args.H_eps
            elif self.args.eps_scenario == 1:
                H_mat = H_mat+bin_to_sign(torch.randint(0,2,H_mat.shape)).to(H_mat.device)*self.args.H_eps
            elif self.args.eps_scenario == 2:
                H_mat[H_mat<1] = H_mat[H_mat<1]+bin_to_sign(torch.randint(0,2,H_mat[H_mat<1].shape)).to(H_mat.device)*self.args.H_eps
            elif self.args.eps_scenario == 3:
                H_mat[H_mat<1] = H_mat[H_mat<1] + self.args.H_eps
            else:
                raise Exception(f'Wrong scenario {self.args.eps_scenario}')            
        ###
        for bp_iter in range(self.num_iters):
            if R_bp is not None:
                Q_bp = C_bp + torch.sum(R_bp*H_mat.T.unsqueeze(0),-1).unsqueeze(-1) - R_bp
            else:
                Q_bp = C_bp
            Q_bp = torch.tanh(0.5*Q_bp)
            tmp = Q_bp*H_mat.T.unsqueeze(0)
            tmp = tmp +(1-H_mat).T.unsqueeze(0) #adish kifli
            tmp = torch.prod(tmp,1)
            tmp = tmp.unsqueeze(1)/(Q_bp+1e-7) # remove eps
            tmp = tmp*H_mat.T.unsqueeze(0) # can remove
            R_bp = 2*clamp_atanh(tmp)
        ###
        x_pred = C_bp.squeeze()+torch.sum(R_bp*H_mat.T.unsqueeze(0),-1)
        return x_pred
    #######################################
    #######################################
    #######################################
    def get_bin_H(self):
        return sign_to_bin(self.bin(self.H_matrix))

#####################################################################################################################
#####################################################################################################################
#####################################################################################################################
#####################################################################################################################
#####################################################################################################################
#####################################################################################################################
class Legacy_BP_Network_Trainable(nn.Module):
    def __init__(self, code, args):
        super(Legacy_BP_Network_Trainable, self).__init__()
        self.model = BP_Network(code,args)
        self.args = args
        
    def forward(self, y):
        return self.model(y)

    def perform_H_optimization(self):
        
        metric = 'fer'
        self.eval()
        self.model.eval()
        orig_H_bin = self.model.module.get_bin_H().clone()
        torch.save(model, os.path.join(args.path, f'model_iter_{0}'))
        t = time.time()
        ber_all = fer_all = float('inf')
        best_H = None
        metric_all =[]
        for kk in range(self.args.opt_num_iters):
            ###
            all_data = None
            all_data = []
            for ll in tqdm(range(self.args.inner_opt_niter),desc='Generating data', leave=True):
                y, stds, x_gt, h, burst = self.get_data_train()
                all_data.append([y.cpu(), stds.cpu(), x_gt.cpu(), h, burst])
            ###
            curr_H_bin = self.model.module.get_bin_H().clone()
            curr_H = self.model.module.H_matrix.clone()
            curr_rank = np.linalg.matrix_rank(GF(curr_H_bin.int().cpu().numpy()))
            loss_, ber_, fer_ = self.loss_fun(compute_grad=True,all_data=all_data)
            vec_grad = self.model.module.H_matrix.grad.detach().clone()*1+0
            self.model.zero_grad(set_to_none=True)
            ###
            lr_search_vec = (curr_H/(vec_grad+1e-10))
            lr_search_vec = (lr_search_vec[lr_search_vec>0].view(-1)).abs()
            lr_search_vec = torch.cat([lr_search_vec,torch.zeros(1).to(lr_search_vec.device)])
            lr_search_vec = torch.sort(lr_search_vec)[0]
            for ii in range(len(lr_search_vec)-1):
                fac = 0.01 #initial 1% increase
                while lr_search_vec[ii]*(1+fac) >= lr_search_vec[ii+1]:
                    fac = fac*0.9
                    if fac <1e-6:
                        break
                lr_search_vec[ii] *= (1+fac)
            ###
            loss_arr, ber_arr, fer_arr = [],[], []
            logging.info(f'Iter kk={kk}: Initial loss={loss_:.5e}, Initial BER={ber_:.5e}, Initial FER={fer_}, H Sparsity={(curr_H_bin).mean()*100:.2f}, ||grad H||={torch.norm(vec_grad):.3e}, LS size={len(lr_search_vec)}')
            ###
            lr_search_vec = lr_search_vec[:50]
            for step_idx, step_size in enumerate(lr_search_vec):
                with torch.no_grad():
                    self.model.module.H_matrix.data = curr_H - vec_grad*step_size
                #
                curr_H_bin_iter = self.model.module.get_bin_H().clone()
                curr_rank_iter = np.linalg.matrix_rank(GF(curr_H_bin_iter.int().cpu().numpy()))
                #
                loss, ber, fer = self.loss_fun(all_data=all_data)
                if curr_rank_iter < curr_rank:
                    loss = ber = fer = np.inf
                loss_arr.append(loss)
                ber_arr.append(ber)
                fer_arr.append(fer)
                with torch.no_grad():
                    self.model.module.H_matrix.data = curr_H.data.clone()
                #
                logging.info(f'Iter kk={kk}, step_idx={step_idx}/{len(lr_search_vec)}, loss={loss:.5e}, BER={ber:.5e}, FER={fer:.5e}, |H-H0|={torch.sum((curr_H_bin_iter-curr_H_bin).abs())} ,H Sparsity={(curr_H_bin_iter).mean()*100:.2f} curr lambda={lr_search_vec[step_idx]:.5e}')

            loss_arr = np.array(loss_arr)
            ber_arr = np.array(ber_arr)
            fer_arr = np.array(fer_arr)
            metric_all.append(eval(f'{metric}_arr'))
            eval(f'{metric}_arr')[np.isnan(eval(f'{metric}_arr'))] = 1e8
            idx = np.argmin(eval(f'{metric}_arr'))
            with torch.no_grad():
                self.model.module.H_matrix.data = curr_H - vec_grad*lr_search_vec[idx]
                self.model.module.H_matrix.data[self.model.module.H_matrix.data>0] = self.model.module.H_init
                self.model.module.H_matrix.data[self.model.module.H_matrix.data<0] = -self.model.module.H_init
            ###
            if idx == 0: # no improvement
                logging.info(f'Iter kk={kk}: No improvement. Exiting Optimization \n')
                break
            else:
                if not torch.norm(curr_H_bin-self.model.module.get_bin_H(),p=1) > 0: # sanity check just assert
                    logging.info(f'Iter kk={kk}: No improvement. Exiting Optimization \n')
                    break
                logging.info(f'Iter kk={kk}: Final loss={loss_arr[idx]:.5e}, Final BER={ber_arr[idx]:.5e}, Final FER={fer_arr[idx]}, Best idx={idx}, |H0-H|={(orig_H_bin-self.model.module.get_bin_H()).abs().sum()}, lambda={lr_search_vec[idx]:.3e}, H Sparsity={(self.model.module.get_bin_H()).mean()*100:.2f}, time={time.time()-t:.3e}s \n')
                ###
                torch.save(model, os.path.join(args.path, f'model_iter_{kk+1}'))
                ber_all_,fer_all_ = self.test()
                if eval(f'{metric}_all_') < eval(f'{metric}_all'):
                    best_H = self.model.module.H_matrix.data.clone()
                else:
                    logging.info(f"Iter kk={kk}: Worse testing performance {eval(f'{metric}_all_')} vs {eval(f'{metric}_all')}. Exiting Optimization \n")
                    break
        if best_H is not None:
            with torch.no_grad():
                self.model.module.H_matrix.data = best_H
        return metric_all
    
    def get_data_train(self):
        if self.args.train_data_scenario == 0:
            stds = std_train[torch.randperm(self.args.bs)%len(std_train)]
            noise = (torch.randn(self.args.bs,code.n)*stds.unsqueeze(-1)).to(device)
            x_gt = noise*0
            x = bin_to_sign(x_gt)
            y = x+noise
            # y = x+noise
        elif self.args.train_data_scenario == 1: #no zero syndrome
            y_all_std = []
            stds_all = []
            x_gt = (torch.zeros(self.args.bs,code.n)).to(device)
            x = bin_to_sign(x_gt)
            with torch.no_grad():
                self.eval()
                self.model.eval()
                x_gt = x_gt.to(device)                
                for std_curr in std_train:
                    counter = 0
                    y_all = None
                    while counter == 0 or sum(((syndromes.sum(-1)>0) ).float()) < self.args.bs//len(std_train):
                        if counter:
                            idx = torch.where((~((syndromes.sum(-1)>0) )).float())[0]
                        y = x + (torch.randn(self.args.bs,code.n)*std_curr).to(device)                        #
                        if counter:
                            syndromes = self.get_syndromes(y)
                            idx_n = torch.where((syndromes.sum(-1)>0) )[0]
                            idx_n = idx_n[:len(idx)]
                            idx = idx[:len(idx_n)]
                            y_all[idx] = y[idx_n]
                        else:
                            y_all = y
                        syndromes = self.get_syndromes(y_all)
                        counter += 1

                    y_all_std.append(y_all[torch.where((syndromes.sum(-1)>0) )[0]][:self.args.bs//len(std_train)])
                    stds_all.append(y_all_std[-1][:,0]*0+std_curr)
                #
                y = torch.cat(y_all_std,0)
                stds = torch.cat(stds_all,0)
                x_gt = x_gt[:len(y)]
        ######
        elif self.args.train_data_scenario == 2: #only faulty
            y_all_std = []
            stds_all = []
            x_gt = (torch.zeros(self.args.bs,code.n)).to(device)
            x = bin_to_sign(x_gt)
            with torch.no_grad():
                self.eval()
                self.model.eval()
                x_gt = x_gt.to(device)                
                for std_curr in std_train:
                    counter = 0
                    y_all = None
                    while counter == 0 or sum(((fer_ > 0)).float()) < self.args.bs//len(std_train):
                        if counter:
                            idx = torch.where((~((fer_ > 0))).float())[0]
                        y = x + (torch.randn(self.args.bs,code.n)*std_curr).to(device)                        #
                        if counter:
                            fer_ = torch.any(sign_to_bin(torch.sign(self.model(y.to(device),
                                                                    ((y_all*0+std_curr)).to(device)))) != x_gt, dim=1).float()
                            idx_n = torch.where((fer_ > 0))[0]
                            idx_n = idx_n[:len(idx)]
                            idx = idx[:len(idx_n)]
                            y_all[idx] = y[idx_n]
                        else:
                            y_all = y
                        fer_ = torch.any(sign_to_bin(torch.sign(self.model(y_all.to(device),
                                                                           ((y_all*0+std_curr)).to(device)))) != x_gt, dim=1).float()
                        counter += 1

                    y_all_std.append(y_all[torch.where((fer_ > 0))[0]][:self.args.bs//len(std_train)])
                    stds_all.append(y_all_std[-1][:,0]*0+std_curr)
                #
                y = torch.cat(y_all_std,0)
                stds = torch.cat(stds_all,0)
                x_gt = x_gt[:len(y)]
        elif self.args.train_data_scenario == 3: #no zero syndrome and only faulty
            y_all_std = []
            stds_all = []
            x_gt = (torch.zeros(self.args.bs,code.n)).to(device)
            x = bin_to_sign(x_gt)
            with torch.no_grad():
                self.eval()
                self.model.eval()
                x_gt = x_gt.to(device)                
                for std_curr in std_train:
                    counter = 0
                    y_all = None
                    while counter == 0 or sum(((syndromes.sum(-1)>0) & (fer_ > 0)).float()) < self.args.bs//len(std_train):
                        if counter:
                            idx = torch.where((~((syndromes.sum(-1)>0) & (fer_ > 0))).float())[0]
                        y = x + (torch.randn(self.args.bs,code.n)*std_curr).to(device)                        #
                        if counter:
                            syndromes = self.get_syndromes(y)
                            fer_ = torch.any(sign_to_bin(torch.sign(self.model(y.to(device),
                                                                    ((y_all*0+std_curr)).to(device)))) != x_gt, dim=1).float()
                            idx_n = torch.where((syndromes.sum(-1)>0) & (fer_ > 0))[0]
                            idx_n = idx_n[:len(idx)]
                            idx = idx[:len(idx_n)]
                            y_all[idx] = y[idx_n]
                        else:
                            y_all = y
                        syndromes = self.get_syndromes(y_all)
                        fer_ = torch.any(sign_to_bin(torch.sign(self.model(y_all.to(device),
                                                                           ((y_all*0+std_curr)).to(device)))) != x_gt, dim=1).float()
                        counter += 1

                    y_all_std.append(y_all[torch.where((syndromes.sum(-1)>0) & (fer_ > 0))[0]][:self.args.bs//len(std_train)])
                    stds_all.append(y_all_std[-1][:,0]*0+std_curr)
                #
                y = torch.cat(y_all_std,0)
                stds = torch.cat(stds_all,0)
                x_gt = x_gt[:len(y)]
        else:
            raise Exception(f'Wrong scenario {self.args.train_data_scenario}')
        return y.cpu(),stds.cpu(),x_gt.cpu(), None, None
    
    def get_syndromes(self, y):
        return (sign_to_bin(torch.sign(y))@self.model.module.get_bin_H().T)%2

    def loss_fun(self, compute_grad=False, all_data=None):
        self.zero_grad(set_to_none=True)
        self.model.zero_grad(set_to_none=True)
        if compute_grad:
            self.model.train()
            self.model.module.train()
        else:
            self.model.eval()
            self.model.module.eval()
        loss = ber = fer = 0
        n_iters = self.args.inner_opt_niter
        with torch.no_grad() if not compute_grad else contextlib.nullcontext():
            for ii in tqdm(range(n_iters),desc='Loss Fun', leave=False):
                y, stds, x_gt, h_gt, bursting_noise_gt = all_data[ii]
                x_gt = x_gt.to(device)
                llr = 2*y/(stds.unsqueeze(-1)**2)
                x_pred = self.model(llr.to(device))
                #
                loss_ = nn.functional.binary_cross_entropy_with_logits(-x_pred, x_gt)/n_iters
                #
                loss += loss_
                ber += torch.mean((sign_to_bin(torch.sign(x_pred)) != x_gt).float())/n_iters
                fer += torch.any(sign_to_bin(torch.sign(x_pred)) != x_gt, dim=1).float().mean()/n_iters
                if compute_grad:
                    loss_.backward()

        return loss.item(), ber.item(), fer.item()
    
######################################################################
######################################################################
    
    def test(self,verbose=True):
        self.eval()
        self.model.eval()
        n_iter = self.args.test_niter if self.args.test_niter > 0 else int(1e8)
        ber_all = fer_all = 0
        with torch.no_grad():
            for ii in range(len(std_test)):
                ber = fer = iter_c = 0
                while iter_c < n_iter:
                    y = 1+torch.randn(self.args.test_bs,code.n).to(device)*std_test[ii]
                    x_gt = y.to(device)*0

                    llr = 2*y/(y*0+(std_test[ii]**2))
                    x_pred = self.model(llr.to(device))
                    
                    ber += (sign_to_bin(torch.sign(x_pred)) != x_gt).float().mean(-1).sum().item()
                    fer += torch.any(sign_to_bin(torch.sign(x_pred)) != x_gt, dim=1).float().sum().item()
                    iter_c += x_pred.size(0)
                    if fer > 100 and iter_c > 1e6:
                        break
                ber = ber / iter_c
                fer = fer / iter_c
                if verbose:
                    logging.info(f'TEST: EbNo {EbNo_range_test[ii]}: BER={ber:.3e}, -ln(BER)={-np.log(ber):.3e} --- FER={fer:.3e}, -ln(FER)={-np.log(fer):.3e}')
                ber_all += ber/sum(std_test)
                fer_all += fer/sum(std_test)
        if verbose:
            logging.info('####################\n')
        model.model.module.num_iters = args.num_iters
        return ber_all, fer_all
##########################################################################################
##########################################################################################
##########################################################################################
##########################################################################################
##########################################################################################
##########################################################################################
##########################################################################################
##########################################################################################
##########################################################################################
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Pytorch Optimized Codes for BP')
    # Sys args
    parser.add_argument('--gpus', type=str, default='0,1,2,3,4,5,6,7', help='gpus ids')
    # Code args
    parser.add_argument('--code_type', type=str, default='LDPC')
    parser.add_argument('--code_k', type=int, default=48)
    parser.add_argument('--code_n', type=int, default=96)
    # BP args
    parser.add_argument('--H_init', type=float, default=0.1)
    parser.add_argument('--num_iters', type=int, default=5)
    parser.add_argument('--H_eps', type=float, default=0)
    # Optimization args
    parser.add_argument('--train_ebno_start', type=int, default=4)
    parser.add_argument('--train_ebno_end', type=int, default=7)
    parser.add_argument('--opt_num_iters', type=int, default=25)
    parser.add_argument('--bs', type=int, default=32768)
    parser.add_argument('--inner_opt_niter', type=int, default=150)
    # Testing args
    parser.add_argument('--test_bs', type=int, default=32768)
    parser.add_argument('--test_niter', type=int, default=-1)
    parser.add_argument('--test_ebn0_start', type=int, default=3)
    parser.add_argument('--test_ebn0_end', type=int, default=7)
    # Scenarios
    parser.add_argument('--eps_scenario', type=int, default=3)
    parser.add_argument('--train_data_scenario', type=int, default=0)
    args = parser.parse_args()
    ###
    set_seed()
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    ###
    ###
    channel = 'AWGN' 
    model_dir = os.path.join('Results',
                            f'{channel}',
                            args.code_type , 
                            'Code_n_' + str(args.code_n),
                            'Code_k_' + str(args.code_k), 
                            datetime.now().strftime("%d_%m_%Y_%H_%M_%S"))
    os.makedirs(model_dir, exist_ok=True)
    args.path = model_dir
    handlers = [logging.FileHandler(os.path.join(model_dir, 'logging.txt'))]
    handlers += [logging.StreamHandler()]
    logging.basicConfig(level=logging.INFO, format='%(message)s',handlers=handlers)
    logging.info(f"Path to model/logs: {model_dir}")
    logging.info(f'Args: {args}')
    ###
    ###
    ###
    class Code():
        pass
    code = Code()
    code.k = args.code_k
    code.n = args.code_n
    code.code_type = args.code_type
    ###
    EbNo_range_train = range(args.train_ebno_start, args.train_ebno_end+1)
    EbNo_range_test = range(args.test_ebn0_start, args.test_ebn0_end+1)
    std_test = [EbN0_to_std(ii, code.k / code.n) for ii in EbNo_range_test]
    std_train = torch.tensor([EbN0_to_std(ii, code.k / code.n) for ii in EbNo_range_train]).float()
    ###
    code.H = torch.from_numpy(GetPCM(code)).float()
    logging.info(f'Rank of H: {np.linalg.matrix_rank(GF(code.H.int().cpu().numpy()))}')
    ###
    model = Legacy_BP_Network_Trainable(code,args=args)
    model.model = torch.nn.DataParallel(model.model, device_ids = list(range(len(args.gpus.split(','))))).to(device)
    ###

    logging.info(f'Baseline Performance:')
    model.model.module.num_iters = 5
    logging.info(f'BP L={model.model.module.num_iters}:')
    model.test()
    model.model.module.num_iters = 15
    logging.info(f'BP L={model.model.module.num_iters}:')
    model.test()
    
    ############################
    
    model.model.module.num_iters = args.num_iters
    logging.info(f'PCM Optimization:')
    metric_all = model.perform_H_optimization()
    torch.save(model, os.path.join(args.path, 'final_model'))

    ############################
    
    logging.info(f'New Performance:')
    model.model.module.num_iters = 5
    logging.info(f'BP L={model.model.module.num_iters}:')
    model.test()
    model.model.module.num_iters = 15
    logging.info(f'BP L={model.model.module.num_iters}:')
    model.test()
