import os
import math
import json
from copy import deepcopy
from dataclasses import dataclass
from datetime import datetime
from typing import Dict, Tuple, Optional, List

import numpy as np
import pandas as pd
import torch
import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

PNAS_BASELINES = ["DDB", "CBF", "EFB"]
# PNAS_BASELINES = ["DDB"]


# ---------------------------
# Utilities
# ---------------------------
def _to_tensor(x, dtype=torch.float32, dev=device):
    return torch.tensor(x, dtype=dtype, device=dev)


def get_model_parameters(model: nn.Module) -> Dict[str, object]:
    out = {}
    for name, p in model.named_parameters():
        if p.numel() == 1:
            out[name] = float(p.detach().cpu().item())
        else:
            out[name] = p.detach().cpu().tolist()
    return out


def flatten(lst):
    for item in lst:
        if isinstance(item, list):
            yield from flatten(item)
        else:
            yield item


def get_negative_parameters(param_dict: Dict[str, object]) -> Dict[str, object]:
    neg = {}
    for k, v in param_dict.items():
        if isinstance(v, (int, float)) and v < 0:
            neg[k] = v
        elif isinstance(v, list):
            vals = list(flatten(v))
            bad = [x for x in vals if isinstance(x, (int, float)) and x < 0]
            if bad:
                neg[k] = bad
    return neg


# ---------------------------
# Data loading
# ---------------------------
@dataclass
class PNASLocationPaths:
    location_root: str

    @property
    def epi_csv(self) -> str:
        return os.path.join(self.location_root, "epi-data", "epi_data.csv")

    @property
    def mob_csv(self) -> str:
        return os.path.join(self.location_root, "google-mobility-report", "google_mobility_data.csv")

    @property
    def pop_csv(self) -> str:
        return os.path.join(self.location_root, "population-data", "pop_data_Nk.csv")

    @property
    def hemi_csv(self) -> str:
        return os.path.join(self.location_root, "hemisphere", "hemisphere.csv")

    @property
    def cm_dir(self) -> str:
        return os.path.join(self.location_root, "contact_matrix")

    def cm_npz(self, name: str) -> str:
        return os.path.join(self.cm_dir, f"{name}.npz")


def load_population_Nk(pop_csv: str) -> np.ndarray:
    df = pd.read_csv(pop_csv)
    # expected: one row, 10 columns
    Nk = df.iloc[0].values.astype(np.float64)
    Nk = np.maximum(Nk, 1.0)
    return Nk


def load_hemisphere_code(hemi_csv: str) -> int:
    df = pd.read_csv(hemi_csv)
    # expected columns: hemisphere_code, hemisphere
    # your screenshot shows hemisphere_code=0, hemisphere=north_hemisphere
    code = int(df.iloc[0]["hemisphere_code"])
    return code


def load_epi(epi_csv: str) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    df = pd.read_csv(epi_csv)
    df["date"] = pd.to_datetime(df["date"])
    df = df.sort_values("date")

    yD = df["new_deaths"].fillna(0.0).astype(float).values
    yC = df["new_cases"].fillna(0.0).astype(float).values
    dates = df["date"].values
    return dates, yD, yC


def load_contact_matrix(cm_dir: str) -> np.ndarray:
    """
    Prefer contact_matrix.npz if present; else sum {home,work,school,community}.
    Assumes each npz contains an array (commonly under key 'arr_0').
    """
    agg_path = os.path.join(cm_dir, "contact_matrix.npz")
    if os.path.exists(agg_path):
        z = np.load(agg_path)
        C = z[list(z.keys())[0]]
        return C.astype(np.float64)

    parts = []
    for name in ["home", "work", "school", "community"]:
        p = os.path.join(cm_dir, f"{name}.npz")
        if os.path.exists(p):
            z = np.load(p)
            A = z[list(z.keys())[0]]
            parts.append(A.astype(np.float64))

    if not parts:
        raise FileNotFoundError(f"No contact matrices found in {cm_dir}")
    C = np.sum(parts, axis=0)
    return C


