"""
Database models for MPF Hyperparameter Tuning Dashboard
(Used by aggregate_cluster_results.py --import_to only.)
"""

import uuid
from datetime import datetime

from flask_sqlalchemy import SQLAlchemy

db = SQLAlchemy()


def generate_id():
    """Generate a unique ID for database records"""
    return str(uuid.uuid4())


class Experiment(db.Model):
    """
    Main experiment table storing hyperparameter tuning runs
    """

    __tablename__ = "experiments"

    id = db.Column(db.String(36), primary_key=True, default=generate_id)
    name = db.Column(db.String(255), nullable=False)
    description = db.Column(db.Text)
    status = db.Column(
        db.String(20), nullable=False, default="pending"
    )  # pending, running, completed, failed
    priority = db.Column(db.String(20), default="normal")  # low, normal, high
    tags = db.Column(db.Text)  # JSON string with tags array
    created_at = db.Column(db.DateTime, default=datetime.utcnow, nullable=False)
    started_at = db.Column(db.DateTime)
    completed_at = db.Column(db.DateTime)
    config = db.Column(db.Text)  # JSON string with all configuration
    results = db.Column(db.Text)  # JSON string with results
    error_log = db.Column(db.Text)
    optuna_storage_name = db.Column(db.String(255))
    job_id = db.Column(db.String(100))  # RQ job ID
    sub_job_ids = db.Column(db.Text)  # JSON string with list of sub-job IDs

    # Progress tracking fields
    progress = db.Column(db.Float, default=0.0)  # 0.0 to 1.0
    trials_completed = db.Column(db.Integer, default=0)
    total_trials = db.Column(
        db.Integer, default=0
    )  # Total optimization trials across all subjobs
    n_trials = db.Column(db.Integer, default=100)  # Keep for backward compatibility
    best_rmse = db.Column(db.Float)
    best_model = db.Column(db.String(100))

    # Task collection tracking
    is_task_collection_parent = db.Column(
        db.Boolean, default=False
    )  # Mark parent experiments from task collections

    # Dataset information
    n = db.Column(db.Integer)  # Number of samples/instances
    p = db.Column(db.Integer)  # Number of features/dimensions

    # Fitted model paths (JSON dict: model_name -> path)
    fitted_models_paths = db.Column(db.Text)  # JSON string with dict of model_name -> path

    def to_dict(self):
        """Convert to dictionary for JSON serialization"""
        import json

        return {
            "id": self.id,
            "name": self.name,
            "description": self.description,
            "status": self.status,
            "priority": self.priority,
            "tags": json.loads(self.tags) if self.tags else [],
            "created_at": self.created_at.isoformat() if self.created_at else None,
            "started_at": self.started_at.isoformat() if self.started_at else None,
            "completed_at": self.completed_at.isoformat()
            if self.completed_at
            else None,
            "config": self.config,
            "results": self.results,
            "error_log": self.error_log,
            "optuna_storage_name": self.optuna_storage_name,
            "job_id": self.job_id,
            "sub_job_ids": json.loads(self.sub_job_ids) if self.sub_job_ids else [],
            "progress": self.progress,
            "trials_completed": self.trials_completed,
            "total_trials": self.total_trials,
            "n_trials": self.n_trials,
            "best_rmse": self.best_rmse,
            "best_model": self.best_model,
            "is_task_collection_parent": self.is_task_collection_parent,
            "n": self.n,
            "p": self.p,
            "fitted_models_paths": json.loads(self.fitted_models_paths) if self.fitted_models_paths else {},
        }

    def to_dict_summary(self):
        """Lightweight dictionary for list views - excludes large fields"""
        import json

        return {
            "id": self.id,
            "name": self.name,
            "description": self.description,
            "status": self.status,
            "priority": self.priority,
            "tags": json.loads(self.tags) if self.tags else [],
            "created_at": self.created_at.isoformat() if self.created_at else None,
            "started_at": self.started_at.isoformat() if self.started_at else None,
            "completed_at": self.completed_at.isoformat()
            if self.completed_at
            else None,
            # Exclude large fields: config, results
            # Include error_log only if failed (it's needed for display)
            "error_log": self.error_log if self.status == "failed" else None,
            "optuna_storage_name": self.optuna_storage_name,
            "job_id": self.job_id,
            "sub_job_ids": json.loads(self.sub_job_ids) if self.sub_job_ids else [],
            "progress": self.progress,
            "trials_completed": self.trials_completed,
            "total_trials": self.total_trials,
            "n_trials": self.n_trials,
            "best_rmse": self.best_rmse,
            "best_model": self.best_model,
            "is_task_collection_parent": self.is_task_collection_parent,
            "n": self.n,
            "p": self.p,
            "fitted_models_paths": json.loads(self.fitted_models_paths) if self.fitted_models_paths else {},
            # Include a flag to indicate results exist (for conditional display)
            "has_results": bool(self.results),
        }


