#generation process
import torch
import numpy as np
import random
import argparse
from tqdm import tqdm

from torch.distributions import Categorical
import torch.nn.functional as F

from model import NN
from matrix_q import calculate_Q_bar, calculate_Q
torch.set_num_threads(12)

def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     np.random.seed(seed)
     random.seed(seed)
     torch.backends.cudnn.deterministic = True

def build_init(k_list):
    for i in range(len(k_list)):
        if i == 0:
            init_feature = torch.ones(1, k_list[i])
            init_feature = init_feature / k_list[i]
        else:
            tmp_feature = torch.ones(1, k_list[i])
            tmp_feature = tmp_feature / k_list[i]
            init_feature = torch.cat((init_feature, tmp_feature), 1)
    return init_feature

def sample_at_feature_idx(distribution, feature_idx, k_list):
    left = sum(k_list[:feature_idx])
    right = sum(k_list[:feature_idx]) + k_list[feature_idx]
    distribution = distribution[:, left:right]
    distribution = Categorical(distribution)
    sample = distribution.sample()
    new_feature = F.one_hot(sample, num_classes = k_list[feature_idx]).float()
    return new_feature

def sample_from_distribution(distribution, k_list):
    for feature_idx in range(len(k_list)):
        if feature_idx == 0:
            sample = sample_at_feature_idx(distribution, feature_idx, k_list)
        else:
            tmp_sample = sample_at_feature_idx(distribution, feature_idx, k_list)
            sample= torch.cat((sample, tmp_sample), 1)
    return sample

def calculate_xtQt(x_idx, k_list, Qt_list):
    for feature_idx in range(len(k_list)):
        x_tmp = x_idx[:, sum(k_list[:feature_idx]):sum(k_list[:feature_idx])+k_list[feature_idx]]
        if feature_idx == 0:
            XtQt = torch.mm(x_tmp, Qt_list[feature_idx])
        else:
            XtQt_tmp = torch.mm(x_tmp, Qt_list[feature_idx])
            XtQt = torch.cat((XtQt, XtQt_tmp), 1)
    return XtQt

    

