import difflib
import json
import os.path

from flask import Flask, render_template_string, session, redirect, url_for, flash
from flask_bootstrap import Bootstrap
from flask_wtf import FlaskForm
from wtforms import TextAreaField
from json import load, dump
from glob import glob
import argparse
import re
import shutil
from jsonl_utils import read_jsonl

app = Flask(__name__)
app.config['SECRET_KEY'] = 'Thisisasecret!'
bootstrap = Bootstrap(app)


class DataForm(FlaskForm):
    pass


@app.before_first_request
def initialize_session():
    session['current_index'] = 1


@app.route('/', methods=['GET', 'POST'])
def index():
    print(len(json_files))
    if session['current_index'] > len(json_files) or session['current_index'] < 1:
        session['current_index'] = 1
    with open(json_files[session['current_index'] - 1]) as f:
        data = load(f)
    keys = sorted(data.keys(),
                  key=lambda k: (int(k.split(' #')[-1]) if k.split(' #')[-1].isdigit() else float('inf'), k != 'input'))
    # Now set the form fields in sorted order
    for key in keys:
        setattr(DataForm, key, TextAreaField(key, default=data[key]))

    form = DataForm()
    if form.validate_on_submit():
        for key in data.keys():
            if key != "input":
                data[key] = form.data[key].replace("\r", '')
            else:
                ratio = difflib.SequenceMatcher(None, data[key][:200] , form.data[key][:200]).ratio()
                if ratio < 0.8:
                    print(data[key][:200])
                    print(form.data[key][:200])
                    raise ValueError("Unexpected error! Please re-run the script (with mode = continue)")
        cleaned_json = json_files[session['current_index'] - 1].replace("_example", "_cleaned")
        with open(cleaned_json, 'w') as f:
            f.write(json.dumps(data))
        flash(f'Updated successfully! Save to --> {cleaned_json}')
        json_files.pop(session['current_index'] - 1)
        session['current_index'] = min(len(json_files), session['current_index'])
        cleaned_files = sorted(glob(f'{jsons_dir}/*_cleaned.json'))
        if len(cleaned_files) >= num:
            # 说明已经 达到了期待的数量 检查是否有 cleaned 的 jsonl 文件
            with open(cleaned_jsonl_path, "w") as fw:
                for cfile in cleaned_files:
                    c_temp = open(cfile)
                    cleaned_sample = json.load(c_temp)
                    c_temp.close()
                    cleaned_sample = convert_html_to_key(cleaned_sample)
                    print(json.dumps(cleaned_sample), file=fw)
            return render_template_string(
                f"<h3> You have finished the work and all {len(cleaned_files)} samples have been saved to {cleaned_jsonl_path} </h3>")

        return redirect(url_for("index"))

    html = '''
    <!DOCTYPE html>
    <html>
        <head>
            <!-- CSS only -->
            <link rel="stylesheet" href="https://stackpath.bootstrapcdn.com/bootstrap/4.5.2/css/bootstrap.min.css">
            <style>
            h1 {
                font-family: 'Georgia', serif;
                color: #333;
                text-align: center;
            }
            hr {
                border-top: 2px solid #ccc;
            }
            /* Added new styles here */
            .form-group.row.odd {
                background-color: #f8f9fa; 
                margin: 10px 0;
                padding: 20px;
                border-radius: 5px;
            }
            .form-group.row.even {
                background-color: #e9ecef;
                margin: 10px 0;
                padding: 20px;
                border-radius: 5px;
            }
        </style>
        </head>
        <body>
            <!-- Modal -->
            <div class="modal fade" id="myModal" tabindex="-1" role="dialog" aria-labelledby="myModalLabel" aria-hidden="true">
              <div class="modal-dialog" role="document">
                <div class="modal-content">
                  <div class="modal-header">
                    <h5 class="modal-title" id="myModalLabel">Notice</h5>
                    <button type="button" class="close" data-dismiss="modal" aria-label="Close">
                      <span aria-hidden="true">&times;</span>
                    </button>
                  </div>
                  <div class="modal-body">
                        {% with messages = get_flashed_messages() %}
                        {% if messages %}
                        {{ messages[0] }}
                        {% endif %}
                        {% endwith %}
                    </div>
                  <div class="modal-footer">
                    <button type="button" class="btn btn-secondary" data-dismiss="modal">Close</button>
                  </div>
                </div>
              </div>
            </div>
            <div class="container mt-5">
                <a href="#bottom">
                    <h1 class="mb-3"> Human Annotation for LEval </h1>
                    <hr class="mb-4">
                </a>
                <form method="POST">
                {{ form.csrf_token }}
                {% set counter = 0 %}
                {% for key in keys %}
                    {% set field = form[key] %}
                    {% if field.type != "HiddenField" and field.name != "csrf_token" and field.name != "input" %}
                        <div class="form-group row {{ 'odd' if loop.index0 // 2 % 2 == 0 else 'even' }}">
                            <label class="col-sm-2 col-form-label font-weight-bold text-uppercase">{{ field.label.text }}</label>
                            <div class="col-sm-10">
                                {{ field(class="form-control") }}
                            </div>
                        </div>
                    {% endif %}
                {% endfor %}
                <hr class="mb-4">
                <div class="d-flex justify-content-between align-items-end">
                    <div>
                        <button type="submit" formaction="{{ url_for('previous') }}" class="btn btn-primary btn-lg" {{ 'disabled' if session['current_index'] <= 1 }}>Previous</button>
                        <button type="submit" formaction="{{ url_for('next') }}" class="btn btn-primary btn-lg" {{ 'disabled' if session['current_index'] >= file_count }}>Next</button>
                    </div>
                    <button type="submit" class="btn btn-danger btn-lg">Submit</button>
                </div>
                <hr class="mb-4">
                {% set field = form['input'] %}
                <div class="form-group row" id="bottom">
                    <label class="col-sm-2 col-form-label font-weight-bold text-uppercase">{{ field.label.text }}</label>
                    <div class="col-sm-10">
                        {{ field(class="form-control-plaintext") }}
                    </div>
                </div>
            </form>
            </div>
            <hr class="mb-4">

            <!-- JS, Popper.js, and jQuery -->
            <script src="https://code.jquery.com/jquery-3.5.1.slim.min.js"></script>
            <script src="https://cdn.jsdelivr.net/npm/bootstrap@4.5.2/dist/js/bootstrap.bundle.min.js"></script>
            <script src="https://cdnjs.cloudflare.com/ajax/libs/autosize.js/4.0.2/autosize.min.js"></script>
            <script>
            $(document).ready(function(){
                autosize($('textarea'));
                {% with messages = get_flashed_messages() %}
                {% if messages %}
                $('#myModal').modal('show')
                {% endif %}
                {% endwith %}
            });
            </script>
        </body>
    </html>
    '''
    return render_template_string(html, form=form, keys=keys, file_count=len(json_files))