def compute_contact_reduction(mob_csv: str, dates: np.ndarray) -> np.ndarray:
    """
    Matches PNAS: r(t) = (1 + mean(mobility_cols)/100)^2
    Uses the same 5 columns as the example notebook (excludes parks).
    """
    mob = pd.read_csv(mob_csv)
    mob["date"] = pd.to_datetime(mob["date"])
    mob = mob.sort_values("date").set_index("date")

    cols = [
        "retail_and_recreation_percent_change_from_baseline",
        "grocery_and_pharmacy_percent_change_from_baseline",
        "transit_stations_percent_change_from_baseline",
        "workplaces_percent_change_from_baseline",
        "residential_percent_change_from_baseline",
    ]
    missing = [c for c in cols if c not in mob.columns]
    if missing:
        raise ValueError(f"Missing mobility columns: {missing}")

    mob_sig = mob[cols].ffill().fillna(0.0)

    idx = pd.to_datetime(dates)
    mob_sig = mob_sig.reindex(idx).ffill().fillna(0.0)

    m = mob_sig.mean(axis=1).values.astype(np.float64)  # percent change
    r = (1.0 + m / 100.0) ** 2
    return r

def load_ifr(constants_path: str = None) -> np.ndarray:
    """
    Load age-specific IFR (length K=10) from your PNAS baseline constants.

    If you already have IFR in a file, replace this loader accordingly.
    """
    if constants_path is None:
        # Default: import from constants.py in the same project
        # (adjust import path if needed)
        from libs.covasim.constants import IFR_10age
        ifr = np.asarray(IFR_10age, dtype=np.float64)
    else:
        # Example: load from JSON/CSV if you prefer
        if constants_path.endswith(".json"):
            with open(constants_path, "r") as f:
                ifr = np.asarray(json.load(f), dtype=np.float64)
        else:
            df = pd.read_csv(constants_path)
            ifr = df.iloc[0].values.astype(np.float64)

    if ifr.ndim != 1:
        ifr = ifr.reshape(-1)
    return ifr

def apply_seasonality(day: datetime, seasonality_min: float, hemisphere_code: int, seasonality_max: float = 1.0) -> float:
    """
    Copy of utils.apply_seasonality logic:
      hemisphere_code: 0=north, 1=tropical, 2=south
    """
    s_r = seasonality_min / seasonality_max
    day_max_north = datetime(day.year, 1, 15)
    day_max_south = datetime(day.year, 7, 15)

    # north
    north = 0.5 * ((1 - s_r) * np.sin(2 * np.pi / 365 * (day - day_max_north).days + 0.5 * np.pi) + 1 + s_r)
    # tropical
    trop = 1.0
    # south
    south = 0.5 * ((1 - s_r) * np.sin(2 * np.pi / 365 * (day - day_max_south).days + 0.5 * np.pi) + 1 + s_r)

    return [north, trop, south][hemisphere_code]

def compute_seasonality_vector(dates: np.ndarray, hemisphere_code: int, seasonality_min: float) -> np.ndarray:
    dates = pd.to_datetime(dates)
    return np.array(
        [apply_seasonality(d.to_pydatetime(), seasonality_min, hemisphere_code) for d in dates],
        dtype=np.float64
    )

def compute_lambda_max_C_hat(C: np.ndarray, Nk: np.ndarray) -> float:
    """
    Matches PNAS get_beta spectral radius step:
      C_hat[i,j] = (Nk[i]/Nk[j]) * C[i,j]
      lambda_max = max(real(eigvals(C_hat)))
    """
    Nk = Nk.astype(np.float64)
    C = C.astype(np.float64)

    C_hat = (Nk[:, None] / (Nk[None, :] + 1e-12)) * C
    eigvals = np.linalg.eigvals(C_hat)
    lam_max = float(np.max(np.real(eigvals)))
    return lam_max