class SubJob(db.Model):
    """
    Individual sub-jobs (model/fold combinations) for tracking detailed results
    """

    __tablename__ = "sub_jobs"

    id = db.Column(db.String(36), primary_key=True, default=generate_id)
    experiment_id = db.Column(
        db.String(36), db.ForeignKey("experiments.id"), nullable=False, index=True
    )
    rq_job_id = db.Column(db.String(100))  # RQ job ID for reference
    model_name = db.Column(db.String(100))
    fold_index = db.Column(db.Integer)  # For nested CV, None for simple CV
    status = db.Column(
        db.String(20), default="pending"
    )  # pending, running, completed, failed
    started_at = db.Column(db.DateTime)
    completed_at = db.Column(db.DateTime)

    # Results
    test_rmse = db.Column(db.Float)
    test_mse = db.Column(db.Float)
    best_cv_score = db.Column(db.Float)
    best_params = db.Column(db.Text)  # JSON string with best parameters
    fixed_params = db.Column(db.Text)  # JSON string with fixed parameters (not tuned)
    fit_time = db.Column(db.Float)  # Time taken to fit
    error_log = db.Column(db.Text)  # Error message if failed

    created_at = db.Column(db.DateTime, default=datetime.utcnow, nullable=False)

    def to_dict(self):
        """Convert to dictionary for JSON serialization"""
        return {
            "id": self.id,
            "experiment_id": self.experiment_id,
            "rq_job_id": self.rq_job_id,
            "model_name": self.model_name,
            "fold_index": self.fold_index,
            "status": self.status,
            "started_at": self.started_at.isoformat() if self.started_at else None,
            "completed_at": self.completed_at.isoformat()
            if self.completed_at
            else None,
            "test_rmse": self.test_rmse,
            "test_mse": self.test_mse,
            "best_cv_score": self.best_cv_score,
            "best_params": self.best_params,
            "fixed_params": self.fixed_params,  # Include fixed_params for backward compatibility
            "fit_time": self.fit_time,
            "error_log": self.error_log,
            "created_at": self.created_at.isoformat() if self.created_at else None,
        }


class ParameterConfig(db.Model):
    """
    Saved parameter configurations for reuse
    """

    __tablename__ = "parameter_configs"

    id = db.Column(db.String(36), primary_key=True, default=generate_id)
    name = db.Column(db.String(255), nullable=False)
    model_type = db.Column(db.String(100))  # 'MPFRegressor', 'XGBRegressor', etc.
    parameters = db.Column(db.Text)  # JSON string with parameter definitions
    is_template = db.Column(db.Integer, default=0)  # 1 if it's a reusable template
    created_at = db.Column(db.DateTime, default=datetime.utcnow, nullable=False)

    def to_dict(self):
        """Convert to dictionary for JSON serialization"""
        return {
            "id": self.id,
            "name": self.name,
            "model_type": self.model_type,
            "parameters": self.parameters,
            "is_template": self.is_template,
            "created_at": self.created_at.isoformat() if self.created_at else None,
        }


class DatasetConfig(db.Model):
    """
    Saved dataset configurations
    """

    __tablename__ = "dataset_configs"

    id = db.Column(db.String(36), primary_key=True, default=generate_id)
    name = db.Column(db.String(255), nullable=False)
    type = db.Column(
        db.String(50)
    )  # 'openml', 'openml_task', 'openml_batch', 'friedman'
    config = db.Column(db.Text)  # JSON string with dataset config
    created_at = db.Column(db.DateTime, default=datetime.utcnow, nullable=False)

    def to_dict(self):
        """Convert to dictionary for JSON serialization"""
        return {
            "id": self.id,
            "name": self.name,
            "type": self.type,
            "config": self.config,
            "created_at": self.created_at.isoformat() if self.created_at else None,
        }


def init_db(app):
    """Initialize the database"""
    db.init_app(app)
    with app.app_context():
        db.create_all()
        print("Database initialized successfully")