@app.route('/next', methods=['POST'])
def next():
    session['current_index'] = min(len(json_files), session['current_index'] + 1)
    return redirect(url_for('index'))


@app.route('/previous', methods=['POST'])
def previous():
    session['current_index'] -= 1
    return redirect(url_for('index'))


def convert_key_to_html(samples):
    for sample in samples:
        try:
            instructions = sample["instructions"]
            outputs = sample["outputs"]
            if len(outputs) < len(instructions):
                outputs += ["pad"] * (len(instructions) - len(outputs))
            ziped = zip(instructions, outputs)
            for i, (ins, out) in enumerate(ziped):
                sample[f"instruction #{i}"] = ins
                sample[f"output #{i}"] = out
            for i in range(args.new_pairs_num):
                sample[f"instruction #{i + len(instructions)}"] = ""
                sample[f"output #{i + len(outputs)}"] = ""
            sample.pop("instructions")
            sample.pop("outputs")

        except:
            print(f"the instructions and outputs are blank, we will add {args.new_pairs_num} new pairs")
            # if the original file do not have "instructions" and "outputs"
            for i in range(args.new_pairs_num):
                sample[f"instruction #{i}"] = ""
                sample[f"output #{i}"] = ""


def convert_html_to_key(cleaned_sample):
    new_sample = {}
    max_num = len(cleaned_sample.keys()) // 2
    instructions = ["" for _ in range(max_num)]
    outputs = ["" for _ in range(max_num)]
    for data_key in cleaned_sample:
        data_value = cleaned_sample[data_key]
        data_temp = data_key.split(" #")
        if len(data_temp) == 2:
            pre_ = data_temp[0]
            post_ = int(data_temp[1])
            if pre_ == "instruction":
                instructions[post_] = data_value
            else:
                outputs[post_] = data_value
    assert len(instructions) == len(outputs)
    new_sample["instructions"] = [inst for inst in instructions if len(inst) > 0]
    new_sample["outputs"] = [out for out in outputs if len(out) > 0]
    new_sample["input"] = cleaned_sample["input"]
    new_sample["source"] = cleaned_sample["source"]
    return new_sample


