import difflib
import json
import os
import sys

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from Tools.jsonl_utils import *

from flask import Flask, render_template_string, session, redirect, url_for, flash
from flask_bootstrap import Bootstrap
from flask_wtf import FlaskForm
from wtforms import TextAreaField
from json import load, dump
import argparse
import re
import shutil
from glob import glob
from collections import Counter

app = Flask(__name__)
app.config['SECRET_KEY'] = 'Thisisasecret!'
bootstrap = Bootstrap(app)


class DataForm(FlaskForm):
    pass


@app.before_first_request
def initialize_session():
    session['current_index'] = 1


@app.route('/', methods=['GET', 'POST'])
def index():
    print(len(json_files))
    if session['current_index'] > len(json_files) or session['current_index'] < 1:
        session['current_index'] = 1
    with open(json_files[session['current_index'] - 1]) as f:
        data = load(f)
    keys = sorted(data.keys(),
                  key=lambda k: (int(k.split(' #')[-1]) if k.split(' #')[-1].isdigit() else float('inf'), k != 'input'))
    # Now set the form fields in sorted order
    for key in keys:
        setattr(DataForm, key, TextAreaField(key, default=data[key]))

    form = DataForm()
    if form.validate_on_submit():
        for key in data.keys():
            if key != "input":
                data[key] = form.data[key].replace("\r", '')
            else:
                # check the data from website
                ratio = difflib.SequenceMatcher(None, data[key][:200], form.data[key][:200]).ratio()
                if ratio < 0.8:
                    print(data[key][:200])
                    print(form.data[key][:200])
                    raise ValueError("Unexpected error! Please re-run the script (with mode = continue)")
        evaled_json = json_files[session['current_index'] - 1].replace("_example", "_evaled")
        with open(evaled_json, 'w') as f:
            f.write(json.dumps(data))
        flash(f'Updated successfully! Save to --> {evaled_json}')
        json_files.pop(session['current_index'] - 1)
        session['current_index'] = min(len(json_files), session['current_index'])
        evaled_files = sorted(glob(f'{jsons_dir}/*_evaled.json'))
        if len(evaled_files) >= total_num:
            # 说明已经 达到了期待的数量 检查是否有 evaled 的 jsonl 文件
            res_string = f"<h3> You have finished the human evaluation process! Scored files are saved to {evaled_jsonl_path} </h3>"

            with open(evaled_jsonl_path, "w") as fw:
                for cfile in evaled_files:
                    c_temp = open(cfile)
                    evaled_sample = json.load(c_temp)
                    c_temp.close()
                    evaled_sample = convert_html_to_key(evaled_sample)
                    print(json.dumps(evaled_sample), file=fw)

            key_to_avg = calc_avg()
            with open(os.path.join(args.path, f"results_from_{annotator_name}.jsonl"), "w") as f_key:
                f_key.write(json.dumps(key_to_counter))
            for key in key_to_counter:
                sorted_counter = dict(sorted(key_to_counter[key].items(), key=lambda item: float(item[0])))
                res_string += f"<h4> {key} [score: frequency]: {sorted_counter}  ====>  Average score: {key_to_avg[key]}</h4>"

            return render_template_string(res_string)

        return redirect(url_for("index"))

    html = '''
    <!DOCTYPE html>
    <html>
        <head>
            <!-- CSS only -->
            <link rel="stylesheet" href="https://stackpath.bootstrapcdn.com/bootstrap/4.5.2/css/bootstrap.min.css">
            <style>
            h1 {
                font-family: 'Georgia', serif;
                color: #333;
                text-align: center;
            }
            hr {
                border-top: 2px solid #ccc;
            }
            /* Added new styles here */
            .form-group.row.odd {
                background-color: #f8f9fa; 
                margin: 10px 0;
                padding: 20px;
                border-radius: 5px;
            }
            .form-group.row.even {
                background-color: #e9ecef;
                margin: 10px 0;
                padding: 20px;
                border-radius: 5px;
            }
        </style>
        </head>
        <body>
            <!-- Modal -->
            <div class="modal fade" id="myModal" tabindex="-1" role="dialog" aria-labelledby="myModalLabel" aria-hidden="true">
              <div class="modal-dialog" role="document">
                <div class="modal-content">
                  <div class="modal-header">
                    <h5 class="modal-title" id="myModalLabel">Notice</h5>
                    <button type="button" class="close" data-dismiss="modal" aria-label="Close">
                      <span aria-hidden="true">&times;</span>
                    </button>
                  </div>
                  <div class="modal-body">
                        {% with messages = get_flashed_messages() %}
                        {% if messages %}
                        {{ messages[0] }}
                        {% endif %}
                        {% endwith %}
                    </div>
                  <div class="modal-footer">
                    <button type="button" class="btn btn-secondary" data-dismiss="modal">Close</button>
                  </div>
                </div>
              </div>
            </div>
            <div class="container mt-5">
                <a href="#bottom">
                    <h1 class="mb-3"> Human Annotation for LEval </h1>
                    <hr class="mb-4">
                </a>
                <form method="POST">
                {{ form.csrf_token }}
                {% set counter = 0 %}
                {% for key in keys %}
                    {% set field = form[key] %}
                    {% if field.type != "HiddenField" and field.name != "csrf_token" and field.name != "input" %}
                        <div class="form-group row {{ 'odd' if loop.index0 // num_of_group % 2  == 0 else 'even' }}">
                            <label class="col-sm-2 col-form-label font-weight-bold text-uppercase">{{ field.label.text }}</label>
                            <div class="col-sm-10">
                                {{ field(class="form-control") }}
                            </div>
                        </div>
                    {% endif %}
                {% endfor %}
                <hr class="mb-4">
                <div class="d-flex justify-content-between align-items-end">
                    <div>
                        <button type="submit" formaction="{{ url_for('previous') }}" class="btn btn-primary btn-lg" {{ 'disabled' if session['current_index'] <= 1 }}>Previous</button>
                        <button type="submit" formaction="{{ url_for('next') }}" class="btn btn-primary btn-lg" {{ 'disabled' if session['current_index'] >= file_count }}>Next</button>
                    </div>
                    <button type="submit" class="btn btn-danger btn-lg">Submit</button>
                </div>
                <hr class="mb-4">
                {% set field = form['input'] %}
                <div class="form-group row" id="bottom">
                    <label class="col-sm-2 col-form-label font-weight-bold text-uppercase">{{ field.label.text }}</label>
                    <div class="col-sm-10">
                        {{ field(class="form-control-plaintext") }}
                    </div>
                </div>
            </form>
            </div>
            <hr class="mb-4">

            <!-- JS, Popper.js, and jQuery -->
            <script src="https://code.jquery.com/jquery-3.5.1.slim.min.js"></script>
            <script src="https://cdn.jsdelivr.net/npm/bootstrap@4.5.2/dist/js/bootstrap.bundle.min.js"></script>
            <script src="https://cdnjs.cloudflare.com/ajax/libs/autosize.js/4.0.2/autosize.min.js"></script>
            <script>
            $(document).ready(function(){
                autosize($('textarea'));
                {% with messages = get_flashed_messages() %}
                {% if messages %}
                $('#myModal').modal('show')
                {% endif %}
                {% endwith %}
            });
            </script>
        </body>
    </html>
    '''
    return render_template_string(html, form=form, keys=keys, file_count=len(json_files), num_of_group=num_of_group)


