import csv
import numpy as np
import statistics
import sys
import torch

import guilib
import kaplan_meier
import neural_network
import torch_dataset
import util


def show_histogram(model, dataset, nn_param):
    boundaries = list(range(dataset.x_train.shape[1]+1))
    print('Debug for test data')
    y_dist_test, y_label = neural_network.predict_lightning(dataset, model, nn_param)
    for id in range(5):
        raw_data = dataset.original_y_test[id]
        guilib.show_histogram(y_dist_test[id],boundaries)

def show_KMcurve(model, dataset, nn_param):
    y_dist_test, y_label = neural_network.predict_lightning(dataset, model, nn_param)
    y_dist_mean = np.mean(y_dist_test, axis=0)
    print(y_dist_mean)
    print(y_dist_mean.shape)
    sys.exit()

def write_pred(y_pred, y_label, nn_param):
    dir = nn_param['pred_output_dir']
    n_bin = nn_param['n_bin']
    segments = np.linspace(0, 1, n_bin+1)[:-1].reshape(-1,1)
    triu = np.triu(np.ones( (n_bin, n_bin) ))
    for i in range(y_pred.shape[0]):
        filename = ('test%d_' % i) + '{:.4f}'.format(y_label[i,0])
        if y_label[i,1] > 0.0:
            filename += '_uncensored'
        else:
            filename += '_censored'
        filename += '.csv'
        ft = y_pred[i,:].reshape(-1,1)
        st = np.matmul(triu, y_pred[i,:]).reshape(-1,1)
        output = np.concatenate([segments, ft, st], axis=1)
        np.savetxt(os.path.join(dir, filename), output, delimiter=',', fmt='%.5f')

def write_test_scores(scores, dataset_name, logger_loss, list_name=None):
    summary = {}
    num_cv = 1
    for i, score in enumerate(scores):
        for key, value in score.items():
            if key in summary:
                summary[key].append(value)
                if len(summary[key]) > num_cv:
                    num_cv = len(summary[key])
            else:
                summary[key] = [value]

    filename = 'test_scores_%s_%s.csv' % (dataset_name, logger_loss.model_name)
    print('Writing %s' % filename)
    keys = []
    means = []
    stdevs = []
    with open(filename, 'w', newline="") as f:
        writer = csv.writer(f)

        write_mean_std = False
        if list_name is None:
            write_mean_std = True

        # write header
        header = [ 'score' ]
        if list_name is None:
            for i in range(num_cv):
                header.append('cv%d' % i)
        else:
            header.extend(list_name)
        if write_mean_std:
            header.append('mean')
            header.append('std')
        writer.writerow(header)

        # write contents
        for key in sorted(summary):
            temp = [key]
            values = summary[key]
            temp.extend(values)
            m = statistics.mean(values)
            s = statistics.stdev(values)
            if write_mean_std:
                temp.append(m)
                temp.append(s)
            writer.writerow(temp)
            keys.append(key)
            means.append(m)
            stdevs.append(s)
    return keys, means, stdevs

def write_combined_summary(scores, dataset_name, logger_loss, list_name):
    filename = 'test_score_summary_%s_%s.csv' % (dataset_name, logger_loss.model_name)
    print('Writing %s' % filename)
    with open(filename, 'w', newline="") as f:
        writer = csv.writer(f)

        # write header
        header = [ 'score' ]
        for name in list_name:
            header.append(name+'_mean')
            header.append(name+'_stdev')
        writer.writerow(header)

        # write contents
        keys = scores[0][0][0]
        for i in range(len(scores[0][0])):
            temp = [ scores[0][0][i] ]  # key
            for j, alpha in enumerate(list_name):
                temp.append(scores[j][1][i]) # mean
                temp.append(scores[j][2][i]) # stdev
            writer.writerow(temp)
