#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
app.py

Flask-based gradient attribution visualization application that provides a web interface 
to display the importance score of each token and highlight tokens with different 
importance using different colors.

Dependencies:
    - torch>=1.10
    - transformers>=4.15
    - flask>=2.0
    - pandas>=1.3
"""

import os
import json
import pandas as pd
from flask import Flask, request, render_template, jsonify
import torch
from gradient_attribution import compute_gradient_attribution, parse_args

app = Flask(__name__)

# Default model and device
DEFAULT_MODEL = "bert-base-uncased"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Global variable to store loaded models and tokenizers
loaded_models = {}


def get_model_tokenizer(model_name):
    """Get the model and tokenizer with the specified name, reuse if already loaded"""
    global loaded_models
    
    if model_name in loaded_models:
        return loaded_models[model_name]
    
    from transformers import AutoTokenizer, AutoModelForSequenceClassification
    
    print(f"[Info] Loading model and tokenizer: {model_name}")
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForSequenceClassification.from_pretrained(model_name)
        loaded_models[model_name] = (model, tokenizer)
        return model, tokenizer
    except Exception as e:
        print(f"[Error] Unable to load model or tokenizer: {e}")
        return None, None


@app.route('/')
def index():
    """Render homepage"""
    return render_template('index.html')


@app.route('/analyze', methods=['POST'])
def analyze_text():
    """Analyze text and return token importance data"""
    data = request.get_json()
    
    text = data.get('text', '')
    model_name = data.get('model', DEFAULT_MODEL)
    label = data.get('label')
    
    if not text:
        return jsonify({"error": "Please provide text content"}), 400
        
    # If label is provided and not None, convert it to integer
    if label is not None and label != '':
        try:
            label = int(label)
        except ValueError:
            return jsonify({"error": "Label must be an integer"}), 400
    else:
        label = None
    
    # Get model and tokenizer
    model, tokenizer = get_model_tokenizer(model_name)
    if model is None or tokenizer is None:
        return jsonify({"error": f"Unable to load model {model_name}"}), 500
    
    try:
        # Calculate gradient attribution
        tokens, importances = compute_gradient_attribution(
            model=model,
            tokenizer=tokenizer,
            text=text,
            target_label=label,
            device=DEVICE
        )
        
        # Escape special characters for proper display in frontend
        display_tokens = [token.replace('#', '') for token in tokens]
        
        # Normalize importance scores (map to 0-1 range)
        if importances:
            max_importance = max(importances)
            min_importance = min(importances)
            normalized_importances = []
            
            if max_importance > min_importance:
                normalized_importances = [
                    (score - min_importance) / (max_importance - min_importance)
                    for score in importances
                ]
            else:
                # Prevent division by zero error
                normalized_importances = [0.5] * len(importances)
        
        # Construct list containing tokens and scores
        result = []
        for i, (token, score, norm_score) in enumerate(zip(tokens, importances, normalized_importances)):
            result.append({
                "index": i,
                "token": token,
                "display_token": display_tokens[i],
                "importance": score,
                "normalized_importance": norm_score
            })
        
        # Return results
        return jsonify({
            "tokens": result,
            "original_text": text,
            "model": model_name,
            "target_label": label
        })
        
    except Exception as e:
        import traceback
        traceback.print_exc()
        return jsonify({"error": str(e)}), 500


@app.route('/available_models')
def get_available_models():
    """Return list of available pre-trained models"""
    # Models can be customized or dynamically obtained as needed
    models = [
        {"id": "bert-base-uncased", "name": "BERT Base (English)"},
        {"id": "bert-base-chinese", "name": "BERT Base (Chinese)"},
        {"id": "distilbert-base-uncased", "name": "DistilBERT Base (English)"},
        {"id": "roberta-base", "name": "RoBERTa Base"},
        {"id": "xlm-roberta-base", "name": "XLM-RoBERTa (Multilingual)"}
    ]
    
    # Add local models if they exist
    local_models = []
    if os.path.exists("pythia-2.8b"):
        local_models.append({"id": "pythia-2.8b", "name": "Pythia 2.8B (Local)"})
    
    return jsonify({"models": models + local_models})


if __name__ == '__main__':
    # Create templates and static directories (if they don't exist)
    os.makedirs('templates', exist_ok=True)
    os.makedirs('static', exist_ok=True)
    os.makedirs('static/css', exist_ok=True)
    os.makedirs('static/js', exist_ok=True)
    
    # Start Flask application
    app.run(debug=True, host='0.0.0.0', port=5000)