@app.route('/next', methods=['POST'])
def next():
    session['current_index'] = min(len(json_files), session['current_index'] + 1)
    return redirect(url_for('index'))


@app.route('/previous', methods=['POST'])
def previous():
    session['current_index'] -= 1
    return redirect(url_for('index'))


def calc_avg():
    key_to_avg = {}
    for key in key_to_counter:
        counter = key_to_counter[key]
        total = 0
        res = 0
        for score in counter:
            res += float(score) * counter[score]
            total += counter[score]
        key_to_avg[key] = res / total
    return key_to_avg

def convert_key_to_html(samples):
    for sample in samples:
        list_keys = []
        for key in sample:
            if isinstance(sample[key], list):
                list_keys.append(key)
        for key in list_keys:
            predictions = sample[key]
            for i, pred in enumerate(predictions):
                if key != "instructions" and key != "outputs" and split_line not in pred:
                    pred = f"{pred}{split_line}"
                sample[f"{key} #{i}"] = pred
            sample.pop(key)


def convert_html_to_key(evaled_sample):
    new_sample = {}
    global key_to_counter
    max_num = len(evaled_sample.keys()) // num_of_group
    for data_key in evaled_sample:
        data_value = evaled_sample[data_key]
        data_temp = data_key.split(" #")
        if len(data_temp) == 2:
            pre_ = data_temp[0]
            post_ = int(data_temp[1])
            if pre_ not in new_sample:
                new_sample[pre_] = [""] * max_num
            new_sample[pre_][post_] = data_value
            if split_line in data_value:
                score = float(data_value.split(split_line)[1])
                if pre_ not in key_to_counter:
                    key_to_counter[pre_] = Counter([str(score)])
                else:
                    key_to_counter[pre_][str(score)] += 1

    new_sample["evaluation"] = evaled_sample["evaluation"]
    new_sample["input"] = evaled_sample["input"]
    new_sample["source"] = evaled_sample["source"]
    return new_sample


