from __future__ import annotations

import math
import os
from typing import Dict, List, Union, Tuple, Optional, Any
import numpy as np
import pandas as pd

try:
    import torch
    from torch import Tensor as TorchTensor
except Exception:
    torch = None
    TorchTensor = None
class HEIScoring:
    def __init__(self):
        # max_standard: threshold for full score
        # min_standard: threshold for zero score
        self.components = {
            "Total Fruits": {"max_score": 5, "max_standard": 0.8, "min_standard": 0.0},
            "Whole Fruits": {"max_score": 5, "max_standard": 0.4, "min_standard": 0.0},
            "Total Vegetables": {"max_score": 5, "max_standard": 1.1, "min_standard": 0.0},
            "Greens and Beans": {"max_score": 5, "max_standard": 0.2, "min_standard": 0.0},
            "Whole Grains": {"max_score": 10, "max_standard": 1.5, "min_standard": 0.0},
            "Dairy": {"max_score": 10, "max_standard": 1.3, "min_standard": 0.0},
            "Total Protein Foods": {"max_score": 5, "max_standard": 2.5, "min_standard": 0.0},
            "Seafood and Plant Proteins": {"max_score": 5, "max_standard": 0.8, "min_standard": 0.0},
            "Fatty Acids": {"max_score": 10, "max_standard": 2.5, "min_standard": 1.2},
            "Refined Grains": {"max_score": 10, "max_standard": 1.8, "min_standard": 4.3},
            "Sodium": {"max_score": 10, "max_standard": 1.1, "min_standard": 2.0},
            "Added Sugars": {"max_score": 10, "max_standard": 6.5, "min_standard": 26.0},
            "Saturated Fats": {"max_score": 10, "max_standard": 8.0, "min_standard": 16.0},
        }

        self.reverse_components = {
            "Refined Grains",
            "Sodium",
            "Added Sugars",
            "Saturated Fats",
        }

    def _score_component(self, value: float, component_name: str) -> float:
        comp = self.components[component_name]
        max_score = comp["max_score"]
        low_bound = comp["max_standard"]
        high_bound = comp["min_standard"]

        reverse = component_name in self.reverse_components

        if reverse:
            if value <= low_bound:
                return max_score
            elif value >= high_bound:
                return 0.0
            else:
                score = max_score * (high_bound - value) / (high_bound - low_bound)
                return float(np.clip(score, 0.0, max_score))
        else:
            if value >= low_bound:
                return max_score
            elif value <= 0:
                return 0.0
            else:
                score = max_score * (value / low_bound)
                return float(np.clip(score, 0.0, max_score))

    def calculate_from_totals(
        self,
        *,
        total_fruits_g: float,
        whole_fruits_g: float,
        total_vegetables_g: float,
        greens_and_beans_g: float,
        whole_grains_g: float,
        dairy_g: float,
        total_protein_foods_g: float,
        seafood_plant_proteins_g: float,
        refined_grains_g: float,
        sodium_g_total: float,
        added_sugars_g: float,
        saturated_fats_g: float,
        pufas_g: float,
        mufas_g: float,
        sfas_g: float,
        total_kcal: float,
    ) -> Dict[str, float]:
        """Calculate HEI scores from total nutrients."""
        if total_kcal <= 0:
            raise ValueError("Total energy (kcal) must be > 0.")

        scale = 1000.0 / total_kcal

        total_fruits_cup = (total_fruits_g * scale) / 142.0
        whole_fruits_cup = (whole_fruits_g * scale) / 142.0
        total_vegetables_cup = (total_vegetables_g * scale) / 142.0
        greens_and_beans_cup = (greens_and_beans_g * scale) / 142.0
        whole_grains_oz = (whole_grains_g * scale) / 28.35
        total_protein_oz = (total_protein_foods_g * scale) / 28.35
        seafood_plant_oz = (seafood_plant_proteins_g * scale) / 28.35

        dairy_cup = (dairy_g * scale) / 28.35

        refined_grains_per_1000kcal_g = refined_grains_g * scale
        sodium_per_1000kcal_g = sodium_g_total * scale

        added_sugars_energy_percent = (added_sugars_g * 4.0 / total_kcal) * 100.0
        saturated_fats_energy_percent = (saturated_fats_g * 9.0 / total_kcal) * 100.0

        fatty_acids_ratio = (pufas_g + mufas_g) / sfas_g if sfas_g > 0 else float("inf")

        inputs = {
            "Total Fruits": total_fruits_cup,
            "Whole Fruits": whole_fruits_cup,
            "Total Vegetables": total_vegetables_cup,
            "Greens and Beans": greens_and_beans_cup,
            "Whole Grains": whole_grains_oz,
            "Dairy": dairy_cup,
            "Total Protein Foods": total_protein_oz,
            "Seafood and Plant Proteins": seafood_plant_oz,
            "Refined Grains": refined_grains_per_1000kcal_g,
            "Sodium": sodium_per_1000kcal_g,
            "Added Sugars": added_sugars_energy_percent,
            "Saturated Fats": saturated_fats_energy_percent,
            "Fatty Acids": fatty_acids_ratio,
        }

        results = {}
        total_score = 0.0
        for name in self.components:
            score = self._score_component(inputs.get(name, 0.0), name)
            results[name] = round(score, 2)
            total_score += score
        results["Total_HEI_Score"] = round(total_score, 2)
        return results


