from typing import List, Optional
from paretoset import paretoset

import numpy as np
import pandas as pd
import ast
import re

from reagent import Reagent


def create_reagents(filename: str, num_to_select: Optional[int] = None, num_props=1) -> List[Reagent]:
    """
    Creates a list of Reagents from a file
    :param filename: a smiles file containing the reagents
    :param num_to_select: For dev purposes; the number of molecules to return
    :param num_props: number of properties being optimized
    :return: List of Reagents
    """
    reagent_list = []
    with open(filename, 'r') as f:
        for line in f.readlines():
            smiles, reagent_name = line.split()
            reagent = Reagent(reagent_name=reagent_name, smiles=smiles, num_props=num_props)
            reagent_list.append(reagent)
    if num_to_select is not None and len(reagent_list) > num_to_select:
        reagent_list = reagent_list[:num_to_select]
    return reagent_list


def read_reagents(reagent_file_list, num_to_select: Optional[int], num_props) -> List[Reagent]:
    """
    Read the reagents SMILES files
    :param reagent_file_list: a list of filenames containing reagents for the reaction. Each file list contains smiles
    strings for a single component of the reaction.
    :param num_to_select: select how many reagents to read, mostly a development function
    :param num_props: number of properties being optimized
    :return: List of Reagents
    """
    reagents = []
    for reagent_filename in reagent_file_list:
        reagent_list = create_reagents(filename=reagent_filename, num_to_select=num_to_select, num_props=num_props)
        reagents.append(reagent_list)
    return reagents


def get_pareto_indices(scores):
    # Using paretoset package for efficiency
    pareto_mask = paretoset(scores, sense=["max"] * scores.shape[1], distinct=False)
    return np.where(pareto_mask)[0]


def parse_score(s):
    # Fast, robust parsing using numpy.fromstring.
    if s is None:
        return np.array([])
    try:
        if isinstance(s, float) and np.isnan(s):
            return np.array([])
    except Exception:
        pass
    s = str(s).strip()
    s = s.strip('[]')
    # Replace commas and multiple whitespace with a single space
    s = re.sub(r'[,\s]+', ' ', s).strip()
    if s == '':
        return np.array([])
    arr = np.fromstring(s, dtype=float, sep=' ')
    if arr.size == 0:
        # Fallback to ast if numpy fails for weird formats
        return np.array(ast.literal_eval('[' + s + ']'), dtype=float)
    return arr


def process_large_csv(input_csv, score_column='score', chunksize=100_000):
    pareto_candidates = []
    reader = pd.read_csv(input_csv, chunksize=chunksize)
    for chunk in reader:
        chunk['score_array'] = chunk[score_column].apply(parse_score)
        scores = np.vstack(chunk['score_array'])
        pareto_idx = get_pareto_indices(scores)
        pareto_candidates.append(chunk.iloc[pareto_idx])
    # Concatenate all chunk-level pareto sets and compute global Pareto
    all_candidates = pd.concat(pareto_candidates, ignore_index=True)
    scores = np.vstack(all_candidates['score_array'])
    pareto_idx = get_pareto_indices(scores)
    pareto_df = all_candidates.iloc[pareto_idx]
    return pareto_df


def csv_to_df(input_csv, score_column='score', chunksize=100_000):
    processed = []
    reader = pd.read_csv(input_csv, chunksize=chunksize)
    for chunk in reader:
        chunk['score_array'] = chunk[score_column].apply(parse_score)
        processed.append(chunk)
    # Concatenate all chunk-level pareto sets and compute global Pareto
    processed_df = pd.concat(processed, ignore_index=True)
    return processed_df


def bernoulli_metric(rec, opt):
    """
    Computes a binary metric: 1 if the recommended set matches the optimal set, 0 otherwise.
    """
    return int(rec == opt)


def jaccard_metric(rec, opt):
    """
    Computes the Jaccard similarity between the recommended set and the optimal set.
    """
    intersection = len(rec.intersection(opt))
    union = len(rec.union(opt))
    if union == 0:
        return 0.0
    return intersection / union


def misclassification_metric(rec, opt):
    """
    Computes the percentage of recommended reagents that are not in the optimal set.
    """
    incorrect = len([r for r in rec if r not in opt])
    total = len(rec)
    if total == 0:
        return 0.0
    return incorrect / total
