import os
import numpy as np
import pandas as pd
import wandb
from collections import defaultdict

import src.dataset as dataset_utils
from ._base_task import BaseTask


class task_RNAGymDataset(BaseTask):
    def __init__(self, cfg, model):
        # RNA types for RNAGym dataset
        self.rna_types = ["mRNA-splicing", "mRNA-coding", "tRNA", "Aptamer", "Ribozyme"]

        super().__init__(cfg, model)

    def calculate_rnagym_metrics(
        self, assay_scores: np.ndarray, model_scores: np.ndarray
    ) -> dict:
        """Calculate RNAGym-specific metrics: Spearman correlation, AUC, and MCC."""
        try:
            # Remove NaN values
            mask = ~(np.isnan(assay_scores) | np.isnan(model_scores))
            if not mask.any():
                return {
                    "spearman_abs": np.nan,
                    "auroc_abs": np.nan,
                    "auprc_abs": np.nan,
                    "mcc": np.nan,
                }

            assay_clean = assay_scores[mask]
            model_clean = model_scores[mask]

            if len(assay_clean) < 2:
                return {
                    "spearman_abs": np.nan,
                    "auroc_abs": np.nan,
                    "auprc_abs": np.nan,
                    "mcc": np.nan,
                }

            # Calculate metrics using existing dataset_utils functions
            spearman_corr = dataset_utils.calc_metrics(
                assay_clean, model_clean, "spearman_abs"
            )

            # For AUC and MCC, we need binary labels based on median split
            assay_binary = (assay_clean > np.median(assay_clean)).astype(int)
            model_binary = (model_clean > np.median(model_clean)).astype(int)

            auc = dataset_utils.calc_metrics(assay_binary, model_clean, "auroc_abs")
            auprc = dataset_utils.calc_metrics(assay_binary, model_clean, "ap_abs")
            mcc = dataset_utils.calc_metrics(assay_binary, model_binary, "mcc")

            return {
                "spearman_abs": spearman_corr,
                "auroc_abs": auc,
                "auprc_abs": auprc,
                "mcc": abs(mcc) if not np.isnan(mcc) else np.nan,
            }
        except Exception as e:
            print(f"Error calculating RNAGym metrics: {str(e)}")
            return {
                "spearman_abs": np.nan,
                "auroc_abs": np.nan,
                "auprc_abs": np.nan,
                "mcc": np.nan,
            }

    def get_mutation_depth(self, sequence: str, original_seq: str) -> str:
        """Determine if mutation is single or multiple based on sequence comparison."""
        try:
            # Count differences between sequences
            if len(sequence) != len(original_seq):
                return "multiple"  # Length changes indicate complex mutations

            differences = sum(c1 != c2 for c1, c2 in zip(sequence, original_seq))
            return "single" if differences == 1 else "multiple"
        except Exception:
            return "unknown"

    def calc_metrics(self, is_print=True):
        """Override BaseTask calc_metrics to include RNAGym-specific analysis."""
        results = defaultdict(dict)
        score_map = {"all": "pred_all", "for": "pred_for", "rev": "pred_rev"}
        # Determine score modes to calculate; use defaults if not provided in cfg
        score_modes_to_run = getattr(self.cfg, "score_modes", ["all", "for", "rev"])
        if "all" in score_modes_to_run:
            score_modes_to_run = ["all", "for", "rev"]
        if isinstance(score_modes_to_run, str):
            score_modes_to_run = [score_modes_to_run]

        for file in os.listdir(self.output_dir):
            if not file.endswith('csv'):
                continue

            if file.startswith('rnagym_'):  # skip summary CSVs
                continue

            name = os.path.basename(file)  # extract name from "name.csv"
            csv_path = os.path.join(self.output_dir, file)
            df = pd.read_csv(csv_path)
            _labels_all = np.array(df["label"].tolist(), dtype=np.float32)

            # Iterate requested score modes
            for mode in score_modes_to_run:
                # Select corresponding prediction column by mode
                pred_column = score_map[mode]
                _preds_all = np.array(df[pred_column].tolist(), dtype=np.float32)

                # Calculate standard metrics using dataset_utils
                for metric in self.metrics:
                    results[name][f"{mode}_{metric}"] = dataset_utils.calc_metrics(
                        _labels_all, _preds_all, metric
                    )
                # Additionally calculate RNAGym-specific metrics for each tag
                rnagym_metrics = self.calculate_rnagym_metrics(_labels_all, _preds_all)
                for rna_metric, value in rnagym_metrics.items():
                    results[name][f"{mode}_rnagym_{rna_metric}"] = value

            dataset_utils.report_metrics(
                name,
                results[name],
                self.cfg.use_wandb,
                is_print,
            )

        return results

    def calc_average_metrics(self, metrics, is_print=True):
        """Override to provide RNAGym-specific multi-dimensional analysis."""
        # First call the base implementation for overall averages
        super().calc_average_metrics(metrics, is_print)

        # Collect data for multi-dimensional analysis
        dataset_info = []
        for name in metrics.keys():
            rna_type = dataset_utils.get_rna_type_from_name(name)

            # Get original sequence if available (from dataset)
            try:
                original_seq = getattr(self, "dataset", None)
                if hasattr(original_seq, "origins") and name in original_seq.origins:
                    orig_seq = original_seq.origins[name]
                else:
                    orig_seq = None
            except Exception:
                orig_seq = None

            dataset_info.append(
                {
                    "name": name,
                    "rna_type": rna_type,
                    "original_seq": orig_seq,
                    **metrics[name],
                }
            )
        df = pd.DataFrame(dataset_info)
        rna_type_metrics = self._calculate_type_averages(df)

        # Calculate averages by RNA type
        if is_print:
            print("\n" + "=" * 50)
            print("RNAGym Multi-dimensional Analysis")
            print("=" * 50)

            print("\nMetrics by RNA Type:")
            print(rna_type_metrics.to_string(index=False))

        # Log RNAGym-specific metrics to wandb
        if self.cfg.use_wandb:
            self._log_rnagym_wandb_metrics(rna_type_metrics, df)

        # Save results if output directory exists
        if hasattr(self, "output_dir") and os.path.exists(self.output_dir):
            self._save_analysis_results(df, rna_type_metrics, is_verbose=False)

        return df

    def _calculate_type_averages(self, df: pd.DataFrame) -> pd.DataFrame:
        """Calculate average metrics grouped by RNA type."""
        # Get RNAGym-specific metrics
        rnagym_cols = [col for col in df.columns if "rnagym_" in col]
        if not rnagym_cols:
            # Fallback to standard metrics
            metric_cols = [
                col for col in df.columns if any(m in col for m in self.metrics)
            ]
            metric_cols = [col for col in metric_cols if col.startswith("all_")]
        else:
            metric_cols = rnagym_cols

        if not metric_cols:
            return pd.DataFrame({"RNA_Type": self.rna_types})

        # Group by RNA type and calculate means
        type_results = []
        rna_type_means = {}  # Store means for each RNA type

        for rna_type in self.rna_types:
            subset = df[df["rna_type"] == rna_type]

            if len(subset) == 0:
                continue

            result = {"RNA_Type": rna_type}
            for col in metric_cols:
                if col in subset.columns:
                    values = subset[col].dropna()
                    if len(values) > 0:
                        result[col] = values.mean()
                        # Store for overall calculation
                        if col not in rna_type_means:
                            rna_type_means[col] = []
                        rna_type_means[col].append(result[col])
                    else:
                        result[col] = np.nan

            type_results.append(result)

        # Calculate 'All' as average of RNA type means (not all datasets)
        overall_result = {"RNA_Type": "All"}
        for col in metric_cols:
            if col in rna_type_means and len(rna_type_means[col]) > 0:
                overall_result[col] = np.mean(rna_type_means[col])
            else:
                overall_result[col] = np.nan
        type_results.append(overall_result)

        return pd.DataFrame(type_results)

    def _log_rnagym_wandb_metrics(
        self, rna_type_metrics: pd.DataFrame, detailed_df: pd.DataFrame
    ):
        """Log RNAGym-specific metrics to wandb."""
        try:
            # Log summary metrics by RNA type (excluding 'All' to avoid duplication)
            for _, row in rna_type_metrics.iterrows():
                rna_type = row["RNA_Type"]
                if rna_type == "All":
                    continue  # Skip 'All' here, handle it separately below

                rna_type_clean = rna_type.replace(" ", "_").replace("-", "_")

                # Log each metric for this RNA type
                wandb_data = {}
                for col in row.index:
                    if col != "RNA_Type" and not pd.isna(row[col]):
                        # Avoid double 'rnagym_' prefix if column already contains 'rnagym_'
                        if "rnagym_" in col:
                            metric_name = f"{rna_type_clean}_{col}"
                        else:
                            metric_name = f"rnagym_{rna_type_clean}_{col}"
                        wandb_data[metric_name] = row[col]

                if wandb_data:
                    wandb.log(wandb_data)

            # Log dataset count by RNA type
            rna_type_counts = detailed_df["rna_type"].value_counts().to_dict()
            for rna_type, count in rna_type_counts.items():
                rna_type_clean = rna_type.replace(" ", "_").replace("-", "_")
                wandb.log({f"rnagym_{rna_type_clean}_dataset_count": count})

            # Log overall RNAGym-specific metrics (all RNA types combined)
            all_row = rna_type_metrics[rna_type_metrics["RNA_Type"] == "All"]
            if not all_row.empty:
                wandb_data = {}
                for col in all_row.columns:
                    if col != "RNA_Type" and not pd.isna(all_row.iloc[0][col]):
                        # Avoid double 'rnagym_' prefix if column already contains 'rnagym_'
                        if "rnagym_" in col:
                            metric_name = f"overall_{col}"
                        else:
                            metric_name = f"rnagym_overall_{col}"
                        wandb_data[metric_name] = all_row.iloc[0][col]

                if wandb_data:
                    wandb.log(wandb_data)

        except Exception as e:
            print(f"Error logging RNAGym metrics to wandb: {e}")

    def _log_extended_statistics_wandb(self, results: dict):
        """Log extended statistics (with standard errors) to wandb."""
        try:
            rna_type_df = results.get("by_rna_type", pd.DataFrame())
            if not rna_type_df.empty:
                for _, row in rna_type_df.iterrows():
                    rna_type = row["RNA_Type"]
                    rna_type_clean = rna_type.replace(" ", "_").replace("-", "_")

                    # Log extended metrics with standard errors
                    wandb_data = {}
                    for col in row.index:
                        if col not in ["RNA_Type", "N_datasets"] and not pd.isna(
                            row[col]
                        ):
                            metric_name = f"rnagym_extended_{rna_type_clean}_{col}"
                            wandb_data[metric_name] = row[col]

                    if wandb_data:
                        wandb.log(wandb_data)

        except Exception as e:
            print(f"Error logging extended statistics to wandb: {e}")

    def _save_analysis_results(
        self,
        detailed_df: pd.DataFrame,
        summary_df: pd.DataFrame,
        is_verbose: bool = True,
    ):
        """Save analysis results to CSV files."""
        try:
            # Save detailed results (per dataset)
            detailed_path = os.path.join(self.output_dir, "rnagym_detailed_results.csv")
            detailed_df.to_csv(detailed_path, index=False)

            # Save summary results (by RNA type)
            summary_path = os.path.join(
                self.output_dir, "rnagym_summary_by_rna_type.csv"
            )
            summary_df.to_csv(summary_path, index=False)
            if is_verbose:
                print(f"Detailed results saved to: {detailed_path}")
                print(f"Summary results saved to: {summary_path}")

        except Exception as e:
            print(f"Error saving analysis results: {e}")

    def calc_extended_statistics(
        self, metrics, calculate_se: bool = False, n_iterations: int = 1000
    ):
        """Calculate extended statistics with optional bootstrap standard errors."""
        # Collect data for analysis
        dataset_info = []
        for name in metrics.keys():
            rna_type = dataset_utils.get_rna_type_from_name(name)

            # Get original sequence for mutation depth analysis
            try:
                if (
                    hasattr(self, "dataset")
                    and hasattr(self.dataset, "origins")
                    and name in self.dataset.origins
                ):
                    orig_seq = self.dataset.origins[name]
                else:
                    orig_seq = None
            except Exception:
                orig_seq = None

            dataset_info.append(
                {
                    "name": name,
                    "rna_type": rna_type,
                    "original_seq": orig_seq,
                    **metrics[name],
                }
            )

        df = pd.DataFrame(dataset_info)
        results = {}

        # Calculate metrics by RNA type
        rnagym_cols = [col for col in df.columns if "rnagym_" in col]
        if rnagym_cols:
            metric_cols = rnagym_cols
        else:
            metric_cols = [
                col for col in df.columns if any(m in col for m in self.metrics)
            ]
            metric_cols = [col for col in metric_cols if col.startswith("all_")]

        # RNA type analysis
        type_results = []
        for rna_type in self.rna_types + ["All"]:
            if rna_type == "All":
                subset = df
            else:
                subset = df[df["rna_type"] == rna_type]

            if len(subset) == 0:
                continue

            result = {"RNA_Type": rna_type, "N_datasets": len(subset)}

            for col in metric_cols:
                if col in subset.columns:
                    values = subset[col].dropna()
                    if len(values) > 0:
                        result[f"{col}_mean"] = values.mean()
                        result[f"{col}_std"] = values.std()

                        if calculate_se and len(values) > 1:
                            result[f"{col}_SE"] = self.bootstrap_se(
                                subset, col, n_iterations
                            )
                    else:
                        result[f"{col}_mean"] = np.nan
                        result[f"{col}_std"] = np.nan
                        if calculate_se:
                            result[f"{col}_SE"] = np.nan

            type_results.append(result)

        results["by_rna_type"] = pd.DataFrame(type_results)

        # Log extended statistics to wandb if enabled
        if self.cfg.use_wandb:
            self._log_extended_statistics_wandb(results)

        # If we have mutation depth information, add mutation depth analysis
        if hasattr(self, "dataset") and hasattr(self.dataset, "origins"):
            # This would require access to the actual sequences to determine mutation depth
            # For now, we'll create a placeholder for this functionality
            results["by_mutation_depth"] = pd.DataFrame(
                {"Analysis": ["Mutation depth analysis requires sequence data"]}
            )

        return results