_REQUIRED_KEYS = [
    "total_fruits_100g",
    "whole_fruits_100g",
    "total_vegetables_100g",
    "greens_and_beans_100g",
    "whole_grains_100g",
    "dairy_100g",
    "total_protein_foods_100g",
    "seafood_plant_proteins_100g",
    "refined_grains_100g",
    "sodium_100g",
    "added_sugars_100g",
    "saturated_fats_100g",
    "pufas_100g",
    "mufas_100g",
    "sfas_100g",
    "kcal_per_100g",
]

class FoodEvaluator:
    """Food-table evaluator that computes HEI-style scores."""

    def __init__(self, food_csv: Optional[str] = "food.csv"):
        self.hei = HEIScoring()
        self.food_names, self.food_table = self._load_foods(food_csv)

    # -------- data loading --------
    def _load_foods(self, food_csv: Optional[str]) -> Tuple[List[str], List[Dict[str, float]]]:
        if food_csv and os.path.exists(food_csv):
            df = pd.read_csv(food_csv)
            missing = [k for k in _REQUIRED_KEYS if k not in df.columns]
            if missing:
                raise ValueError(f"[food.csv] missing required columns: {missing}")
            names = df["food_name"].tolist() if "food_name" in df.columns else [f"Food{i}" for i in range(len(df))]
            table = []
            for _, row in df.iterrows():
                rec = {k: float(row[k]) for k in _REQUIRED_KEYS}
                table.append(rec)
            return names, table

        rng = np.random.default_rng(42)
        names = [f"DemoFood_{i}" for i in range(20)]
        table = []
        for _ in range(20):
            kcal = float(rng.uniform(80, 300))
            rec = {
                "total_fruits_100g": float(rng.uniform(0, 60)),
                "whole_fruits_100g": float(rng.uniform(0, 40)),
                "total_vegetables_100g": float(rng.uniform(0, 80)),
                "greens_and_beans_100g": float(rng.uniform(0, 30)),
                "whole_grains_100g": float(rng.uniform(0, 60)),
                "dairy_100g": float(rng.uniform(0, 50)),
                "total_protein_foods_100g": float(rng.uniform(0, 60)),
                "seafood_plant_proteins_100g": float(rng.uniform(0, 40)),
                "refined_grains_100g": float(rng.uniform(0, 60)),
                "sodium_100g": float(rng.uniform(50, 900)),
                "added_sugars_100g": float(rng.uniform(0, 20)),
                "saturated_fats_100g": float(rng.uniform(0, 20)),
                "pufas_100g": float(rng.uniform(0, 15)),
                "mufas_100g": float(rng.uniform(0, 15)),
                "sfas_100g": float(rng.uniform(0, 15)),
                "kcal_per_100g": kcal,
            }
            table.append(rec)
        return names, table

    # -------- utils --------
    def _to_numpy_2d(self, x: Union[List[float], np.ndarray, "TorchTensor"]) -> np.ndarray:
        if torch is not None and isinstance(x, TorchTensor):
            x = x.detach().cpu().numpy()
        elif isinstance(x, list):
            x = np.asarray(x, dtype=np.float64)
        elif isinstance(x, np.ndarray):
            x = x.astype(np.float64, copy=False)
        else:
            raise TypeError("grams_tensor must be one of list / np.ndarray / torch.Tensor")

        if x.ndim == 1:
            x = x[None, :]
        return x

    # -------- core calc for one batch --------
    def _sum_nutrients_batch(self, grams_np: np.ndarray) -> Dict[str, np.ndarray]:
        """Calculate total nutrients per meal."""
        B, F = grams_np.shape
        if F != len(self.food_table):
            raise ValueError(f"Input vector length {F} does not match number of foods {len(self.food_table)}.")

        def stack_field(field: str, to_mg=False) -> np.ndarray:
            arr = np.array([rec[field] for rec in self.food_table], dtype=np.float64)
            if to_mg:
                return arr
            return arr

        grams_ratio = grams_np / 100.0

        def total_from_field(field: str) -> np.ndarray:
            per100 = stack_field(field)
            return grams_ratio @ per100

        total = {
            "total_fruits_g": total_from_field("total_fruits_100g"),
            "whole_fruits_g": total_from_field("whole_fruits_100g"),
            "total_vegetables_g": total_from_field("total_vegetables_100g"),
            "greens_and_beans_g": total_from_field("greens_and_beans_100g"),
            "whole_grains_g": total_from_field("whole_grains_100g"),
            "dairy_g": total_from_field("dairy_100g"),
            "total_protein_foods_g": total_from_field("total_protein_foods_100g"),
            "seafood_plant_proteins_g": total_from_field("seafood_plant_proteins_100g"),
            "refined_grains_g": total_from_field("refined_grains_100g"),
            "added_sugars_g": total_from_field("added_sugars_100g"),
            "saturated_fats_g": total_from_field("saturated_fats_100g"),
            "pufas_g": total_from_field("pufas_100g"),
            "mufas_g": total_from_field("mufas_100g"),
            "sfas_g": total_from_field("sfas_100g"),
            "total_kcal": total_from_field("kcal_per_100g"),
        }

        sodium_per100_mg = stack_field("sodium_100g")
        total["sodium_g_total"] = (grams_ratio @ sodium_per100_mg) / 1000.0

        return total

    @staticmethod
    def _calorie_score_kde_like(total_kcal: np.ndarray, mu: float = 1200.0, sigma: float = 150.0) -> np.ndarray:
        pdf = (1.0 / (sigma * math.sqrt(2.0 * math.pi))) * np.exp(-((total_kcal - mu) ** 2) / (2.0 * sigma**2))
        max_pdf = 1.0 / (sigma * math.sqrt(2.0 * math.pi))
        return (pdf / max_pdf) * 100.0

    def evaluate(self, grams_tensor: Union[List[float], np.ndarray, "TorchTensor"]) -> Union[Dict[str, float], List[Dict[str, float]]]:
        """Evaluate one or a batch of samples."""
        grams_np = self._to_numpy_2d(grams_tensor)
        totals = self._sum_nutrients_batch(grams_np)

        B = grams_np.shape[0]
        out: List[Dict[str, float]] = []
        for i in range(B):
            hei_dict = self.hei.calculate_from_totals(
                total_fruits_g=totals["total_fruits_g"][i],
                whole_fruits_g=totals["whole_fruits_g"][i],
                total_vegetables_g=totals["total_vegetables_g"][i],
                greens_and_beans_g=totals["greens_and_beans_g"][i],
                whole_grains_g=totals["whole_grains_g"][i],
                dairy_g=totals["dairy_g"][i],
                total_protein_foods_g=totals["total_protein_foods_g"][i],
                seafood_plant_proteins_g=totals["seafood_plant_proteins_g"][i],
                refined_grains_g=totals["refined_grains_g"][i],
                sodium_g_total=totals["sodium_g_total"][i],
                added_sugars_g=totals["added_sugars_g"][i],
                saturated_fats_g=totals["saturated_fats_g"][i],
                pufas_g=totals["pufas_g"][i],
                mufas_g=totals["mufas_g"][i],
                sfas_g=totals["sfas_g"][i],
                total_kcal=totals["total_kcal"][i],
            )

            cal_score = float(self._calorie_score_kde_like(np.array([totals["total_kcal"][i]]))[0])
            result = {**hei_dict, "Calorie_Score": round(cal_score, 2),
                      "Total_Score": round(hei_dict["Total_HEI_Score"] + cal_score, 2),
                      "Total_Kcal": round(float(totals["total_kcal"][i]), 2)}
            out.append(result)

        return out[0] if B == 1 else out

