import numpy as np
import random

def read_csv(path):
    fp = open(path, 'r')
    rows = fp.read().split('\n')[: -1]
    split = lambda row: [int(e) if e.isdigit() else e for e in row.split(',')]
    rows = [split(row) for row in rows]

    return rows

def load_data(path):
    qsts = read_csv('%s/question.csv' % path)
    qst_num = len(qsts)

    usrs = read_csv('%s/user.csv' % path)
    usr_num = len(usrs)

    skls = read_csv('%s/skill.csv' % path)
    skl_num = len(skls)

    qst_skl = read_csv('%s/question_skill.csv' % path)
    qst_skl = np.array(qst_skl).astype(np.int64)[:, 1]

    records = read_csv('%s/record.csv' % path)
    records = np.array(records).astype(np.int64)
    
    return qst_num, usr_num, skl_num, qst_skl, records

def create_seqs(usr_num, records):
    seqs = [[] for usr in range(usr_num)]

    for usr, qst, rst in records:
        seqs[usr].append([qst, rst, 1])

    return seqs

def format_seqs(seqs, seq_len):
    update_seqs = []
    for seq in seqs:
        rmd = len(seq) % seq_len
        pad = 0 if rmd == 0 else seq_len - rmd
        seq_arr = np.array(seq + [[0] * 3] * pad).reshape(-1, seq_len, 3)
        update_seqs = update_seqs + seq_arr.tolist()
        
    return update_seqs

def init_qst_clr(qst_num, clr_num):
    qsts = np.arange(qst_num)
    np.random.shuffle(qsts)
    
    base_size = qst_num // clr_num
    remainder = qst_num % clr_num
    
    labels = np.zeros(qst_num, dtype=int)
    start = 0
    
    for i in range(clr_num):
        group_size = base_size + (1 if i < remainder else 0)
        end = start + group_size
        labels[qsts[start: end]] = i
        start = end
        
    return labels

class CrossValidation:
    def __init__(self, seqs, fold, seq_len):
        seq_list = self.fold_seqs(seqs, fold)
        seq_list = [format_seqs(seqs, seq_len) for seqs in seq_list]
        seq_list = [np.array(seqs, dtype = np.int64) for seqs in seq_list]

        self.seq_list = seq_list
        self.fold = fold
        self.index = 0

    def __iter__(self):
        return self

    def __next__(self):
        if self.index < self.fold:
            evl_seqs = self.seq_list[self.index]
            trn_seqs = np.concatenate([self.seq_list[f] for f in range(self.fold) if f != self.index], axis = 0)
            self.index += 1
            return trn_seqs, evl_seqs
        else:
            raise StopIteration
        
    def fold_seqs(self, seqs, fold):
        random.shuffle(seqs)
        
        avg_len = len(seqs) // fold
        remainder = len(seqs) % fold
        
        seq_list = []
        
        start = 0
        for i in range(fold):
            end = start + avg_len + (1 if i < remainder else 0)
            seq_list.append(seqs[start:end])
            start = end
        
        return seq_list