import os

import numpy as np
import pandas as pd

from GenericTools.PlotTools.mpl_tools import load_plot_settings
from GenericTools.StayOrganizedTools.unzip import unzip_good_exps

pd = load_plot_settings(pd=pd)

is_samples_for_table = False
is_samples_for_survey = False
is_samples_from_survey = True

CDIR = os.path.dirname(os.path.realpath(__file__))

EXPERIMENTS = os.path.join(CDIR, 'experiments')
GEXPERIMENTS = r'D:\work\ariel_tests\good_experiments\2021-06-07--good-GW'

SURVEYSPATH = os.path.join(EXPERIMENTS, 'surveys')
os.makedirs(SURVEYSPATH, exist_ok=True)

if is_samples_for_table or is_samples_for_survey:
    ds = unzip_good_exps(
        GEXPERIMENTS, EXPERIMENTS,
        exp_identifiers=['embedding'], except_folders=[],
        unzip_what=['sampled_sentences.txt', ]  # 'png_content', 'train_model']
    )

options_count = {'embedding_ae': 0, 'embedding_vae': 0, 'embedding_lmariel': 0, 'embedding_transformer': 0, }
if is_samples_for_table:
    n_samples = 1

    ds = [d for d in ds if '512d' in d]

    for d in ds:
        for k, count in options_count.items():
            if k in d and count == 0:
                print(k)
                options_count[k] = 1
                samples_path = os.path.join(EXPERIMENTS, d, 'text', 'sampled_sentences.txt')
                with open(samples_path, "r") as f:
                    for _ in range(n_samples):
                        line = f.readline()
                        idx = line.index('?') + 1 if '?' in line else None
                        print(line[:idx])
                print('\n')

if is_samples_for_survey:
    n_surveys = 20
    n_samples = 200
    assert n_samples > n_surveys

    data = {}
    for i, d in enumerate(ds):
        for k, count in options_count.items():
            if k in d:
                name = '{}_{}'.format(k, '16' if '_16d_' in d else '512')
                data['{}:{}'.format(i, name)] = []
                options_count[k] = 1
                samples_path = os.path.join(EXPERIMENTS, d, 'text', 'sampled_sentences.txt')
                with open(samples_path, "r") as f:
                    for _ in range(n_samples):
                        line = f.readline()
                        idx = line.index('?') + 1 if '?' in line else None
                        sentence = line[:idx].replace('\n', ' ')
                        data['{}:{}'.format(i, name)].append(sentence)

    random_indices = np.random.choice(n_samples, n_surveys, replace=False)
    for i in range(n_surveys):
        idx = random_indices[i]
        samples = np.array([v[idx] for v in data.values()])
        sample_origin = np.array([k.split(':')[1] for k in data.keys()])

        df = pd.DataFrame(columns=['sentence', 'interpretability (1-5)', ])
        df['sentence'] = samples
        df['interpretability (1-5)'] = np.random.choice(5, len(samples))
        print(df.head())
        # df.to_csv(os.path.join(EXPERIMENTS, 'survey.csv'), sep='\t', encoding='utf-8')
        df.to_excel(os.path.join(SURVEYSPATH, 'survey_{}.xlsx'.format(i)))

        df = pd.DataFrame(columns=['origin', 'sentence', 'interpretability (1-5)', ])
        df['sentence'] = samples
        df['origin'] = sample_origin
        df['interpretability (1-5)'] = 0

        print(df.head())
        print(df.shape)
        # df.to_csv(os.path.join(EXPERIMENTS, 'survey_spoiled.csv'), sep='\t', encoding='utf-8')
        df.to_excel(os.path.join(SURVEYSPATH, 'survey_spoiled_{}.xlsx'.format(i)))

if is_samples_from_survey:
    human_surveys_path = os.path.join(CDIR, 'good_experiments', 'replies')
    human_surveys_path = SURVEYSPATH
    ds_unspoiled = [d for d in os.listdir(human_surveys_path) if 'spoiled' not in d]
    ds_spoiled = [d for d in os.listdir(human_surveys_path) if 'spoiled' in d]

    names = []
    values = []
    for s, u in zip(ds_spoiled, ds_unspoiled):
        spoiled = pd.read_excel(os.path.join(human_surveys_path, s), engine='openpyxl')
        unspoiled = pd.read_excel(os.path.join(human_surveys_path, u), engine='openpyxl')
        names.extend(spoiled['origin'].values.tolist())
        values.extend(unspoiled.iloc[:, -1].values)
        # spoiled_2 = pd.read_excel(os.path.join(EXPERIMENTS, 'survey_spoiled_2.xlsx'), engine='openpyxl')

    names = np.array(names)
    values = np.array(values)

    unique_names = np.unique(names)
    for n in unique_names:
        print(n)
        instances = names == n
        print('    {} +- {}'.format(np.mean(values[instances]).round(2), np.std(values[instances]).round(2)))