from dataclasses import dataclass, field

def _here(*paths: str) -> str:
    """Get path relative to this file."""
    return os.path.join(os.path.dirname(os.path.abspath(__file__)), *paths)

DEFAULT_FOOD_CSV_PATH = _here("food.csv")

PARAM_NAMES = [
    "multigrain_bread",
    "whole_wheat_bread",
    "sourdough_bread",
    "chicken_protein",
    "tuna_protein",
    "tofu_protein",
    "hummus_protein",
    "egg_protein",
    "low_fat_cheese_dairy",
    "cheddar_cheese",
    "swiss_cheese_dairy",
    "collards",
    "cabbage",
    "onion_vegetables",
    "tomato_vegetables",
    "mayo_sauce",
    "olive_oil",
    "apples",
    "orange",
    "banana",
]

DEFAULT_BOUNDS = [
    (0.0, 140),   # multigrain_bread
    (0.0, 140),   # whole_wheat_bread
    (0.0, 140),   # sourdough_bread
    (0.0, 100),   # chicken_protein
    (0.0, 100),   # tuna_protein
    (0.0, 80),   # tofu_protein
    (0.0, 70),    # hummus_protein
    (0.0, 80),   # egg_protein
    (0.0, 20),    # low_fat_cheese_dairy
    (0.0, 20),    # cheddar_cheese
    (0.0, 20),    # swiss_cheese_dairy
    (0.0, 80),   # collards
    (0.0, 80),   # cabbage
    (0.0, 80),   # onion_vegetables
    (0.0, 80),   # tomato_vegetables
    (0.0, 15),    # mayo_sauce
    (0.0, 20),    # olive_oil
    (0.0, 100),   # apples
    (0.0, 100),   # orange
    (0.0, 100),   # banana
]

