from prepare_data import generate_dataloader
import argparse
import os, sys
import os.path as osp
import torch
import network
from scipy.io import loadmat
from utils import *
from train import train
import pandas as pd
import numpy as np
import copy

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='PrivacyDA')
    # data path
    parser.add_argument('--lclient_root', type=str, default='../dataset/')
    parser.add_argument('--uclient_root', type=str, default='../dataset/')
    parser.add_argument('--server_model', type=str, default='../ckpts/server/model_')
    parser.add_argument('--out_path', type=str, default='../ckpts/result/')

    # network parameters
    parser.add_argument('--in_dim', type=int, default=2048)
    parser.add_argument('--n_hidden_1', type=int, default=512)
    parser.add_argument('--n_hidden_2', type=int, default=128)
    parser.add_argument('--num_class', type=int, default=65)

    # training parameters
    parser.add_argument('--dset', type=str, default='imageCLEF')
    parser.add_argument('--train_batch', type=int, default=128)
    parser.add_argument('--test_batch', type=int, default=256)
    parser.add_argument('--epochs', type=int, default=1)
    parser.add_argument('--rounds', type=int, default=200)
    parser.add_argument('--max_iter', type=int, default=100)
    parser.add_argument('--lr', type=float, default=1e-2)
    parser.add_argument('--mean', type=float, default=0.0)
    parser.add_argument('--std', type=float, default=1.0)
    parser.add_argument('--distance', type=str, default='cosine', choices=["euclidean", "cosine"])
    parser.add_argument('--threshold', type=int, default=0)
    parser.add_argument('--epsilon', type=float, default=1e-5)

    # loss parameters
    parser.add_argument('--alpha', type=float, default=0.01)
    parser.add_argument('--beta', type=float, default=0.1)
    parser.add_argument('--ratio', type=float, default=0.5)

    args = parser.parse_args()

    if not osp.exists(args.out_path):
        os.system('mkdir -p ' + args.out_path)
    if not osp.exists(args.out_path):
        os.mkdir(args.out_path)

    args.out_file = open(osp.join(args.out_path, 'log_' + args.dset + '.txt'), 'w')

    for times in range(10):
        args.out_file = open(osp.join(args.out_path, 'log_' + args.dset + str(times) + '_mask_C2P.txt'), 'w')
        if args.dset == 'office-home':
            args.num_class = 65
            args.lclient_root = '../dataset/officehome/RealWorld.csv'
            args.uclient_root = '../dataset/officehome/Product.csv'

            ldata = pd.read_csv(args.lclient_root, header=None, index_col=None).values
            udata = pd.read_csv(args.uclient_root, header=None, index_col=None).values

            np.random.shuffle(ldata)

            limgs, llabels = np.float32(ldata[:int(args.ratio * ldata.shape[0]), :-1]), \
                             np.float32(ldata[:int(args.ratio * ldata.shape[0]), -1])

            uimgs, ulabels = np.float32(udata[:, :-1]), np.float32(udata[:, -1])

            src_imgs, src_labels = np.float32(ldata[int(args.ratio * ldata.shape[0]):, :-1]), \
                             np.float32(ldata[int(args.ratio * ldata.shape[0]):, -1])

            all_imgs = np.concatenate((np.float32(ldata[int(args.ratio * ldata.shape[0]):, :-1]),
                                        np.float32(udata[:, :-1])), axis=0)
            all_labels = np.concatenate((np.float32(ldata[int(args.ratio * ldata.shape[0]):, -1]),
                                          np.float32(udata[:, -1])), axis=0)

            args.num_src = src_imgs.shape[0]
            args.num_tgt = uimgs.shape[0]
        elif args.dset == 'imageCLEF':
            args.num_class = 12
            args.lclient_root = '../dataset/imageCLEF/c.csv'
            args.uclient_root = '../dataset/imageCLEF/p.csv'

            ldata = pd.read_csv(args.lclient_root, header=None, index_col=None).values
            udata = pd.read_csv(args.uclient_root, header=None, index_col=None).values

            np.random.shuffle(ldata)

            limgs, llabels = np.float32(ldata[:int(args.ratio * ldata.shape[0]), :-1]), \
                             np.float32(ldata[:int(args.ratio * ldata.shape[0]), -1])

            uimgs, ulabels = np.float32(udata[:, :-1]), np.float32(udata[:, -1])

            src_imgs, src_labels = np.float32(ldata[int(args.ratio * ldata.shape[0]):, :-1]), \
                                   np.float32(ldata[int(args.ratio * ldata.shape[0]):, -1])

            all_imgs = np.concatenate((np.float32(ldata[int(args.ratio * ldata.shape[0]):, :-1]),
                                       np.float32(udata[:, :-1])), axis=0)
            all_labels = np.concatenate((np.float32(ldata[int(args.ratio * ldata.shape[0]):, -1]),
                                         np.float32(udata[:, -1])), axis=0)

            args.num_src = src_imgs.shape[0]
            args.num_tgt = uimgs.shape[0]

        args.server_model = args.server_model + args.dset + '.pt'

        model = network.BasicNet(args.in_dim, args.n_hidden_1, args.n_hidden_2, args.num_class).cuda()
        lmodelB = network.BasicNet(args.in_dim, args.n_hidden_1, args.n_hidden_2, args.num_class).cuda()
        lmodelS = network.MaskSNet(args.n_hidden_2).cuda()
        lmodelD = network.Decoder(2 * args.n_hidden_2, args.n_hidden_1, args.in_dim).cuda()
        umodelB = network.BasicNet(args.in_dim, args.n_hidden_1, args.n_hidden_2, args.num_class).cuda()
        umodelS = network.MaskSNet(args.n_hidden_2).cuda()
        umodelD = network.Decoder(2 * args.n_hidden_2, args.n_hidden_1, args.in_dim).cuda()

        dataloader = {}
        dataloader['lclient'] = generate_dataloader(limgs, llabels, args.train_batch, True, True)
        dataloader['uclient'] = generate_dataloader(uimgs, ulabels, args.train_batch, True, True)
        dataloader['source_test'] = generate_dataloader(src_imgs, src_labels, args.test_batch, False, False)
        dataloader['target_test'] = generate_dataloader(uimgs, ulabels, args.test_batch, False, False)
        dataloader['all_test'] = generate_dataloader(all_imgs, all_labels, args.test_batch, False, False)


        train(args, dataloader, model, lmodelB, lmodelS, lmodelD, lmodelB, umodelS, umodelD)