def make_data_bundle(base_dir: str, location: str, seasonality_min: float = 0.75) -> Dict[str, torch.Tensor]:
    root = os.path.join(base_dir, location)
    paths = PNASLocationPaths(location_root=root)

    dates, yD, yC = load_epi(paths.epi_csv)
    Nk = load_population_Nk(paths.pop_csv)
    hemi = load_hemisphere_code(paths.hemi_csv)  # should be 0/1/2 per PNAS utils
    C = load_contact_matrix(paths.cm_dir)

    r = compute_contact_reduction(paths.mob_csv, dates)
    s = compute_seasonality_vector(dates, hemi, seasonality_min)
    lam_max = compute_lambda_max_C_hat(C, Nk)
    ifr = load_ifr()

    Ntot = float(np.sum(Nk))
    yD_norm = yD 
    yC_norm = yC
    # yD_norm = yD / max(Ntot, 1.0)
    # yC_norm = yC / max(Ntot, 1.0)	

    bundle = {
        "dates": dates,
        "C": _to_tensor(C),
        "Nk": _to_tensor(Nk),
        "hemisphere_code": torch.tensor(int(hemi), device=device),
        "r_mobility": _to_tensor(r),
        "seasonality": _to_tensor(s),                 
        "lambda_max": torch.tensor(lam_max, device=device), 
        "ifr": _to_tensor(ifr),
        "y_deaths": _to_tensor(yD_norm),
        "y_cases": _to_tensor(yC_norm),
        "Ntot": torch.tensor(Ntot, device=device),
    }
    return bundle

class CalibrationLoss(nn.Module):
    """
    Loss function for calibrating epidemic models.
    Supports multiple loss types and weighted combinations.
    """
    
    def __init__(
        self, 
        loss_type: str = "mse",
        weights: Dict[str, float] = None
    ):
        super().__init__()
        self.loss_type = loss_type
        self.weights = weights or {"deaths": 1.0, "cases": 0.5}
        
    def forward(
        self,
        pred_deaths: torch.Tensor,
        true_deaths: torch.Tensor,
        pred_cases: torch.Tensor = None,
        true_cases: torch.Tensor = None,
    ) -> torch.Tensor:
        """
        Compute weighted loss between predictions and ground truth.
        
        Args:
            pred_deaths: Predicted daily deaths [T]
            true_deaths: Observed daily deaths [T]
            pred_cases: Predicted daily cases [T] (optional)
            true_cases: Observed daily cases [T] (optional)
        """
        
        loss = 0.0
        
        # Deaths loss
        if self.loss_type == "mse":
            deaths_loss = torch.mean((pred_deaths - true_deaths) ** 2)
        elif self.loss_type == "mae":
            deaths_loss = torch.mean(torch.abs(pred_deaths - true_deaths))
        elif self.loss_type == "rmse":
            deaths_loss = torch.sqrt(torch.mean((pred_deaths - true_deaths) ** 2))
        elif self.loss_type == "log_mse":
            deaths_loss = torch.mean(
                (torch.log1p(pred_deaths) - torch.log1p(true_deaths)) ** 2
            )
        else:
            raise ValueError(f"Unknown loss type: {self.loss_type}")
        
        loss += self.weights.get("deaths", 1) * deaths_loss
    
        
        return loss
	
def split_bundle(bundle: Dict[str, torch.Tensor], train_frac=0.98, val_frac=0.01):
	T = int(bundle["y_deaths"].shape[0])
	n_train = max(10, int(train_frac * T))
	n_val = max(10, int(val_frac * T))
	n_train = min(n_train, T - 2)
	n_val = min(n_val, T - n_train - 1)

	idx_train = slice(0, n_train)
	idx_val = slice(n_train, n_train + n_val)
	idx_test = slice(n_train + n_val, T)

	def sub(idx):
		out = dict(bundle)
		out["y_deaths"] = bundle["y_deaths"][idx].clone()
		out["y_cases"]  = bundle["y_cases"][idx].clone()
		out["r_mobility"] = bundle["r_mobility"][idx].clone()
		return out

	return sub(idx_train), sub(idx_val), sub(idx_test)	


	
def load_data_withAggregation(location: str) -> Dict[str, torch.Tensor]:
	base_dir = "/project/biocomplexity/hht9zt/Public-Health-Agent/data"
	bundle = make_data_bundle(base_dir, location)
	train_b, val_b, test_b = split_bundle(bundle)
	val_b = train_b
	return train_b, val_b, test_b
	