@dataclass
class SandwichEvaluator:
    """Sandwich task evaluator wrapping FoodEvaluator."""
    food_csv: Optional[str] = DEFAULT_FOOD_CSV_PATH
    bounds: List[Tuple[float, float]] = field(init=False)
    
    def __post_init__(self):
        self.food_evaluator = FoodEvaluator(self.food_csv)
        
        self.bounds = DEFAULT_BOUNDS.copy()
    
    def evaluate_from_dict(self, params: Dict[str, Any]) -> Dict[str, float]:
        """Evaluate from a parameter dict."""
        x_list = [float(params.get(name, 0.0)) for name in PARAM_NAMES]
        return self.evaluate_from_list(x_list)
    
    def evaluate_from_list(self, x: List[Any]) -> Dict[str, float]:
        """Evaluate from a parameter list."""
        if len(x) != len(PARAM_NAMES):
            raise ValueError(f"Expected {len(PARAM_NAMES)} parameters, got {len(x)}")
        
        import numpy as np
        x_array = np.array([float(v) for v in x], dtype=np.float64)
        
        result = self.food_evaluator.evaluate(x_array)
        
        if isinstance(result, dict):
            return result
        elif isinstance(result, list) and len(result) == 1:
            return result[0]
        else:
            raise ValueError(f"Unexpected result type from FoodEvaluator: {type(result)}")
    
    def evaluate(self, x: Any) -> Dict[str, float]:
        """Generic interface accepting either dict or list."""
        if isinstance(x, (list, tuple)):
            return self.evaluate_from_list(list(x))
        elif isinstance(x, dict):
            return self.evaluate_from_dict(x)
        else:
            raise TypeError(f"Unsupported input type {type(x)}; must be list or dict.")