import os
import random
import argparse
import numpy as np
from sklearn.neighbors import KNeighborsRegressor

import _init_paths
from models.methods import DF_LDL, LDLF, LDL_LRR, LDL_SCL
from utils.traditional_metrics import score
from dataset.SALDL_dataset import get_datasets


os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1, 2, 3"


def parser_args():
    parser = argparse.ArgumentParser(description='Second Training')

    # data
    parser.add_argument('--dataset_name', help='dataset name', default='flickr', 
                        choices=['flickr', 'twitter', 'raf', 'emotion6', 'fbp5500'])
    parser.add_argument('--dataset_dir', help='dir of all datasets', default='./data_SALDL')

    parser.add_argument('--method', help='methods', default='LDL_LRR', 
                        choices=['LDL_LRR', 'DF_LDL', 'LDLF', 'SA_LDL', 'LDL_SCL'])

    # random seed
    parser.add_argument('--seed', default=1, type=int,
                        help='seed for initializing training. ')
    
    parser.add_argument('--K', default=3, type=int,
                        help='K neighbour')
    
    args = parser.parse_args()
    return args


def get_args():
    args = parser_args()
    return args


def same_seeds(seed):
    random.seed(seed) 
    np.random.seed(seed)


def main():
    args = get_args()

    if args.seed is not None:
        same_seeds(args.seed)

    train_label_data, train_label_label, train_unlabel_data, train_unlabel_label, val_data, val_label, test_data, test_label = get_datasets(args)
    if args.method == 'LDL_LRR':
        model = LDL_LRR()
    elif args.method == 'LDLF':
        model = LDLF()
    elif args.method == 'DF_LDL':
        model = DF_LDL()
    elif args.method == 'LDL_SCL':
        model = LDL_SCL()
    elif args.method == 'SA_LDL':
        model = KNeighborsRegressor(n_neighbors=args.K, metric='l1', algorithm='brute')
    model.fit(train_label_data, train_label_label)
    pred = model.predict(test_data)
    # can, cheby, clark, cosine, inter, kl, spear, tau = model.score(test_data, test_label)
    # print('origin:      can:{:.4f}, cheby:{:.4f}, clark:{:.4f}, cosine:{:.4f}. intersection:{:.4f}, KL:{:.4f}, spear:{:.4f}, tau:{:.4f}'.format(can, cheby, clark, cosine, inter, kl, spear, tau))
    cheby, clark, can, kl, cosine, inter, spear, tau = score(test_label, pred)
    print('ours_metric:     can:{:.4f}, cheby:{:.4f}, clark:{:.4f}, cosine:{:.4f}. intersection:{:.4f}, KL:{:.4f}, spear:{:.4f}, tau:{:.4f}'.format(can, cheby, clark, cosine, inter, kl, spear, tau))


if __name__ == '__main__':
    main()


