import os
import random
import argparse
import numpy as np
from sklearn.neighbors import KNeighborsRegressor

import _init_paths
from models.dram import DRAMLN
from utils.traditional_metrics import score
from dataset.SALDL_dataset import get_datasets



os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1, 2, 3"


def parser_args():
    parser = argparse.ArgumentParser(description='Second Training')

    # data
    parser.add_argument('--dataset_name', help='dataset name', default='flickr', 
                        choices=['flickr', 'twitter', 'raf', 'emotion6', 'fbp5500'])
    parser.add_argument('--dataset_dir', help='dir of all datasets', default='./data_SALDL')

    parser.add_argument('--method', help='methods', default='DRAM', 
                        choices=['DRAM'])

    # random seed
    parser.add_argument('--seed', default=1, type=int,
                        help='seed for initializing training. ')
    
    parser.add_argument('--K', default=3, type=int,
                        help='K neighbour')
    
    args = parser.parse_args()
    return args


def get_args():
    args = parser_args()
    return args


def same_seeds(seed):
    random.seed(seed) 
    np.random.seed(seed)


def select_instance(X, D):
    '''
        Discard the instances with tie labels
    '''
    selection = []
    for i, d in enumerate(D):
        _d = d[d!=0]
        if np.unique(_d).size == _d.size:
            selection.append(int(i))
    selection = np.array(selection)
    X, D = X[selection], D[selection]
    return X, D


def main():
    args = get_args()

    if args.seed is not None:
        same_seeds(args.seed)

    train_label_data, train_label_label, train_unlabel_data, train_unlabel_label, val_data, val_label, test_data, test_label = get_datasets(args)
    print(train_label_data.shape)
    train_label_data, train_label_label = select_instance(train_label_data, train_label_label)
    print(train_label_data.shape)
    model = DRAMLN()
    train_label_label = [np.argsort(d)[d[d==0].size:].tolist() for d in train_label_label]
    model.fit(train_label_data, train_label_label, test_label.shape[-1])
    pred = model.predict(test_data)
    cheby, clark, can, kl, cosine, inter, spear, tau = score(test_label, pred)
    print('ours_metric:     can:{:.4f}, cheby:{:.4f}, clark:{:.4f}, cosine:{:.4f}. intersection:{:.4f}, KL:{:.4f}, spear:{:.4f}, tau:{:.4f}'.format(can, cheby, clark, cosine, inter, kl, spear, tau))


if __name__ == '__main__':
    main()


