# -*- coding: utf-8 -*-

import numpy as np
import os
import pdb
import pandas as pd

data_dir = "./data"

def load_data(name="coat", util=False):

    if name == "coat":
        data_set_dir = os.path.join(data_dir, name)
        train_file = os.path.join(data_set_dir, "train.ascii")
        test_file = os.path.join(data_set_dir, "test.ascii")

        with open(train_file, "r") as f:
            x_train = []
            for line in f.readlines():
                x_train.append(line.split())

            x_train = np.array(x_train).astype(int)

        with open(test_file, "r") as f:
            x_test = []
            for line in f.readlines():
                x_test.append(line.split())

            x_test = np.array(x_test).astype(int)

        print("===>Load from {} data set<===".format(name))
        print("[train] rating ratio: {:.6f}".format((x_train>0).sum() / (x_train.shape[0] * x_train.shape[1])))
        print("[test]  rating ratio: {:.6f}".format((x_test>0).sum() / (x_test.shape[0] * x_test.shape[1])))

    elif name == "yahoo":
        data_set_dir = os.path.join(data_dir, name)
        train_file = os.path.join(data_set_dir,
            "ydata-ymusic-rating-study-v1_0-train.txt")
        test_file = os.path.join(data_set_dir,
            "ydata-ymusic-rating-study-v1_0-test.txt")

        x_train = []
        # <user_id> <song id> <rating>
        with open(train_file, "r") as f:
            for line in f:
                x_train.append(line.strip().split())
        x_train = np.array(x_train).astype(int)

        x_test = []
        # <user_id> <song id> <rating>
        with open(test_file, "r") as f:
            for line in f:
                x_test.append(line.strip().split())
        x_test = np.array(x_test).astype(int)
        print("===>Load from {} data set<===".format(name))
        print("[train] num data:", x_train.shape[0])
        print("[test]  num data:", x_test.shape[0])

        return x_train[:,:-1], x_train[:,-1], \
            x_test[:, :-1], x_test[:,-1]

    elif name == 'kuai' and util == False:
        data_set_dir = os.path.join(data_dir, name)
        train_file = os.path.join(data_set_dir, "user.txt")
        test_file = os.path.join(data_set_dir, "random.txt")

        x_train = []
        # <user_id> <song id> <rating>
        with open(train_file, "r") as f:
            for line in f:
                lst = line.strip().split(',')
                lst[2] = int(float(lst[2]))
                x_train.append(lst)
        x_train = np.array(x_train).astype(int)
        # print(x_train[:3])

        x_test = []
        # <user_id> <song id> <rating>
        with open(test_file, "r") as f:
            for line in f:
                lst = line.strip().split(',')
                lst[2] = int(float(lst[2]))
                x_test.append(lst)
        x_test = np.array(x_test).astype(int)
        print("===>Load from {} data set<===".format(name))
        print("[train] num data:", x_train.shape[0])
        print("[test] num data:", x_test.shape[0])

        return x_train[:,:-1], x_train[:,-1], x_test[:, :-1], x_test[:,-1]

    elif name == 'kuai' and util == True:
        rdf_train = np.array(pd.read_table("./data/kuai/Copy of user.txt", header = None, sep = ','))     
        rdf_test = np.array(pd.read_table("./data/kuai/Copy of random.txt", header = None, sep = ','))
        rdf_train_new = np.c_[rdf_train, np.ones(rdf_train.shape[0])]
        rdf_test_new = np.c_[rdf_test, np.zeros(rdf_test.shape[0])]
        rdf = np.r_[rdf_train_new, rdf_test_new]
        
        rdf = rdf[np.argsort(rdf[:, 0])]
        c = rdf.copy()
        for i in range(rdf.shape[0]):
            if i == 0:
                c[:, 0][i] = i
                temp = rdf[:, 0][0]
            else:
                if c[:, 0][i] == temp:
                    c[:, 0][i] = c[:, 0][i-1]
                else:
                    c[:, 0][i] = c[:, 0][i-1] + 1
                temp = rdf[:, 0][i]
        
    #     print(c)
        c = c[np.argsort(c[:, 1])]
    #     print(c)
        d = c.copy()
        for i in range(rdf.shape[0]):
            if i == 0:
                d[:, 1][i] = i
                temp = c[:, 1][0]
            else:
                if d[:, 1][i] == temp:
                    d[:, 1][i] = d[:, 1][i-1]
                else:
                    d[:, 1][i] = d[:, 1][i-1] + 1
                temp = c[:, 1][i]
    #     d = d[np.argsort(d[:, 0])]    
    #     print(d)
        y_train = d[:, 2][d[:, 3] == 1]
        y_test = d[:, 2][d[:, 3] == 0]
        x_train = d[:, :2][d[:, 3] == 1]
        x_test = d[:, :2][d[:, 3] == 0]
        return x_train, y_train, x_test, y_test
    else:
        print("Cant find the data set",name)
        return

    return x_train, x_test


def rating_mat_to_sample(mat):
    row, col = np.nonzero(mat)
    y = mat[row,col]
    x = np.concatenate([row.reshape(-1,1), col.reshape(-1,1)], axis=1)
    return x, y