import torch
import numpy as np
import random
import argparse
from pathlib import Path
import yaml
import wandb
from tqdm import tqdm
import os

from torch.utils.data import DataLoader
from torch.distributions import Categorical
import torch.nn.functional as F

from dataset.dataset import SyntheticDataset, RealDataset
from model import NN
from matrix_q import calculate_Q_bar

torch.set_num_threads(24)

def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     np.random.seed(seed)
     random.seed(seed)
     torch.backends.cudnn.deterministic = True

def data_at_featureidx(args, k_list, data, feature_idx, matrix_Q):
    if feature_idx == 0:
        distribution = torch.mm(data[:,:k_list[feature_idx]], matrix_Q)
        label = data[:,:k_list[feature_idx]]
    else:
        distribution = torch.mm(data[:,sum(k_list[:feature_idx]) : sum(k_list[:feature_idx]) + k_list[feature_idx]], matrix_Q)
        label = data[:,sum(k_list[:feature_idx]) : sum(k_list[:feature_idx]) + k_list[feature_idx]]

    distribution = Categorical(distribution)
    for sample_idx in range(args.samples_T):
    #3 - for each step in each data batch, how many samples to sample
        if sample_idx == 0:
            feature_temp = distribution.sample()
            new_feature = F.one_hot(feature_temp, num_classes = k_list[feature_idx]).float()
            tmp_label = label
        else:
            feature_temp = distribution.sample()
            feature_temp_onehot = F.one_hot(feature_temp, num_classes = k_list[feature_idx]).float()
            new_feature = torch.cat((new_feature, feature_temp_onehot))
            label = torch.cat((label, tmp_label))
    return new_feature, label

def data_at_T(args, k_list, data, Q_list):
    for feature_idx in range(len(k_list)):
    # traverse each feature, totally n features with dimension k
        matrix_Q = Q_list[feature_idx]
        if feature_idx == 0:
            new_feature, label = data_at_featureidx(args,k_list, data, feature_idx, matrix_Q)
        else:
            tmp_feature, tmp_label = data_at_featureidx(args,k_list, data, feature_idx, matrix_Q)
            new_feature = torch.cat((new_feature, tmp_feature), 1)
            label = torch.cat((label, tmp_label), 1)
    return new_feature, label

# # adult dataset
# def map_indices_to_submatrices(data, index_list, label):
#     submatrices_index = np.where(data[:, -1] == label)[0]
#     submatrices = data[submatrices_index, 1:-1]
#     index_array = np.array(index_list)
#     indices_label = index_array[data[index_array, -1] == label].tolist()
#     indices = []
#     for row in data[indices_label, 1:-1]:
#         matches = (submatrices == row).all(axis=1)
#         if matches.any():
#             index = np.where(matches)[0][0]
#             indices.append(index)
#     return indices

# credit dataset
def map_indices_to_submatrices(data, index_list, label):
    submatrices_index = np.where(data[:, -1] == label)[0]
    submatrices = data[submatrices_index, :-1]
    index_array = np.array(index_list)
    indices_label = index_array[data[index_array, -1] == label].tolist()
    indices = []
    for row in data[indices_label, :-1]:
        matches = (submatrices == row).all(axis=1)
        if matches.any():
            index = np.where(matches)[0][0]
            indices.append(index)
    return indices

