import argparse
import json
import logging
from pathlib import Path
from typing import List, Tuple, Union

import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset

from src.ebm_training.utils import get_cached_batch
from src.energy_model.config import EBMConfig
from src.energy_model.models import EnergyModel
from src.energy_model.utils.energy_network import semantic_sentence_split, normalize_sentences
from src.interpreter_model.config.interpreter_configs import InterpreterConfig
from src.interpreter_model.interpreter import InterpreterModel

# ------------------- Dataset for Inference -------------------
class InterpreterInferenceDataset(Dataset):
    """Dataset for inference with prompts and target output sentences."""
    def __init__(self, csv_path: str):
        import pandas as pd
        self.df = pd.read_csv(csv_path)
        if "target_index" not in self.df.columns:
            self.df["target_index"] = -1
        self.prompts = self.df["review_text"].tolist()
        self.responses = self.df["formatted_prediction"].tolist()
        self.target_indices = self.df["target_index"].tolist()

    def __len__(self):
        return len(self.prompts)

    def __getitem__(self, idx):
        return self.prompts[idx], self.responses[idx], self.target_indices[idx]

def collate_fn(batch):
    prompts, responses, target_indices = zip(*batch)
    return list(prompts), list(responses), list(target_indices)

# ------------------- Save importance scores -------------------

def save_explanations_csv_batch(
    prompts: List[str],
    responses: List[str],
    scores_list: List[Union[torch.Tensor, List[float]]],
    n_sentences: int,
    out_dir: Path,
    output_file: str = "inference_explanations.csv"
) -> Path:
    """Save all inference results into a single CSV file with sentence_analysis format.

    Args:
        prompts (List[str]): list of input texts
        responses (List[str]): list of target outputs
        scores_list (List[torch.Tensor or List]): list of importance score tensors/lists
        n_sentences (int): number of sentences per prompt (for normalization)
        out_dir (Path): output directory
        output_file (str): CSV file name

    """
    all_rows = []

    for _, (input_text, output_text, scores_tensor) in enumerate(
        zip(prompts, responses, scores_list)
    ):
        sentences = semantic_sentence_split(input_text)
        sentences = normalize_sentences(sentences, n_sentences)

        # Convert scores to list of floats
        if isinstance(scores_tensor, list):
            scores = [float(s) for s in scores_tensor]
        else:
            scores = scores_tensor.detach().cpu().tolist()

        # Filter out masked sentences and create sentence analysis
        valid_sentences = []
        for sent, score in zip(sentences, scores):
            # Skip placeholder sentences (masked ones)
            if sent is None or sent == "..." or sent.strip() == "":
                continue
            valid_sentences.append((sent, score))

        # Sort by importance score (descending) and assign ranks
        valid_sentences.sort(key=lambda x: x[1], reverse=True)

        sentence_analysis = []
        for rank, (sentence, score) in enumerate(valid_sentences, 1):
            sentence_analysis.append({
                "sentence": sentence,
                "rank": rank,
                "importance_score": float(score)
            })

        # Create row with sentence_analysis column
        all_rows.append({
            "review_text": input_text,
            "output_text": output_text,
            "interpreter_analysis": json.dumps(sentence_analysis)
        })

    # Save the entire batch to a single CSV
    out_dir.mkdir(parents=True, exist_ok=True)
    csv_path = out_dir / output_file
    results_df = pd.DataFrame(all_rows)
    results_df.to_csv(csv_path, index=False)
    return csv_path

# ------------------- Inference -------------------
def main() -> None:
    parser = argparse.ArgumentParser(description="Interpreter Inference")
    parser.add_argument("--csv", required=True, help="CSV file with prompts and responses")
    parser.add_argument(
        "--interpreter_ckpt",
        required=True,
        help="Path to trained interpreter checkpoint"
    )
    parser.add_argument(
        "--output_dir",
        default="interpreter_inference",
        help="Directory to save explanations"
    )
    parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
    parser.add_argument("--ebm_self_attention_layers", type=int, default=2)
    parser.add_argument("--ebm_cross_attention_layers", type=int, default=6)
    parser.add_argument("--interp_cross_attention_layers", type=int, default=6)

    args = parser.parse_args()

    device = torch.device(args.device)
    logging.basicConfig(level=logging.INFO)

    # Load dataset
    dataset = InterpreterInferenceDataset(args.csv)
    dataloader = DataLoader(dataset, batch_size=16, shuffle=False, collate_fn=collate_fn)

    # Load EBM
    ebm_config = EBMConfig()
    ebm_config = EBMConfig(
        self_attention_n_layers=args.ebm_self_attention_layers,
        cross_attention_n_layers=args.ebm_cross_attention_layers
    )
    energy_model = EnergyModel(ebm_config).to(device)
    energy_model.eval()

    # Load Interpreter
    interp_config = InterpreterConfig(
        cross_attention_layers=args.interp_cross_attention_layers,
        softmax_type="normal",
    )

    n_sentences = ebm_config.n_sentences
    interpreter = InterpreterModel(
        energy_model=energy_model,
        config=interp_config,
        n_sentences=n_sentences
    ).to(device)
    checkpoint = torch.load(args.interpreter_ckpt, map_location=device)
    interpreter.load_state_dict(checkpoint["model_state_dict"])
    interpreter.eval()

    # Run inference - accumulate all results
    all_prompts = []
    all_responses = []
    all_importance_scores = []

    with torch.no_grad():
        for _, (prompts, responses, target_indices) in enumerate(dataloader):

            # Compute importance scores (no energy model needed)
            importance_scores, masked_texts, concept_mask = interpreter.forward(
                input_texts=prompts,
                output_texts=responses,
                target_indices=target_indices,
            )

            # Accumulate results from this batch
            all_prompts.extend(prompts)
            all_responses.extend(responses)
            all_importance_scores.extend(importance_scores)

    # Save all results once at the end
    logger = logging.getLogger(__name__)
    logger.info(f"Saving {len(all_prompts)} inference results...")
    save_explanations_csv_batch(
        all_prompts, all_responses, all_importance_scores,
        n_sentences=interpreter.n_sentences,
        out_dir=Path(args.output_dir)
    )

if __name__ == "__main__":
    main()