def merge_cleaned_files():
    cleaned_files = sorted(glob(f'{jsons_dir}/*_cleaned.json'))
    cleaned_jsonl_path = re.sub(r'(\.jsonl)$', '.cleaned.jsonl', args.path)
    with open(cleaned_jsonl_path, "w") as fw:
        for cfile in cleaned_files:
            print("processing",cfile)
            c_temp = open(cfile)
            cleaned_sample = json.load(c_temp)
            c_temp.close()
            cleaned_sample = convert_html_to_key(cleaned_sample)
            print(json.dumps(cleaned_sample), file=fw)
    exit(0)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--path', default="LEval-data/summary/qmsum.jsonl")
    parser.add_argument('--anno_num', type=int, default=-1)
    parser.add_argument('--mode', choices=["begin", "continue"], default="continue",
                        help="Warning: begin the annotate process, this will delete your previous result")
    parser.add_argument('--new_pairs_num', type=int, default=0, help="the number of new pairs you want to add")
    parser.add_argument('--merge_cleaned_files', action="store_true", help="merge the *_cleand.json to a new jsonl file")
    args = parser.parse_args()
    num = args.anno_num
    samples = read_jsonl(args.path)
    if num < 0:
        num = len(samples)
    convert_key_to_html(samples)

    jsons_dir = os.path.join(os.path.dirname(args.path), "jsons")
    os.makedirs(jsons_dir, exist_ok=True)
    if args.merge_cleaned_files:
        args.mode = "continue"
    if args.mode == "begin":
        input("this will remove your previous annotation! press enter to confirm... ")
        shutil.rmtree(jsons_dir)
        os.makedirs(jsons_dir, exist_ok=True)
        for i, sample in enumerate(samples):
            with open(f'{jsons_dir}/{i}_example.json', "w") as f:
                f.write(json.dumps(sample))
    json_files = sorted(glob(f'{jsons_dir}/*_example.json'))
    json_files = [jfile for jfile in json_files if not os.path.exists(jfile.replace("_example", "_cleaned"))]
    cleaned_files = sorted(glob(f'{jsons_dir}/*_cleaned.json'))
    cleaned_jsonl_path = re.sub(r'(\.jsonl)$', '.cleaned.jsonl', args.path)
    current_num = len(cleaned_files)
    ## call merge cleaned files
    if args.merge_cleaned_files:
        merge_cleaned_files()
    if current_num >= num:
        # the expected number has been reached. Check if there is a cleaned jsonl file.
        print(f"You have already has {current_num} cleaned files: {cleaned_files} !! ")
        print("If want to restart the annotate process plese re-run with --restart")
        exit(0)

    assert len(json_files) != 0
    app.run(debug=True, port=5000)