# -*- coding: utf-8 -*-
import random
from collections import defaultdict
import numpy as np
import pandas as pd
from logger import print_log

def process_money(money, seg):
    if money >= seg:
        m = (money+seg/2)//seg*seg
    elif money <= 499:
        m = 499
    elif money <= 999:
        m = 999
    else:
        m = money
    return m

def read_csv(fname):
    df = pd.read_csv(fname, sep="@")
    filenames = df.filename.tolist()
    ahs = df.ah.tolist()
    # process money
    money = [float(process_money(float(m), 1000)) for m in df.money.tolist()]
    months = [int(m) for m in df.months.tolist()]
    attrs = []
    for attr in df["attrs"].tolist():
        attr = list(map(int, attr.split(",")))
        attrs.append(attr[:6]+attr[7:])  # remove tool column
    sentences = df.sentence.tolist()
    return filenames, ahs, money, months, attrs, sentences

def read_summary_csv(fname):
    df = pd.read_csv(fname, sep="@")
    ahs = df.ah.tolist()
    summaries = df.summary.tolist()
    dic = {}
    for ah, summary in zip(ahs, summaries):
        dic[ah] = summary
    return dic

def get_data(fname, summary_fname):
    out = read_csv(fname)
    summary_dict = read_summary_csv(summary_fname)
    data = []
    for fn, ah, mny, mth, attr, sent in zip(*out):
        data.append({
            "ah": ah,
            "money": mny,
            "month": mth,
            "attr": attr,
            "sentence": sent,
            "summary": summary_dict[ah],
        })
    return data

def get_level(attr, pret_lvl):
    tot = int(np.sum(attr))
    if tot >= 3:
        return tot
    if pret_lvl == 2:
        return 2
    assert pret_lvl == 1
    if tot <= 1:
        if attr[1] == 1:
            # if random.random() < 2/3:
            return 2
        return 1
    if attr[1] == attr[2] == 1:
        return 1
    if attr[1] == attr[4] == 1:
        return 1
    if attr[1] == attr[6] == 1:
        return 1
    if attr[0] == attr[5] == 1:
        return 1
    return tot

def get_levels_for_data(data, pretrain_lvl):
    max_level = 6
    level2ahs_dict = defaultdict(set)
    for datapoint in data:
        lvl = get_level(datapoint["attr"], pretrain_lvl)
        assert pretrain_lvl <= lvl <= max_level
        ah = datapoint["ah"]
        if lvl == pretrain_lvl:
            level2ahs_dict[lvl].add(ah)
            continue
        for i in range(lvl, max_level + 1):
            level2ahs_dict[i].add(ah)

        # FIXME 100% supervise
        # level2ahs_dict[pretrain_lvl].add(ah)
    pretrain_ratio = len(level2ahs_dict[pretrain_lvl]) / len(data)

    print_log(f"pretrain ratio: {pretrain_ratio*100:.2f} %.", logger="current")
    return level2ahs_dict

def get_all_data(pretrain_level):
    # summary_file = "data/summary_200.csv"
    summary_file = "data/summary_600.csv"
    train_data = get_data("data/train.csv", summary_file)
    test_data = get_data("data/test.csv", summary_file)
    train_levels = get_levels_for_data(train_data, pretrain_level)
    return train_data, train_levels, test_data

if __name__ == '__main__':
    x, y, z = get_all_data()
    print(len(x), len(y), len(z))