def main():
    parser = argparse.ArgumentParser(description="Arg Parse for Diffusion Model privacy")
    parser.add_argument('--seed', dest = 'seed', type = str, default = '123', help = 'the random seed set in the experiments')
    parser.add_argument('--num_layer', dest = 'num_layer', type = int, default = 4, help = 'layer number of NNs')
    parser.add_argument('--emb_dim', dest = 'emb_dim', type = int, default = 256, help = 'dimension of features')
    parser.add_argument('--outer_batch', dest = 'outer_batch', type = int, default = 16, help = 'the batch size of different configurations')
    parser.add_argument('--inner_batch', dest = 'inner_batch', type = int, default = 16, help = 'the batch that sample different timesteps for a batch')
    parser.add_argument('--gpu', dest = 'gpu', type = int, default = 4, help = 'the GPU card number')
    parser.add_argument('--total_steps', dest = 'total_steps', type = int, default = 10, help = 'the total diffusion steps in the training process')
    parser.add_argument('--epochs', dest = 'epochs', type = int, default = 15, help = 'epochs of training')
    parser.add_argument('--decay_f', dest = 'decay_f', type = str, default = 'linear', help = 'the way of decay function in diffusion')
    parser.add_argument('--samples_T', dest = 'samples_T', type = int, default = 30, help = 'number of different samples in time T')
    parser.add_argument('--lr', dest = 'lr', type = float, default = 1e-3, help = 'the learning rate')
    parser.add_argument('--save_path', dest = 'save_path', type = str, default = './xxx', help = 'the folder to save files')
    parser.add_argument('--d', type=int, default=11, help='the number of features')
    parser.add_argument('--k', type=int, default=4, help='the number of categories')
    parser.add_argument('--label', type=int, default=1)
    # parser.add_argument('--load', action='store_true', help='whether load data from file')
    args = parser.parse_args()

    #init wandb
    wandb.init(project="diffusion_synthetic")
    wandb.config = {
    "learning_rate": args.lr,
    "epochs": args.epochs,
    "batch_size": args.outer_batch * args.inner_batch * args.samples_T
    }

    ratio = str(0.05)
    #set up the random seed
    setup_seed(int(args.seed))

    #load the dataset and k_list, p_list
    k_list = np.ones((args.d), dtype = int)
    k_list = k_list * args.k
    # if args.load:
    #     p_list = np.ones((10), dtype = float)
    #     p_list = p_list * 1 / 5
    #     dataset = SyntheticDataset('./syth_dataset', n = 5, k_list = k_list, p_list = p_list, sample_number = 100000)
    # else:
    dataset = RealDataset('./loan_cat.csv.npz', label=args.label)

    num_sample = len(dataset)
    print(num_sample)

    trainset = dataset[:int(0.8*num_sample)]
    valset = dataset[int(0.8*num_sample) : int(0.9*num_sample)] 
    testset = dataset[int(0.9*num_sample):]
    print(f'num of samples: {len(dataset)}, feature dim {args.d}, #cat {args.k}, label {args.label}')

    #set the device
    device = torch.device('cuda:'+str(args.gpu) if torch.cuda.is_available() else 'cpu')

    #construct the folder to save trained model
    if not os.path.exists(args.save_path):
        os.mkdir(args.save_path)

    #load the model
    model = NN(k_list, args.emb_dim, args.num_layer)
    model.to(device)

    #define the criterion and optimizer
    criterion = torch.nn.BCELoss(reduction = 'mean')
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    lowest_loss = 100
    #randomly sample times T for each input, for that time T, generate the input
    for epoch_idx in tqdm(range(args.epochs)):
        train_loader = DataLoader(trainset, args.outer_batch, shuffle = True)
    # number of epochs
        for data in train_loader:
        #1 - different data batch
            data = data.squeeze(1)
            for inner_batch_idx in range(args.inner_batch):
            #2 - how many steps T to consider
                step = random.randint(0,args.total_steps)
                Q_list = calculate_Q_bar(step, decay_f = args.decay_f, feature_dim_list = k_list, total_steps = args.total_steps)
                if inner_batch_idx == 0:
                    new_feature, label = data_at_T(args, k_list, data, Q_list)
                else:
                    tmp_feature, tmp_label = data_at_T(args, k_list, data, Q_list)
                    new_feature = torch.cat((new_feature, tmp_feature), 0)
                    label = torch.cat((label, tmp_label), 0)

            #here we get the data with [batch_size * inner_batch * samples_T, feature_dimension]
            # # and we broadcast labels for the generated data
            '''
            sampls_T at T0 (batch_size * samples_T)
            samples_T at T1
            ...
            samples_T at Tt
            '''
            optimizer.zero_grad()
            new_feature = new_feature.to(device)
            label = label.to(device)
            model_predict = model(new_feature)
            
            loss = criterion(model_predict, label)
            wandb.log({"loss": loss.item()})
            print('loss:'+str(loss.item()))
            if loss.item() < lowest_loss:
                lowest_loss = loss.item()
                torch.save(model.state_dict(), f'{args.save_path}model_cat_{args.label}_eph_{epoch_idx}_ratio_{ratio}.pth')
            loss.backward()
            optimizer.step()
        
    wandb.finish()
            
if __name__ == '__main__':
     main()