def load_data(location):
    bundle,_,_ = load_data_withAggregation(location) 
    def weekly_sum(x):
        x = x[: (len(x)//7)*7]
        return x.reshape(-1, 7).sum(axis=1)
        
    bundle["y_deaths"] = weekly_sum(bundle["y_deaths"])
    bundle["y_cases"] = weekly_sum(bundle["y_cases"])
    def weekly_mean(x):
        x = x[: (len(x)//7)*7]
        return x.reshape(-1, 7).mean(axis=1)
    bundle["r_mobility"] = weekly_mean(bundle["r_mobility"])
    bundle["seasonality"] = weekly_mean(bundle["seasonality"])
    return bundle, bundle, bundle



import matplotlib.pyplot as plt

def plot_pred_vs_gt(
    dates,
    y_true,
    y_pred,
    title,
    ylabel,
    out_path,
):
    """
    dates: array-like (datetime or index)
    y_true, y_pred: 1D numpy arrays or tensors
    """
    if torch.is_tensor(y_true):
        y_true = y_true.detach().cpu().numpy()
    if torch.is_tensor(y_pred):
        y_pred = y_pred.detach().cpu().numpy()

    plt.figure(figsize=(10, 4))
    plt.plot(dates, y_true, label="GT", linewidth=2)
    plt.plot(dates, y_pred, label="Prediction", linestyle="--")
    plt.xlabel("Time")
    plt.ylabel(ylabel)
    plt.title(title)
    plt.legend()
    plt.tight_layout()
    plt.savefig(out_path, dpi=200)
    plt.close()


# ---------------------------
# PNASEnv: training + evaluation
# ---------------------------
class PNASEnv:
    """
    EPITWIN-compatible environment for PNAS behavioral baselines.

    Simulator contract:
      StateDifferential(T, data_bundle, baseline=...)
        -> forward() returns (pred_deaths[T], pred_cases[T]) as daily incident rates (normalized)
    """

    def __init__(self, save_dir: str = "./pnas_runs"):
        self.save_dir = save_dir
        os.makedirs(self.save_dir, exist_ok=True)

    def evaluate_simulator_code_wrapper(self, StateDifferential, train_data, val_data, test_data, config={}, logger=None, env_name=''):
        if config.run.optimizer == 'pytorch':
            train_loss, val_loss, optimized_parameters, loss_per_dim, test_loss, sc_output = self.evaluate_simulator_code_using_pytorch(StateDifferential, train_data, val_data, test_data, config=config, logger=logger, env_name=env_name)
        if env_name == 'PNAS':
            print(loss_per_dim)
            loss_per_dim_dict =  {'infected': loss_per_dim}
        return train_loss, val_loss, optimized_parameters, loss_per_dim_dict, test_loss, sc_output
		

    def evaluate_simulator_code_using_pytorch(
        self,
        StateDifferential,
		train_b, 
		val_b, 
		test_b,
		config = {},
		logger = None,
		env_name = ''
        
    ):
        location ='new_york'
        baseline_train = "DDB"
        optimize_params = True
        epochs = config.run.pytorch_as_optimizer.epochs
        grad_clip = 10.0
        use_cases_in_loss = True
        patience=config.run.optimization.patience
        T_train = int(train_b["y_deaths"].shape[0])

        model = StateDifferential(T_train, train_b, baseline=baseline_train).to(device)
        model.train()

        mse = nn.MSELoss(reduction="none")

        opt = torch.optim.Adam(
            model.parameters(),
        	lr=5e-3,          # IMPORTANT: higher than default
        	weight_decay=1e-4
        )
        
        sch = torch.optim.lr_scheduler.ReduceLROnPlateau(
        	opt,
        	mode="min",
        	factor=0.5,
        	patience=20,
        )

        loss_fn = CalibrationLoss()
		
        def run_eval(m, b):
            m.eval()
            with torch.no_grad():
                predD, predC = m()
                yD, yC = b["y_deaths"], b["y_cases"]
                horizon = int(yD.shape[0])
                predD = predD[:horizon]
                predC = predC[:horizon]
                L = loss_fn(predD, predC, yD, yC).item()
            m.train()
            return L

        best = None
        best_val = float("inf")
        bad = 0
        prev_L = 0
        if optimize_params:
            for ep in range(epochs):
                opt.zero_grad(set_to_none=True)
                predD, predC = model()
                yD, yC = train_b["y_deaths"], train_b["y_cases"]
                predD = predD[: int(yD.shape[0])]
                predC = predC[: int(yC.shape[0])]
                L = loss_fn(predD, predC, yD, yC)
                if L == prev_L or torch.isnan(L):
                    break
                else:
                    prev_L = L
                L.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip)
                opt.step()
                sch.step(L)

                if ep % config.run.pytorch_as_optimizer.log_interval == 0:
                    valL = run_eval(model, val_b)
                    print(f"[{location} | {baseline_train}] ep={ep:04d} train={float(L.item()):.6f} val={valL:.6f}")

                    if valL < best_val:
                        best_val = valL
                        best = deepcopy(model.state_dict())
                        bad = 0
                    else:
                        bad += 1
                        if bad >= patience:
                            print(f"Early stop at ep={ep} (best val={best_val:.6f})")
                            break

        if best is not None:
            model.load_state_dict(best)
            model.eval()

        with torch.no_grad():
            predD_tr, predC_tr = model()
            yD_tr = train_b["y_deaths"]
            yC_tr = train_b["y_cases"]
			
            Ttr = int(yD_tr.shape[0])
            predD_tr = predD_tr[:Ttr]
            predC_tr = predC_tr[:Ttr]
			
            dates_tr = bundle_dates = train_b["dates"][:Ttr] if "dates" in train_b else np.arange(Ttr)
			
            ts = datetime.now().strftime("%Y%m%d-%H%M%S")
			
            plot_pred_vs_gt(
	            dates_tr,
	            yD_tr,
	            predD_tr,
	            title=f"{location} | Train Deaths",
	            ylabel="Deaths (normalized)",
	            out_path=os.path.join(self.save_dir, f"{location}_{ts}_train_deaths.png"),
            )
			
            plot_pred_vs_gt(
                dates_tr,
                yC_tr,
                predC_tr,
                title=f"{location} | Train Cases",
                ylabel="Cases (normalized)",
                out_path=os.path.join(self.save_dir, f"{location}_{ts}_train_cases.png"),
            )

        # Evaluate under all three baselines by toggling model.baseline
        results = {}
        for bl in PNAS_BASELINES:
            model.baseline = bl
            # IMPORTANT: also swap mobility/targets horizon to test bundle length
            # Rebuild a test-time model instance to match test horizon (cleanest)
            Tb = int(train_b["y_deaths"].shape[0])
            test_model = StateDifferential(Tb, train_b, baseline=bl).to(device)
            test_model.load_state_dict(model.state_dict(), strict=False)
            test_model.eval()

            with torch.no_grad():
                predD, predC = test_model()
                yD, yC = train_b["y_deaths"], train_b["y_cases"]
                horizon = int(yD.shape[0])
                predD = predD[:horizon]
                predC = predC[:horizon]
                test_loss = loss_fn(predD, predC, yD, yC).item()
                dates_te = train_b["dates"][:horizon] if "dates" in train_b else np.arange(horizon)				
				
                plot_pred_vs_gt(
                    dates_te,
                    yD,
                    predD,
                    title=f"{location} | Test Deaths | {bl}",
                    ylabel="Deaths (normalized)",
                    out_path=os.path.join(self.save_dir, f"{location}_{ts}_test_deaths_{bl}.png"),
                )
				
				
                plot_pred_vs_gt(
                    dates_te,
                    yC,
                    predC,
                    title=f"{location} | Test Cases | {bl}",
                    ylabel="Cases (normalized)",
                    out_path=os.path.join(self.save_dir, f"{location}_{ts}_test_cases_{bl}.png"),
                )

                # simple summary metrics
                peakD = float(predD.max().item())
                peakD_t = int(predD.argmax().item())
                peakC = float(predC.max().item())
                peakC_t = int(predC.argmax().item())

                params = get_model_parameters(test_model)
                neg = get_negative_parameters(params)

            results[bl] = {
                "test_loss": test_loss,
                "peak_deaths": peakD,
                "peak_deaths_day": peakD_t,
                "peak_cases": peakC,
                "peak_cases_day": peakC_t,
                "negative_parameters": neg,
            }

        # Save artifacts
        ts = datetime.now().strftime("%Y%m%d-%H%M%S")
        out = {
            "location": location,
            "trained_baseline": baseline_train,
            "best_val_loss": best_val,
            "results": results,
            "parameters": get_model_parameters(model),
        }
        with open(os.path.join(self.save_dir, f"{location}_{ts}_summary.json"), "w") as f:
            json.dump(out, f, indent=2)
			
		# train_loss, val_loss, optimized_parameters, loss_per_dim, test_loss, _

        return L.item(), best_val, params, L.item(), test_loss, results