def ground_with_gold(gold_samples, pred_files):
    key2pred = {}
    for pred_file in pred_files:
        pred_samples = read_jsonl(pred_file)
        # find pred key
        pred_key = None
        for key in pred_samples[0]:
            if "_pred" in key:
                pred_key = key
                break
        print(pred_key)
        assert pred_key is not None

        for sample in pred_samples:
            key2pred[sample["query"] + sample["gt"]] = (sample[pred_key], pred_key)
        for sample in gold_samples:
            for inst, out in zip(sample["instructions"], sample["outputs"]):
                pred, pred_key = key2pred[inst + out]
                if pred_key in sample:
                    sample[pred_key].append(pred)
                else:
                    sample[pred_key] = [pred]
    return gold_samples


def merge_evaled_files():
    evaled_files = sorted(glob(f'{jsons_dir}/*_evaled.json'))
    evaled_jsonl_path = re.sub(r'(\.jsonl)$', '.evaled.jsonl', args.path)
    with open(evaled_jsonl_path, "w") as fw:
        for cfile in evaled_files:
            print("processing", cfile)
            c_temp = open(cfile)
            evaled_sample = json.load(c_temp)
            c_temp.close()
            evaled_sample = convert_html_to_key(evaled_sample)
            print(json.dumps(evaled_sample), file=fw)
    exit(0)


if __name__ == "__main__":

    gold_path = "Predictions/human_eval/claude.gpt4.ref.jsonl"
    split_line = "\n ---------------------- Your Score (1~5) ----------------------\n"
    key_to_counter = {} # calculate the final result

    assert os.path.exists(gold_path)

    parser = argparse.ArgumentParser()
    parser.add_argument('--path', default="Predictions/human_eval")
    parser.add_argument('--mode', choices=["begin", "continue"], default="continue",
                        help="Warning: begin the annotate process, this will delete your previous result")
    parser.add_argument('--merge_evaled_files', action="store_true",
                        help="merge the *_cleand.json to a new jsonl file")
    args = parser.parse_args()
    pred_files = glob(f"{args.path}/*.pred.jsonl")
    num_of_group = len(pred_files) + 4  # instructions, outputs , gpt4 outputs, claude outputs
    samples = read_jsonl(gold_path)
    samples = ground_with_gold(samples, pred_files)
    total_num = len(samples)
    convert_key_to_html(samples)

    jsons_dir = os.path.join(args.path, "jsons")
    os.makedirs(jsons_dir, exist_ok=True)
    if args.merge_evaled_files:
        args.mode = "continue"
    if args.mode == "begin":
        input("this will remove your previous annotation! press enter to confirm... ")
        shutil.rmtree(jsons_dir)
        os.makedirs(jsons_dir, exist_ok=True)
        for i, sample in enumerate(samples):
            with open(f'{jsons_dir}/{i}_example.json', "w") as f:
                f.write(json.dumps(sample))
    json_files = sorted(glob(f'{jsons_dir}/*_example.json'))
    json_files = [jfile for jfile in json_files if not os.path.exists(jfile.replace("_example", "_evaled"))]
    evaled_files = sorted(glob(f'{jsons_dir}/*_evaled.json'))

    annotator_name = input("Please type your name (e.g., Annotator.1):  ")

    evaled_jsonl_path = os.path.join(args.path, f"files_from_{annotator_name}.jsonl")
    current_num = len(evaled_files)
    ## call merge evaled files
    if args.merge_evaled_files:
        merge_evaled_files()
    assert len(json_files) != 0

    app.run(debug=False, port=5000)