def main():
    parser = argparse.ArgumentParser(description="Arg Parse for Diffusion Model privacy")
    parser.add_argument('--seed', dest = 'seed', type = str, default = '123', help = 'the random seed set in the experiments')
    parser.add_argument('--num_sample', dest = 'num_sample', type = int, default = 100000, help = 'the number of samples to generate')
    parser.add_argument('--model_path', dest = 'model_path', type = str, default = './model_cat_1_eph_12_ratio_0.05.pth', help = 'the path to save the model')
    parser.add_argument('--gpu', dest = 'gpu', type = int, default = 4, help = 'the GPU card number')
    parser.add_argument('--total_steps', dest = 'total_steps', type = int, default = 10, help = 'the total generations steps in the generation')
    parser.add_argument('--num_layer', dest = 'num_layer', type = int, default = 4, help = 'layer number of NNs')
    parser.add_argument('--emb_dim', dest = 'emb_dim', type = int, default = 256, help = 'dimension of fatures')
    parser.add_argument('--d', type=int, default=11, help='the dimension of features')
    parser.add_argument('--k', type=int, default=4, help='num of categories')
    parser.add_argument('--label', type=int, default=1)
    args = parser.parse_args()

    ratio = 0.05

    #set up the random seed
    setup_seed(int(args.seed))

    if args.label == 0:
        args.num_sample = 23068
    elif args.label == 1:
        args.num_sample = 7650
    else:
        raise NotImplementedError
    #set the device
    device = torch.device('cuda:'+str(args.gpu) if torch.cuda.is_available() else 'cpu')

    #load the dataset and k_list, p_list
    k_list = np.ones((args.d), dtype = int)
    k_list = k_list * args.k

    #load the model
    model = NN(k_list, args.emb_dim, args.num_layer)
    model_parameters = torch.load(args.model_path)
    model.load_state_dict(model_parameters)
    # model.load_state_dict(model_parameters.state_dict())
    model.to(device)

    #build uniform distribution noise to sample from, [args.num_sample, n*k] 
    uniform_init = build_init(k_list)
    uniform_init = uniform_init.repeat(args.num_sample, 1)
    x_t = sample_from_distribution(uniform_init, k_list)

    for step_idx in tqdm(range(args.total_steps)):
        if step_idx == 0:
            distribution_idx = uniform_init
            x_idx = x_t
        #get the model output
        x_idx = x_idx.to(device)
        phi_idx = model(x_idx)
        phi_idx = phi_idx.cpu()
        x_idx = x_idx.cpu()
        #calculate x_t @ Q_t
        Qt_list = calculate_Q(args.total_steps - step_idx, 'linear', k_list, args.total_steps)
        xtQt = calculate_xtQt(x_idx, k_list, Qt_list)
        Qt_1_bar_list = calculate_Q_bar(args.total_steps - step_idx - 1, 'linear', k_list, args.total_steps)
        Qt_bar_list = calculate_Q_bar(args.total_steps - step_idx, 'linear', k_list, args.total_steps)
        #for loop to calculate each feature and in each feature sum all x0
        for feature_idx in range(len(k_list)):
            xtQt_tmp = xtQt[:, sum(k_list[:feature_idx]):sum(k_list[:feature_idx])+k_list[feature_idx]]
            #get Qt-1_bar
            Qt_1_bar = Qt_1_bar_list[feature_idx]
            #get Qt_bar
            Qt_bar = Qt_bar_list[feature_idx]
            #get x_idx_tmp
            x_idx_tmp = x_idx[:, sum(k_list[:feature_idx]):sum(k_list[:feature_idx])+k_list[feature_idx]]
            for x0_idx in range(k_list[feature_idx]):
                #get x0_i
                x0_i = torch.zeros((1, k_list[feature_idx])).float()
                x0_i[0, x0_idx] = 1.
                x0_i = x0_i.repeat(args.num_sample, 1)
                #get x0_i @ Qt-1_bar
                x0_iQt_1_bar = torch.mm(x0_i, Qt_1_bar)
                #get x0_i @ Qt_bar @ (Xt_i)T
                x0_iQt_barXt = torch.mm(x0_i, Qt_bar)
                x0_iQt_barXt = torch.bmm(x0_iQt_barXt.unsqueeze(1), x_idx_tmp.unsqueeze(-1))
                x0_iQt_barXt = x0_iQt_barXt.squeeze(-1)
                #get xt_iQt * x0_iQt_1_bar
                xt_iQtx0_iQt_1_bar = xtQt_tmp * x0_iQt_1_bar
                #get the division
                division = xt_iQtx0_iQt_1_bar / x0_iQt_barXt
                #get phi_idx_i
                phi_idx_i = phi_idx[:, sum(k_list[:feature_idx]) + x0_idx].reshape(-1, 1)
                if x0_idx == 0:
                    prob_i = division * phi_idx_i
                else: 
                    prob_i_tmp = division * phi_idx_i
                    prob_i = prob_i + prob_i_tmp
            #now we have prob_i, which means the probability distribution in the i-th feature
            distribution = Categorical(prob_i)
            sample = distribution.sample()
            if feature_idx == 0:
                new_feature = F.one_hot(sample, num_classes = k_list[feature_idx]).float()
            else:
                new_feature_tmp = F.one_hot(sample, num_classes = k_list[feature_idx]).float()
                new_feature = torch.cat((new_feature, new_feature_tmp), 1)
        if step_idx%3 == 0:
            print(torch.sum(x_idx, 0))
        #now we get the new samples in t-1
        x_idx = new_feature

    #do some statistics
    prob_list = []
    x_stat = torch.sum(x_idx, 0)
    for feature_idx in range(len(k_list)):
        prob_tmp = x_stat[sum(k_list[:feature_idx]):sum(k_list[:feature_idx])+k_list[feature_idx]]
        prob_tmp = prob_tmp / torch.sum(prob_tmp)
        prob_list.append(prob_tmp)
        print('the probability distribution in'+str(feature_idx)+'-th feature:'+str(prob_tmp))

    #x_idx

    for i in range(len(k_list)):
        temp_feature = x_idx[:, sum(k_list[:i]):sum(k_list[:i])+k_list[i]]
        if i == 0: 
            labels = torch.argmax(temp_feature, dim=1).reshape(-1, 1)
        else:
            temp_labels = torch.argmax(temp_feature, dim=1).reshape(-1, 1)
            labels = torch.cat((labels, temp_labels), 1)
    torch.save(labels, f'./gen_data_cat_{args.label}_ratio_{ratio}.pth')

if __name__ == '__main__':
    main()