
import os, re, json, math, random, time
from dataclasses import dataclass
from typing import Dict, List, Any, Tuple, Optional
from datetime import datetime
import numpy as np

import torch

@dataclass
class CFG:
    DATA_DIR: str = "."
    TRAIN_OUT: str = "train.txt"
    VAL_OUT: str   = "val.txt"
    TEST_OUT: str  = "test.txt"

    K: int = 8
    H: int = 1

    MODEL_NAME: str = "gpt2"
    BATCH_SIZE: int = 32
    EPOCHS: int = 20
    LR: float = 2e-4
    WEIGHT_DECAY: float = 0.01
    MAX_LEN: int = 1024

    DTW_BAND: int = 2
    PAIRS_PER_BATCH: int = 32

    MAX_NEWS_ARTICLES_PER_WINDOW: int = 3
    NEWS_MAX_CHARS: int = 2500

    PRINT_EVERY_STEPS: int = 50
    PRINT_ALL_TEST_SAMPLES: bool = True
    PRINT_SOME_TRAIN_SAMPLES: int = 2

    SEED: int = 1337
    DEVICE: str = "cuda" if torch.cuda.is_available() else "cpu"

cfg = CFG()

print("[CFG]", cfg)

PATH_GOV = os.path.join(cfg.DATA_DIR, "PaperReady_g_scores.txt")
PATH_ENV = os.path.join(cfg.DATA_DIR, "PaperReady_e_scores.txt")
PATH_SOC = os.path.join(cfg.DATA_DIR, "PaperReady_s_scores.txt")
PATH_ESG = os.path.join(cfg.DATA_DIR, "esg_risk_ratings_1.txt")
PATH_DATES = os.path.join(cfg.DATA_DIR, "PaperReady_esg_dates.txt")
PATH_RET_OUT = os.path.join(cfg.DATA_DIR, "PaperReady_ret.txt")

required = [PATH_GOV, PATH_ENV, PATH_SOC, PATH_ESG, PATH_DATES]
for p in required:
    if not os.path.exists(p):
        raise FileNotFoundError(f"[MISSING REQUIRED FILE] {p}")

print("[OK] Found all required score/date files.")

random.seed(cfg.SEED)
np.random.seed(cfg.SEED)
torch.manual_seed(cfg.SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(cfg.SEED)

def now_str():
    return datetime.now().strftime("%Y-%m-%d %H:%M:%S")

def log(msg: str):
    print(f"[{now_str()}] {msg}")

log("Phase 0 ready.")

"""# Phase 1"""


def parse_ddmmyyyy(s: str) -> datetime:
    return datetime.strptime(s, "%d-%m-%Y")

import ast
from typing import Dict, List

def load_esg_dates(path_dates: str) -> Dict[str, List[str]]:
    """
    PaperReady_esg_dates.txt contains a pure dict literal:
      { 'A': [...], 'AAPL': [...], ... }
    So we parse it safely via ast.literal_eval.
    """
    with open(path_dates, "r", encoding="utf-8") as f:
        txt = f.read()

    txt = txt.strip()

    d = ast.literal_eval(txt)
    if not isinstance(d, dict):
        raise ValueError("PaperReady_esg_dates.txt must be a dict literal like { 'A': [...], ... }")

    for k, v in d.items():
        if not isinstance(k, str) or not isinstance(v, list):
            raise ValueError(f"Bad entry for {k}: expected str -> list[str]")
    return d


def parse_series_file_as_records(path: str) -> List[Dict[str, str]]:
    """
    Your files are python-like: [
        {'text': 'Company: AAPL ENV: 74.00 ... <EOS>'},
        ...
    ]
    We'll extract each 'text' entry using regex, robustly.
    """
    with open(path, "r", encoding="utf-8") as f:
        raw = f.read()

    matches = re.findall(r"\{\s*'text'\s*:\s*'([^']*)'\s*\}", raw)
    if len(matches) == 0:
        matches = re.findall(r'\{\s*"text"\s*:\s*"([^"]*)"\s*\}', raw)

    recs = [{"text": m} for m in matches]
    return recs

def parse_company_series_from_records(records: List[Dict[str,str]], series_tag: str) -> Dict[str, List[float]]:
    """
    series_tag: "ESG" or "ENV" or "SOC" or "GOV"
    text examples:
      "Company: AAPL ENV: 74.00 74.00 ... <EOS>"
    """
    out: Dict[str, List[float]] = {}
    for r in records:
        t = r["text"]

        m = re.search(r"Company:\s*([A-Z0-9\.\-]+)\s+", t)
        if not m:
            continue
        ticker = m.group(1).strip()

        m2 = re.search(rf"\b{series_tag}\s*:\s*(.*)", t)
        if not m2:
            continue
        tail = m2.group(1)

        tail = tail.replace("<EOS>", " ").replace("</EOS>", " ")
        parts = tail.strip().split()

        vals = []
        for p in parts:
            try:
                vals.append(float(p))
            except:
                pass

        if len(vals) == 0:
            continue
        out[ticker] = vals
    return out


esg_dates = load_esg_dates(PATH_DATES)

core = sorted(esg_dates.keys())
log(f"[DATES] tickers={len(core)}  example={core[:5]}")
log(f"[DATES] example {core[0]} n_dates={len(esg_dates[core[0]])} first={esg_dates[core[0]][0]} last={esg_dates[core[0]][-1]}")

gov_records = parse_series_file_as_records(PATH_GOV)
env_records = parse_series_file_as_records(PATH_ENV)
soc_records = parse_series_file_as_records(PATH_SOC)
esg_records = parse_series_file_as_records(PATH_ESG)

gov_series = parse_company_series_from_records(gov_records, "GOV")
env_series = parse_company_series_from_records(env_records, "ENV")
soc_series = parse_company_series_from_records(soc_records, "SOC")
esg_series = parse_company_series_from_records(esg_records, "ESG")

log(f"[SERIES] ESG={len(esg_series)} ENV={len(env_series)} SOC={len(soc_series)} GOV={len(gov_series)}")

def align_length_safe(ticker: str, dates_list: List[str], series_map: Dict[str, List[float]], name: str) -> List[float]:
    """
    Returns series aligned to dates length by trimming both to min length.
    If missing series for ticker, returns [] and logs.
    """
    if ticker not in series_map:
        log(f"[WARN] {name} missing for {ticker} -> will skip where needed")
        return []
    vals = series_map[ticker]
    n = min(len(vals), len(dates_list))
    if len(vals) != len(dates_list):
        log(f"[WARN] length mismatch {ticker} {name}: series={len(vals)} dates={len(dates_list)} -> trim to {n}")
    return vals[:n]

aligned = {}
for t in core:
    dlist = esg_dates[t]
    esg = align_length_safe(t, dlist, esg_series, "ESG")
    env = align_length_safe(t, dlist, env_series, "ENV")
    soc = align_length_safe(t, dlist, soc_series, "SOC")
    gov = align_length_safe(t, dlist, gov_series, "GOV")

    lengths = [len(dlist)]
    for arr in [esg, env, soc, gov]:
        if len(arr) > 0:
            lengths.append(len(arr))
    L = min(lengths)

    aligned[t] = {
        "dates": dlist[:L],
        "ESG": esg[:L] if len(esg)>0 else [],
        "ENV": env[:L] if len(env)>0 else [],
        "SOC": soc[:L] if len(soc)>0 else [],
        "GOV": gov[:L] if len(gov)>0 else [],
        "L": L,
    }

log(f"[ALIGN] built aligned dict for {len(aligned)} tickers.")
ex = core[0]
log(f"[ALIGN EX] {ex} L={aligned[ex]['L']}  ESG?={len(aligned[ex]['ESG'])} ENV?={len(aligned[ex]['ENV'])} SOC?={len(aligned[ex]['SOC'])} GOV?={len(aligned[ex]['GOV'])}")
log("Phase 1 ready.")

"""# Phase 2"""


import glob
from transformers import pipeline

def load_news_file_xmlish(path: str) -> List[Dict[str, Any]]:
    """
    Parses your XML-like structure:
      <company name="AAPL">
      <news>
        <date>22-01-2016</date>
        <data>...</data>
        <source>Newsroom</source>
      </news>
      ...
    Returns list of dicts: {"date": datetime, "text": str, "source": str}
    Tolerant parser: regex-based.
    """
    with open(path, "r", encoding="utf-8", errors="ignore") as f:
        txt = f.read()

    items = []
    blocks = re.findall(r"<news>(.*?)</news>", txt, flags=re.DOTALL | re.IGNORECASE)
    for b in blocks:
        d = re.search(r"<date>(.*?)</date>", b, flags=re.DOTALL | re.IGNORECASE)
        t = re.search(r"<data>(.*?)</data>", b, flags=re.DOTALL | re.IGNORECASE)
        s = re.search(r"<source>(.*?)</source>", b, flags=re.DOTALL | re.IGNORECASE)

        if not d or not t:
            continue

        dstr = d.group(1).strip()
        text = re.sub(r"\s+", " ", t.group(1).strip())
        source = re.sub(r"\s+", " ", s.group(1).strip()) if s else ""

        try:
            dt = datetime.strptime(dstr, "%d-%m-%Y")
        except:
            continue

        if len(text) == 0:
            continue

        items.append({"date": dt, "text": text, "source": source})

    items.sort(key=lambda x: x["date"])
    return items

def load_all_news_from_folder(data_dir: str, tickers: List[str]) -> Dict[str, List[Dict[str, Any]]]:
    """
    Loads TICKER.txt if exists. Missing files are allowed.
    """
    news_db: Dict[str, List[Dict[str, Any]]] = {}
    missing = 0
    for t in tickers:
        p = os.path.join(data_dir, f"{t}.txt")
        if not os.path.exists(p):
            news_db[t] = []
            missing += 1
            continue
        try:
            news_db[t] = load_news_file_xmlish(p)
        except Exception as e:
            log(f"[WARN] failed to parse news for {t}: {e}")
            news_db[t] = []
    log(f"[NEWS] loaded for {len(tickers)-missing}/{len(tickers)} tickers (missing {missing})")
    return news_db

news_db = load_all_news_from_folder(cfg.DATA_DIR, core)

ex_with_news = None
for t in core[:50]:
    if len(news_db.get(t, [])) > 0:
        ex_with_news = t
        break
if ex_with_news:
    log(f"[NEWS EX] {ex_with_news} n_news={len(news_db[ex_with_news])} first={news_db[ex_with_news][0]['date'].strftime('%d-%m-%Y')}")
    log(f"[NEWS EX TEXT] {news_db[ex_with_news][0]['text'][:120]}...")
else:
    log("[NEWS] No local news found in first 50 tickers (that's okay if files not present).")

def fetch_news_from_api(ticker: str, start_dt: datetime, end_dt: datetime) -> List[Dict[str, Any]]:
    """
    If local news file is missing or empty, paste your API call here.
    Must return list of dicts:
      [{"date": datetime, "text": "...", "source": "..."}]
    """
    return []

SENTIMENT_MODEL = "cardiffnlp/twitter-roberta-base-sentiment-latest"
log(f"[SENTI] loading RoBERTa sentiment model: {SENTIMENT_MODEL}")
sent_pipe = pipeline("sentiment-analysis", model=SENTIMENT_MODEL, device=0 if cfg.DEVICE=="cuda" else -1)

def score_to_senti_token(label: str, score: float) -> str:
    """
    Map model output to {-10,0,10}.
    For cardiffnlp model labels are often: 'positive','neutral','negative' (or 'LABEL_0' style).
    We'll handle common variants.
    """
    lab = label.lower()
    if "pos" in lab or lab.endswith("2"):
        return "<SENTI_10>"
    if "neg" in lab or lab.endswith("0"):
        return "<SENTI_-10>"
    return "<SENTI_00>"

def sentiment_tokens_for_news_items(items: List[Dict[str, Any]], max_items: int) -> List[str]:
    """
    Compute sentiment on FULL TEXT for each article in items (capped).
    Returns one SENTI token per article.
    """
    if not items:
        return []
    items = items[:max_items]

    texts = [it["text"] for it in items]
    preds = sent_pipe(texts, truncation=True)

    toks = []
    for it, pr in zip(items, preds):
        tok = score_to_senti_token(pr.get("label",""), float(pr.get("score",0.0)))
        toks.append(tok)
    return toks

def get_news_for_window(ticker: str, start_dt: datetime, end_dt: datetime) -> Tuple[str, List[str]]:
    """
    Returns:
      news_text: string for prompt (shortenable)
      senti_tokens: list of tokens, one per article (computed on FULL TEXT before shortening)
    """
    items = news_db.get(ticker, [])

    if not items:
        items = fetch_news_from_api(ticker, start_dt, end_dt)

    in_scope = []
    for it in items:
        dt = it["date"]
        if start_dt <= dt <= end_dt:
            in_scope.append(it)

    senti_toks = sentiment_tokens_for_news_items(in_scope, cfg.MAX_NEWS_ARTICLES_PER_WINDOW)

    lines = []
    for it in in_scope[:cfg.MAX_NEWS_ARTICLES_PER_WINDOW]:
        ds = it["date"].strftime("%d-%m-%Y")
        txt = it["text"]
        src = it.get("source","")
        if src:
            lines.append(f"- ({ds}) {txt} [src={src}]")
        else:
            lines.append(f"- ({ds}) {txt}")
    news_text = "\n".join(lines).strip()


    if len(news_text) > cfg.NEWS_MAX_CHARS:
        news_text = news_text[:cfg.NEWS_MAX_CHARS] + " ...[TRUNCATED]"
    return news_text, senti_toks

t0 = core[0]
dlist0 = aligned[t0]["dates"]
if len(dlist0) >= 2:
    start_dt = parse_ddmmyyyy(dlist0[0])
    end_dt   = parse_ddmmyyyy(dlist0[min(len(dlist0)-1, cfg.K)])
    txt, toks = get_news_for_window(t0, start_dt, end_dt)
    log(f"[NEWS WINDOW TEST] ticker={t0}  range={dlist0[0]}..{dlist0[min(len(dlist0)-1, cfg.K)]}")
    log(f"[NEWS WINDOW TEST] senti_tokens={toks[:20]} (n={len(toks)})")
    if txt:
        log("[NEWS WINDOW TEST] news_text preview:")
        print(txt[:600])
    else:
        log("[NEWS WINDOW TEST] no news in this window (OK).")

log("Phase 2 ready.")

"""# Phaser 3"""


import pandas as pd

import yfinance as yf

def datestr_month_key(ddmmyyyy: str) -> str:
    dt = parse_ddmmyyyy(ddmmyyyy)
    return dt.strftime("%m-%Y")

def get_monthly_prices(ticker: str, start: str, end: str):
    """
    Your provided function (kept same logic).
    Downloads monthly data using yfinance.
    Returns pd.Series indexed by "MM-YYYY".
    """
    df = yf.download(ticker, start=start, end=end, interval="1mo", auto_adjust=False, progress=False)
    if df is None or df.empty:
        return pd.Series(dtype=float)

    def _col(name: str):
        if name in df.columns:
            return df[name]
        if isinstance(df.columns, pd.MultiIndex):
            for key in [(name, ticker), (ticker, name)]:
                if key in df.columns:
                    return df[key]
        return None

    s = _col("Adj Close")
    if s is None:
        s = _col("Close")
    if s is None:
        return pd.Series(dtype=float)

    if isinstance(s, pd.DataFrame):
        s = s.squeeze("columns")

    s = pd.to_numeric(s, errors="coerce").dropna()
    if s.empty:
        return pd.Series(dtype=float)

    s.index = pd.to_datetime(s.index)
    out = pd.Series(s.values, index=[d.strftime("%m-%Y") for d in s.index])
    out = out.groupby(out.index).last()
    return out

def compute_monthly_returns(price_by_month: pd.Series):
    """
    ret for month m uses current month vs prev month
    """
    if price_by_month is None or len(price_by_month) == 0:
        return pd.Series(dtype=float)
    months = list(price_by_month.index)
    rets = {}
    prev = None
    for m in months:
        p = float(price_by_month.loc[m])
        if prev is None:
            rets[m] = np.nan
        else:
            rets[m] = (p / prev) - 1.0
        prev = p
    return pd.Series(rets)

def align_and_fill_returns(ticker: str, date_list: List[str]) -> List[float]:
    """
    Align monthly returns to each ESG date.
    Fills missing by forward/back fill, then 0.0 if all missing.
    """
    dts = [parse_ddmmyyyy(d) for d in date_list]
    start = (min(dts) - pd.Timedelta(days=40)).strftime("%Y-%m-%d")
    end   = (max(dts) + pd.Timedelta(days=40)).strftime("%Y-%m-%d")

    prices = get_monthly_prices(ticker, start, end)
    rets = compute_monthly_returns(prices)

    out = []
    last_valid = None

    for d in date_list:
        mk = datestr_month_key(d)
        v = float(rets.get(mk, np.nan)) if rets is not None else np.nan
        if not np.isfinite(v):
            out.append(np.nan)
        else:
            out.append(v)
            last_valid = v

    s = pd.Series(out, dtype=float)
    s = s.ffill().bfill()

    if s.isna().any():
        log(f"[WARN] all-missing returns for {ticker} -> setting to 0.0")
        s = s.fillna(0.0)

    assert np.isfinite(s.values).all(), f"RET still NaN/Inf for {ticker}"
    return s.astype(float).tolist()

def load_ret_file(path: str) -> Dict[str, List[float]]:
    """
    Reads:
      TICKER \t v1 v2 v3 ...
    """
    out = {}
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            parts = line.split("\t")
            if len(parts) != 2:
                continue
            t = parts[0].strip()
            vals = [float(x) for x in parts[1].split()]
            out[t] = vals
    return out

def save_ret_file(path: str, ret_map: Dict[str, List[float]], tickers: List[str]):
    with open(path, "w", encoding="utf-8") as f:
        for t in tickers:
            if t not in ret_map:
                continue
            f.write(t + "\t" + " ".join([f"{x:.6f}" for x in ret_map[t]]) + "\n")
    log(f"[SAVED] {path} lines={len(ret_map)}")

if os.path.exists(PATH_RET_OUT):
    log(f"[RET] found existing returns file -> loading: {PATH_RET_OUT}")
    returns_series = load_ret_file(PATH_RET_OUT)
else:
    log(f"[RET] returns file missing -> generating via yfinance: {PATH_RET_OUT}")
    returns_series = {}
    for t in core:
        try:
            returns_series[t] = align_and_fill_returns(t, aligned[t]["dates"])
            log(f"[RET OK] {t} n={len(returns_series[t])} first={returns_series[t][0]:.6f} last={returns_series[t][-1]:.6f}")
        except Exception as e:
            log(f"[RET FAIL] {t} err={e}")
            L = aligned[t]["L"]
            returns_series[t] = [0.0] * L
            log(f"[RET FALLBACK] {t} -> zeros length {L}")

    save_ret_file(PATH_RET_OUT, returns_series, core)

for t in core:
    L = aligned[t]["L"]
    arr = returns_series.get(t, [])
    if len(arr) == 0:
        returns_series[t] = [0.0] * L
        log(f"[RET WARN] {t} no returns -> zeros length {L}")
    elif len(arr) != L:
        log(f"[RET WARN] length mismatch {t}: ret={len(arr)} alignedL={L} -> trim/pad")
        if len(arr) > L:
            returns_series[t] = arr[:L]
        else:
            returns_series[t] = arr + [arr[-1]] * (L - len(arr))

for t in core:
    aligned[t]["RET"] = returns_series[t][:aligned[t]["L"]]

ex = core[0]
log(f"[RET EX] {ex} L={aligned[ex]['L']}  RET first5={[float(x) for x in aligned[ex]['RET'][:5]]}")

log("Phase 3 ready.")

"""# Phase 4"""


def q2(x: float) -> str:
    """
    Format float to EXACTLY 2 decimal places (XX.YY style, can be negative).
    Always returns something like '13.40', '-1.23', '0.00'
    """
    if not np.isfinite(x):
        x = 0.0
    return f"{x:.2f}"

def clip(x: float, lo: float, hi: float) -> float:
    return float(min(max(x, lo), hi))

def ret_to_pct_2dec(r: float) -> float:
    """
    Convert return ratio to percent and clip:
      r=0.0123 -> 1.23
      r=-0.20  -> -20.00
    Then clip to [-100,100].
    """
    if not np.isfinite(r):
        r = 0.0
    pct = 100.0 * float(r)
    pct = clip(pct, -100.0, 100.0)
    return float(q2(pct))

SERIES_KEYS = ["RET", "SOC", "GOV", "ESG", "ESGFO", "ESGSO", "ENV", "SENTI"]

def series_value_to_token(series: str, v: float) -> str:
    """
    Convert numeric series value to its token.
    """
    if series in ["ESG", "ENV", "SOC", "GOV", "ESGFO", "ESGSO"]:
        v = clip(float(v), 0.0, 100.0)
        return f"<{series}_{q2(v)}>"
    elif series == "RET":
        vp = ret_to_pct_2dec(float(v))
        return f"<RET_{q2(vp)}>"
    elif series == "SENTI":
        vv = int(v)
        if vv > 0: vv = 10
        elif vv < 0: vv = -10
        else: vv = 0
        return f"<SENTI_{vv}>"
    else:
        raise ValueError(f"Unknown series: {series}")

def token_to_float(tok: str) -> float:
    """
    Parse numeric from token like <ESG_61.38> or <RET_-1.23>.
    NOTE: Only works for numeric domain tokens, not text tokens.
    """
    m = re.match(r"^<([A-Z]+)_(-?\d+(?:\.\d+)?)>$", tok)
    if not m:
        raise ValueError(f"Not numeric token: {tok}")
    return float(m.group(2))

def token_prefix(tok: str) -> str:
    m = re.match(r"^<([A-Z]+)_", tok)
    return m.group(1) if m else ""

FACET_TOKENS = [
    "<FACET_NEWS>",
    "<FACET_SENTI>",
    "<FACET_RET>",
    "<FACET_ESG>",
    "<FACET_ENV>",
    "<FACET_SOC>",
    "<FACET_GOV>",
]

STRUCT_TOKENS = [
    "<COMPANY>",
    "<TARGET_SERIES>",
    "<TARGET>",
]

DOMAIN_PREFIXES = ["ESG","ENV","SOC","GOV","RET","SENTI","ESGFO","ESGSO"]

log("[TOKENS] Phase 4 token helpers ready.")
log(f"[TOKENS] FACET_TOKENS={FACET_TOKENS}")
log(f"[TOKENS] STRUCT_TOKENS={STRUCT_TOKENS}")

print(series_value_to_token("ESG", 61.38), series_value_to_token("ENV", 0.87))
print(series_value_to_token("RET", 0.0123), series_value_to_token("RET", -0.2))
print(series_value_to_token("SENTI", -10), series_value_to_token("SENTI", 0), series_value_to_token("SENTI", 10))

log("Phase 4 ready.")

"""# Phase 5"""

from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model

log(f"[MODEL] Loading base model: {cfg.MODEL_NAME}")
tokenizer = AutoTokenizer.from_pretrained(cfg.MODEL_NAME)


from collections import defaultdict

TARGET_SERIES_LIST = ["ESG", "ENV", "SOC", "GOV"]

def ensure_space_end(s: str) -> str:
    if len(s) == 0:
        return s
    return s if s[-1].isspace() else (s + " ")

def truncate_news_to_tokens(news_text: str, max_tokens: int) -> str:
    if not news_text:
        return news_text
    ids = tokenizer(news_text, add_special_tokens=False)["input_ids"]
    if len(ids) <= max_tokens:
        return news_text
    ids = ids[:max_tokens]
    return tokenizer.decode(ids)


def build_prompt(
    ticker: str,
    start_dt: datetime,
    end_dt: datetime,
    hist_tokens_map: Dict[str, List[str]],
    senti_tokens: List[str],
    news_text: str,
    target_series: str
) -> str:
    """
    Unified prompt that ALWAYS contains all facets.
    Histories are K tokens long for each series.
    """

    MAX_NEWS_TOKENS = 50

    lines = []
    lines.append(f"{STRUCT_TOKENS[0]} {ticker}")

    lines.append(f"{FACET_TOKENS[0]}")

    if news_text:
        news_text = truncate_news_to_tokens(news_text, MAX_NEWS_TOKENS)
        lines.append(news_text)
    else:
        lines.append("(no news)")

    lines.append(f"{FACET_TOKENS[1]} " + (" ".join(senti_tokens) if senti_tokens else "(none)"))

    lines.append(f"{FACET_TOKENS[2]} " + " ".join(hist_tokens_map["RET"]))

    lines.append(f"{FACET_TOKENS[3]} " + " ".join(hist_tokens_map["ESG"]))
    lines.append(f"{FACET_TOKENS[4]} " + " ".join(hist_tokens_map["ENV"]))
    lines.append(f"{FACET_TOKENS[5]} " + " ".join(hist_tokens_map["SOC"]))
    lines.append(f"{FACET_TOKENS[6]} " + " ".join(hist_tokens_map["GOV"]))

    lines.append(f"{STRUCT_TOKENS[1]} {target_series}")
    lines.append(f"{STRUCT_TOKENS[2]}")

    return "\n".join(lines) + "\n"

def safe_series_present(aligned_t: Dict[str, Any], series: str) -> bool:
    return (series in aligned_t) and (isinstance(aligned_t[series], list)) and (len(aligned_t[series]) > 0)

def build_samples_for_ticker(ticker: str) -> List[Dict[str, Any]]:
    """
    Create all sliding window samples for this ticker for each target series.
    If a series is missing, skip that series for this ticker.
    """
    out = []
    A = aligned[ticker]
    L = A["L"]
    if L <= cfg.K:
        log(f"[SKIP] {ticker} length L={L} <= K={cfg.K}")
        return out

    cut60 = max(cfg.K, int(0.6 * L))

    dates_list = A["dates"]
    dt_list = [parse_ddmmyyyy(d) for d in dates_list]

    if not safe_series_present(A, "RET"):
        log(f"[WARN] {ticker} missing RET -> will fill with zeros")
        A["RET"] = [0.0] * L

    for t_idx in range(cfg.K, L - cfg.H):
        w0 = t_idx - cfg.K
        w1 = t_idx - 1

        start_dt = dt_list[w0]
        end_dt   = dt_list[w1]

        news_text, senti_tokens = get_news_for_window(ticker, start_dt, end_dt)

        hist_tokens_map = {}

        ret_hist = A["RET"][w0:t_idx]
        hist_tokens_map["RET"] = [series_value_to_token("RET", x) for x in ret_hist]

        for ser in ["ESG","ENV","SOC","GOV"]:
            if safe_series_present(A, ser):
                hist = A[ser][w0:t_idx]
                hist_tokens_map[ser] = [series_value_to_token(ser, x) for x in hist]
            else:
                hist_tokens_map[ser] = [f"<{ser}_0.00>"] * cfg.K

        for target_ser in TARGET_SERIES_LIST:
            if not safe_series_present(A, target_ser):
                continue

            y_val = A[target_ser][t_idx]
            y_tok = series_value_to_token(target_ser, y_val)

            prompt = build_prompt(
                ticker=ticker,
                start_dt=start_dt,
                end_dt=end_dt,
                hist_tokens_map=hist_tokens_map,
                senti_tokens=senti_tokens,
                news_text=news_text,
                target_series=target_ser
            )

            out.append({
                "ticker": ticker,
                "t_index": t_idx,
                "target_series": target_ser,
                "text": prompt,
                "target_token": y_tok,

                "hist_vals": {
                    "RET": [ret_to_pct_2dec(x) for x in ret_hist],
                    "ESG": A["ESG"][w0:t_idx] if safe_series_present(A,"ESG") else [],
                    "ENV": A["ENV"][w0:t_idx] if safe_series_present(A,"ENV") else [],
                    "SOC": A["SOC"][w0:t_idx] if safe_series_present(A,"SOC") else [],
                    "GOV": A["GOV"][w0:t_idx] if safe_series_present(A,"GOV") else [],
                    "SENTI": [token_to_float(tok) for tok in senti_tokens] if senti_tokens else [],
                },
                 "traj_vals": {
                      "RET": [ret_to_pct_2dec(x) for x in A["RET"][:min(cut60, t_idx)]],
                      "ESG": A["ESG"][:min(cut60, t_idx)] if safe_series_present(A,"ESG") else [],
                      "ENV": A["ENV"][:min(cut60, t_idx)] if safe_series_present(A,"ENV") else [],
                      "SOC": A["SOC"][:min(cut60, t_idx)] if safe_series_present(A,"SOC") else [],
                      "GOV": A["GOV"][:min(cut60, t_idx)] if safe_series_present(A,"GOV") else [],
                  }
            })

    return out

all_samples = []
for t in core:
    s = build_samples_for_ticker(t)
    if len(s) == 0:
        continue
    all_samples.extend(s)
    log(f"[SAMPLES] {t} -> {len(s)} samples")

log(f"[SAMPLES] TOTAL samples across all tickers = {len(all_samples)}")

for i in range(min(2, len(all_samples))):
    log(f"[SAMPLE PREVIEW {i}] ticker={all_samples[i]['ticker']} series={all_samples[i]['target_series']} t_index={all_samples[i]['t_index']}")
    print(all_samples[i]["text"][:800])
    print("TARGET_TOKEN:", all_samples[i]["target_token"])
    print("-"*60)

by_ticker = defaultdict(list)
for s in all_samples:
    by_ticker[s["ticker"]].append(s)

train_samples, val_samples, test_samples = [], [], []
for t, lst in by_ticker.items():
    lst.sort(key=lambda x: x["t_index"])
    n = len(lst)
    if n < 5:
        train_samples.extend(lst)
        continue
    n_train = int(0.6 * n)
    n_val   = int(0.2 * n)
    train_samples.extend(lst[:n_train])
    val_samples.extend(lst[n_train:n_train+n_val])
    test_samples.extend(lst[n_train+n_val:])

log(f"[SPLIT] train={len(train_samples)} val={len(val_samples)} test={len(test_samples)}")

def save_jsonl(path: str, rows: List[Dict[str, Any]]):
    with open(path, "w", encoding="utf-8") as f:
        for r in rows:
            f.write(json.dumps(r, ensure_ascii=False) + "\n")
    log(f"[SAVED] {path} lines={len(rows)}")

def load_jsonl(path: str) -> List[Dict[str, Any]]:
    rows = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            rows.append(json.loads(line))
    return rows

save_jsonl(cfg.TRAIN_OUT, train_samples)
save_jsonl(cfg.VAL_OUT,   val_samples)
save_jsonl(cfg.TEST_OUT,  test_samples)

tmp = load_jsonl(cfg.TEST_OUT)
log(f"[CACHE CHECK] loaded {len(tmp)} test rows from {cfg.TEST_OUT}")

log("Phase 5 ready.")

"""# Phase 6"""





if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

lm = AutoModelForCausalLM.from_pretrained(cfg.MODEL_NAME)
lm.to(cfg.DEVICE)

log(f"[MODEL] Loaded. n_embd={lm.config.n_embd} (expect 768 for gpt2)")

def extract_domain_tokens_from_samples(samples: List[Dict[str, Any]]) -> List[str]:
    """
    Extract all tokens matching <SERIES_number> for known series prefixes.
    Also include SENTI tokens which are already in the text.
    """
    pat = re.compile(r"<([A-Z]+)_-?\d+(?:\.\d+)?>")
    toks = set()
    for s in samples:
        for m in pat.finditer(s["text"] + " " + s["target_token"]):
            tok = m.group(0)
            pref = token_prefix(tok)
            if pref in DOMAIN_PREFIXES:
                toks.add(tok)
        if s.get("target_token","").startswith("<"):
            toks.add(s["target_token"])
    return sorted(toks)

train_samples = load_jsonl(cfg.TRAIN_OUT)
val_samples   = load_jsonl(cfg.VAL_OUT)
test_samples  = load_jsonl(cfg.TEST_OUT)

domain_tokens = extract_domain_tokens_from_samples(train_samples + val_samples + test_samples)

log(f"[TOKENS] domain_tokens from cached data: {len(domain_tokens)}")
log(f"[TOKENS] example: {domain_tokens[:10]}")

special_to_add = list(dict.fromkeys(STRUCT_TOKENS + FACET_TOKENS))
num_added = tokenizer.add_special_tokens({"additional_special_tokens": special_to_add + domain_tokens})

log(f"[TOKENS] added {num_added} new tokens to tokenizer (struct+facet+domain)")

lm.resize_token_embeddings(len(tokenizer))
log(f"[TOKENS] tokenizer vocab_size now = {len(tokenizer)} ; model resized.")

def generate_blockwise_series_embedding(token: str, dim: int = 768, scale: float = 1.0) -> np.ndarray:
    """
    Your mandatory blockwise embedding function (as specified).
    Block layout (dim=768, 8 blocks => 96 each):
      RET    -> 0..95
      SOC    -> 96..191
      GOV    -> 192..287
      ESG    -> 288..383
      ESGFO  -> 384..479
      ESGSO  -> 480..575
      ENV    -> 576..671
      SENTI  -> 672..767
    """
    match = re.match(r"^<([A-Z]+)_(-?\d+(?:\.\d+)?)>$", token)
    if not match:
        raise ValueError(f"Invalid token: {token}")
    series_prefix, numeric_value = match.groups()
    numeric_value = float(numeric_value)

    series_keys = ["RET", "SOC", "GOV", "ESG", "ESGFO", "ESGSO", "ENV", "SENTI"]
    block_size = dim // len(series_keys)

    if series_prefix not in series_keys:
        raise ValueError(f"Unknown series_prefix {series_prefix} in {token}")

    start_idx = series_keys.index(series_prefix) * block_size

    x = np.linspace(0, 1, block_size)
    block = np.sin((x + numeric_value / 100.0) * np.pi) * scale
    block += (x ** 2) * (numeric_value / 100.0) * 0.5
    block = (block - block.mean()) / (block.std() + 1e-8)

    pe = np.zeros(dim, dtype=np.float32)
    pe[start_idx:start_idx + block_size] = block
    return pe

W = lm.get_input_embeddings().weight.data
applied = 0
skipped = 0

for tok in domain_tokens:
    pref = token_prefix(tok)
    if pref not in ["RET","SOC","GOV","ESG","ENV","SENTI","ESGFO","ESGSO"]:
        skipped += 1
        continue
    try:
        vec = generate_blockwise_series_embedding(tok, dim=lm.config.n_embd, scale=1.0)
        tid = tokenizer.convert_tokens_to_ids(tok)
        W[tid] = torch.tensor(vec, device=W.device, dtype=W.dtype)
        applied += 1
    except Exception as e:
        skipped += 1

log(f"[BLOCKWISE] applied embeddings to {applied} tokens; skipped {skipped}")

for tok in domain_tokens[:5]:
    tid = tokenizer.convert_tokens_to_ids(tok)
    v = W[tid].detach().cpu()
    print("[BLOCKWISE CHECK]", tok, "id=", tid, "norm=", float(v.norm().item()))

from peft import LoraConfig, get_peft_model, TaskType

lora_config = Loraconfig(
    task_type=TaskType.CAUSAL_LM,

    r=16,

    lora_alpha=32,

    lora_dropout=0.05,

    target_modules=["c_attn", "c_proj", "c_fc"],

    bias="none",
)


lm = get_peft_model(lm, lora_config)

lm.enable_input_require_grads()

lm.gradient_checkpointing_enable()
lm.config.use_cache = False

lm.to(cfg.DEVICE)
lm.train()

lm.print_trainable_parameters()

log("[LoRA] Attached LoRA adapters successfully.")
log("Phase 6 ready.")

"""# Target Aware Bias - 6.5"""


import contextlib
import torch
import numpy as np

from transformers.models.gpt2.modeling_gpt2 import GPT2Attention

if not hasattr(cfg, "ATTN_BIAS_STRENGTH"):
    cfg.ATTN_BIAS_STRENGTH = 0.75


@contextlib.contextmanager
def attn_bias_context(bias_tensor: torch.Tensor):
    """
    bias_tensor: [B, 1, 1, T] additive bias for KEY positions (broadcast over query length)
    """
    global _GLOBAL_ATTN_BIAS
    old = _GLOBAL_ATTN_BIAS
    _GLOBAL_ATTN_BIAS = bias_tensor
    try:
        yield
    finally:
        _GLOBAL_ATTN_BIAS = old

@torch.no_grad()
def build_attn_bias(input_ids: torch.Tensor, target_series_list):
    """
    Returns additive bias tensor shaped [B,1,1,T] on cfg.DEVICE
    Bias is applied to KEY positions that belong to target series tokens.
    """
    B, T = input_ids.shape
    bias = torch.zeros((B, 1, 1, T), device=input_ids.device, dtype=torch.float32)

    for i in range(B):
        ser = target_series_list[i]

        ids = allowed_ids[ser]
        if ids.numel() > 0:
            mask_series = torch.isin(input_ids[i], ids)
            bias[i, 0, 0, mask_series] += float(cfg.ATTN_BIAS_STRENGTH)

        if ser in facet_by_series:
            fid = int(facet_by_series[ser])
            mask_facet = (input_ids[i] == fid)
            bias[i, 0, 0, mask_facet] += float(cfg.ATTN_BIAS_STRENGTH)

    return bias

def patch_gpt2_attention_for_bias():
    """
    Monkey-patch GPT2Attention.forward to add _GLOBAL_ATTN_BIAS to attention_mask if present.
    - attention_mask inside GPT2 is additive (0 for allowed, very negative for masked)
    - Adding +bias increases attention weight for those keys.
    """
    if getattr(GPT2Attention, "_bias_patched", False):
        log("[ATTN_BIAS] GPT2Attention already patched. Skipping.")
        return

    GPT2Attention._orig_forward = GPT2Attention.forward

    def new_forward(
        self,
        hidden_states,
        layer_past=None,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        use_cache=False,
        output_attentions=False,
        **kwargs,
    ):
        global _GLOBAL_ATTN_BIAS

        if (_GLOBAL_ATTN_BIAS is not None) and (attention_mask is not None):
            b = _GLOBAL_ATTN_BIAS
            b = b.to(device=attention_mask.device, dtype=attention_mask.dtype)
            attention_mask = attention_mask + b

        return GPT2Attention._orig_forward(
            self,
            hidden_states,
            layer_past=layer_past,
            attention_mask=attention_mask,
            head_mask=head_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            **kwargs,
        )

    GPT2Attention.forward = new_forward
    GPT2Attention._bias_patched = True
    log("[ATTN_BIAS] Patched GPT2Attention.forward successfully (SAFE).")

patch_gpt2_attention_for_bias()

"""# Phase 7"""


from torch.utils.data import Dataset, DataLoader

cfg.MAX_LEN = int(getattr(lm.config, "n_positions", cfg.MAX_LEN))
tokenizer.truncation_side = "left"
tokenizer.padding_side = "right"
log(f"[ENC] MAX_LEN={cfg.MAX_LEN} truncation_side={tokenizer.truncation_side} padding_side={tokenizer.padding_side}")

def build_allowed_ids_for_series(prefix: str) -> torch.Tensor:
    vocab = tokenizer.get_vocab()
    pat = re.compile(rf"^<{prefix}_-?\d+(?:\.\d+)?>$")
    ids = [i for tok, i in vocab.items() if pat.match(tok)]
    ids = sorted(ids)
    if len(ids) == 0:
        log(f"[WARN] no allowed ids found for series={prefix}")
    return torch.tensor(ids, dtype=torch.long, device=cfg.DEVICE)

allowed_ids = {
    "ESG": build_allowed_ids_for_series("ESG"),
    "ENV": build_allowed_ids_for_series("ENV"),
    "SOC": build_allowed_ids_for_series("SOC"),
    "GOV": build_allowed_ids_for_series("GOV"),
    "RET": build_allowed_ids_for_series("RET"),
    "SENTI": build_allowed_ids_for_series("SENTI"),
}
for k, v in allowed_ids.items():
    log(f"[ALLOWED] {k}: n_ids={int(v.numel())}")

facet_id = {tok: tokenizer.convert_tokens_to_ids(tok) for tok in FACET_TOKENS}
facet_by_series = {
    "NEWS":  facet_id["<FACET_NEWS>"],
    "SENTI": facet_id["<FACET_SENTI>"],
    "RET":   facet_id["<FACET_RET>"],
    "ESG":   facet_id["<FACET_ESG>"],
    "ENV":   facet_id["<FACET_ENV>"],
    "SOC":   facet_id["<FACET_SOC>"],
    "GOV":   facet_id["<FACET_GOV>"],
}
log(f"[FACET] facet ids: {facet_by_series}")

FACET_ORDER = ["NEWS", "SENTI", "RET", "ESG", "ENV", "SOC", "GOV"]

def find_single_pos(input_ids_1d: torch.Tensor, token_id: int) -> int:
    pos = (input_ids_1d == token_id).nonzero(as_tuple=True)[0]
    if pos.numel() == 0:
        return -1
    return int(pos[0].item())

def compute_target_span_mask(input_ids_1d: torch.Tensor, target_series: str) -> torch.Tensor:
    """
    Returns mask [T] with 1 for tokens inside target facet span:
      span = (pos(<FACET_target>) + 1) .. (pos(<FACET_nextFacet>) - 1)
    If next facet missing, span goes until end.
    If target facet missing, returns all-zeros.
    """
    T = input_ids_1d.size(0)
    mask = torch.zeros(T, dtype=torch.float32)

    if target_series not in FACET_ORDER:
        return mask

    start_tok = facet_by_series[target_series]
    start_pos = find_single_pos(input_ids_1d, start_tok)
    if start_pos < 0:
        return mask

    idx = FACET_ORDER.index(target_series)
    end_pos = T
    if idx + 1 < len(FACET_ORDER):
        next_series = FACET_ORDER[idx + 1]
        next_tok = facet_by_series[next_series]
        p2 = find_single_pos(input_ids_1d, next_tok)
        if p2 >= 0:
            end_pos = p2

    a = start_pos + 1
    b = end_pos
    if a < b:
        mask[a:b] = 1.0
    return mask

def _ensure_sep(text: str) -> str:
    if len(text) == 0:
        return text
    return text if text[-1].isspace() else (text + " ")

def encode_samples(samples, tag=""):
    texts = [_ensure_sep(s["text"]) + s["target_token"] for s in samples]

    enc = tokenizer(
        texts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=cfg.MAX_LEN,
        add_special_tokens=False,
    )

    input_ids = enc["input_ids"]
    attn = enc["attention_mask"]

    labels = torch.full_like(input_ids, -100)
    last_pos = attn.sum(dim=1) - 1

    facet_pos = torch.full((input_ids.size(0),), -1, dtype=torch.long)
    target_span_mask = torch.zeros((input_ids.size(0), input_ids.size(1)), dtype=torch.float32)

    show = min(3, input_ids.size(0))
    for i in range(input_ids.size(0)):
        lp = int(last_pos[i].item())
        labels[i, lp] = input_ids[i, lp]

        ser = samples[i]["target_series"]
        fid = facet_by_series[ser]
        fp = find_single_pos(input_ids[i], fid)
        facet_pos[i] = fp

        target_span_mask[i] = compute_target_span_mask(input_ids[i], ser)

        if i < show:
            log(f"[ENC{tag}] sample {i} ticker={samples[i]['ticker']} series={ser} t_index={samples[i]['t_index']}")
            log(f"[ENC{tag}] last_pos={lp} target_tok={samples[i]['target_token']}")
            log(f"[ENC{tag}] facet_pos({ser})={fp}")
            log(f"[ENC{tag}] target_span_mask sum={int(target_span_mask[i].sum().item())}")

            tid = int(input_ids[i, lp].item())
            tok = tokenizer.convert_ids_to_tokens(tid)
            if tok != samples[i]["target_token"]:
                log(f"[ENC{tag}][WARN] target token mismatch: expected {samples[i]['target_token']} got {tok}")

    assert (labels != -100).sum().item() == input_ids.size(0), "Expected exactly 1 supervised token per sample"
    return input_ids, attn, labels, facet_pos, target_span_mask

train_ids, train_attn, train_labels, train_fpos, train_tmask = encode_samples(train_samples, tag=" TRAIN")
val_ids,   val_attn,   val_labels,   val_fpos,   val_tmask   = encode_samples(val_samples,   tag=" VAL")
test_ids,  test_attn,  test_labels,  test_fpos,  test_tmask  = encode_samples(test_samples,  tag=" TEST")

class LMDataset(Dataset):
    def __init__(self, samples, input_ids, attn, labels, facet_pos, target_span_mask):
        self.samples = samples
        self.input_ids = input_ids
        self.attn = attn
        self.labels = labels
        self.facet_pos = facet_pos
        self.target_span_mask = target_span_mask

    def __len__(self):
        return self.input_ids.size(0)

    def __getitem__(self, idx):
        s = self.samples[idx]
        return {
            "input_ids": self.input_ids[idx],
            "attention_mask": self.attn[idx],
            "labels": self.labels[idx],
            "facet_pos": self.facet_pos[idx],
            "target_span_mask": self.target_span_mask[idx],
            "target_series": s["target_series"],
            "ticker": s["ticker"],
            "t_index": s["t_index"],
            "hist_vals": s["hist_vals"],
            "target_token": s["target_token"],
            "text": s["text"],
            "traj_vals": s["traj_vals"],

        }

def collate(batch):
    return {
        "input_ids": torch.stack([b["input_ids"] for b in batch]).to(cfg.DEVICE),
        "attention_mask": torch.stack([b["attention_mask"] for b in batch]).to(cfg.DEVICE),
        "labels": torch.stack([b["labels"] for b in batch]).to(cfg.DEVICE),
        "facet_pos": torch.stack([b["facet_pos"] for b in batch]).to(cfg.DEVICE),
        "target_span_mask": torch.stack([b["target_span_mask"] for b in batch]).to(cfg.DEVICE),
        "target_series": [b["target_series"] for b in batch],
        "ticker": [b["ticker"] for b in batch],
        "t_index": [b["t_index"] for b in batch],
        "hist_vals": [b["hist_vals"] for b in batch],
        "target_token": [b["target_token"] for b in batch],
        "text": [b["text"] for b in batch],
        "traj_vals": [b["traj_vals"] for b in batch],

    }

train_dl = DataLoader(LMDataset(train_samples, train_ids, train_attn, train_labels, train_fpos, train_tmask),
                      batch_size=cfg.BATCH_SIZE, shuffle=True, collate_fn=collate)
val_dl   = DataLoader(LMDataset(val_samples, val_ids, val_attn, val_labels, val_fpos, val_tmask),
                      batch_size=cfg.BATCH_SIZE, shuffle=False, collate_fn=collate)
test_dl  = DataLoader(LMDataset(test_samples, test_ids, test_attn, test_labels, test_fpos, test_tmask),
                      batch_size=cfg.BATCH_SIZE, shuffle=False, collate_fn=collate)

log("[OK] Phase 7 dataloaders ready (with target_span_mask).")

"""# Phase 8"""

def split_facet_z(z: torch.Tensor):
    H = z.size(-1)
    h2 = H // 2
    return z[..., :h2], z[..., h2:]

def extract_facet_Z(hidden_states_last: torch.Tensor, facet_pos: torch.Tensor):
    B, T, H = hidden_states_last.shape
    Z = torch.zeros((B, H), device=hidden_states_last.device, dtype=hidden_states_last.dtype)
    valid = torch.zeros((B,), device=hidden_states_last.device, dtype=torch.bool)

    for i in range(B):
        p = int(facet_pos[i].item())
        if 0 <= p < T:
            Z[i] = hidden_states_last[i, p, :]
            valid[i] = True
    return Z, valid


import time
import numpy as np
import torch
import torch.nn.functional as F
from typing import Any, Dict, List

def dtw_sakoe_chiba(x: List[float], y: List[float], band: int) -> float:
    if x is None or y is None:
        return 0.0
    n, m = len(x), len(y)
    if n == 0 or m == 0:
        return 0.0

    try:
        x = [float(v) for v in x]
        y = [float(v) for v in y]
    except Exception:
        return 0.0

    INF = 1e18
    dp = np.full((n + 1, m + 1), INF, dtype=np.float64)
    dp[0, 0] = 0.0

    for i in range(1, n + 1):
        j_start = max(1, i - band)
        j_end   = min(m, i + band)
        for j in range(j_start, j_end + 1):
            d = x[i - 1] - y[j - 1]
            cost = d * d
            dp[i, j] = cost + min(dp[i - 1, j], dp[i, j - 1], dp[i - 1, j - 1])

    val = float(dp[n, m])
    if not np.isfinite(val):
        return 0.0
    return float(min(val, 1e6))


def forecast_ce_loss(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
    V = logits.size(-1)
    return F.cross_entropy(logits.view(-1, V), labels.view(-1), ignore_index=-100)


def expected_value_from_logits(
    logits_last: torch.Tensor,
    allowed_ids_1d: torch.Tensor,
    id_to_value: Dict[int, float]
) -> torch.Tensor:
    sel = logits_last.index_select(dim=1, index=allowed_ids_1d)
    probs = F.softmax(sel, dim=1)
    vals = torch.tensor(
        [id_to_value[int(i)] for i in allowed_ids_1d.detach().cpu().tolist()],
        dtype=probs.dtype,
        device=probs.device,
    )
    return (probs * vals.unsqueeze(0)).sum(dim=1)


def build_id_to_value(allowed_ids_1d: torch.Tensor) -> Dict[int, float]:
    out = {}
    for tid in allowed_ids_1d.detach().cpu().tolist():
        tok = tokenizer.convert_ids_to_tokens(int(tid))
        try:
            out[int(tid)] = token_to_float(tok)
        except Exception:
            out[int(tid)] = 0.0
    return out


id_to_val = {
    "ESG": build_id_to_value(allowed_ids["ESG"]),
    "ENV": build_id_to_value(allowed_ids["ENV"]),
    "SOC": build_id_to_value(allowed_ids["SOC"]),
    "GOV": build_id_to_value(allowed_ids["GOV"]),
    "RET": build_id_to_value(allowed_ids["RET"]) if allowed_ids["RET"].numel() > 0 else {},
    "SENTI": build_id_to_value(allowed_ids["SENTI"]) if allowed_ids["SENTI"].numel() > 0 else {},
}


def slope_loss_from_batch(logits: torch.Tensor, labels: torch.Tensor, batch: Dict[str, Any]) -> torch.Tensor:
    B, T, V = logits.shape
    pos = (labels != -100).nonzero(as_tuple=False)
    logits_last = logits[pos[:, 0], pos[:, 1], :]

    series_list = batch["target_series"]
    hist_vals_list = batch["hist_vals"]

    pred_vals, true_vals, prev_vals = [], [], []
    for i in range(B):
        ser = series_list[i]
        tid = int(labels[pos[i, 0], pos[i, 1]].item())
        true_tok = tokenizer.convert_ids_to_tokens(tid)

        try:
            y_true = token_to_float(true_tok)
        except Exception:
            y_true = 0.0

        hv = hist_vals_list[i].get(ser, [])
        y_prev = float(hv[-1]) if (hv is not None and len(hv) > 0) else float(y_true)

        if allowed_ids[ser].numel() == 0:
            y_pred = y_prev
        else:
            y_pred = float(expected_value_from_logits(
                logits_last[i:i+1], allowed_ids[ser], id_to_val[ser]
            )[0].item())

        pred_vals.append(y_pred)
        true_vals.append(y_true)
        prev_vals.append(y_prev)

    pred_vals = torch.tensor(pred_vals, device=cfg.DEVICE, dtype=torch.float32)
    true_vals = torch.tensor(true_vals, device=cfg.DEVICE, dtype=torch.float32)
    prev_vals = torch.tensor(prev_vals, device=cfg.DEVICE, dtype=torch.float32)

    return F.mse_loss(pred_vals - prev_vals, true_vals - prev_vals)


def safe_zscore(x: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
    m = x.mean()
    v = ((x - m) ** 2).mean()
    s = torch.sqrt(v + eps)
    return (x - m) / s



from typing import Dict, Any, List
import numpy as np
import torch
import torch.nn.functional as F

def geom_dtw_loss_from_batch(hidden_states_last: torch.Tensor, batch: Dict[str, Any]) -> torch.Tensor:
    """
    Ranking loss:
      if DTW(a,b) < DTW(a,c) then dist(z_a,z_b) + margin < dist(z_a,z_c)
    """
    device = hidden_states_last.device
    try:
        B, T, Hdim = hidden_states_last.shape
        series_list = batch["target_series"]
        traj_vals_list = batch["traj_vals"]

        Z_facet, valid = extract_facet_Z(hidden_states_last, batch["facet_pos"])
        Z_geom, _ = split_facet_z(Z_facet)
        Z_geom = F.normalize(Z_geom.float(), p=2, dim=-1)

        idx_by_ser: Dict[str, List[int]] = {}
        for i in range(B):
            if bool(valid[i].item()):
                idx_by_ser.setdefault(series_list[i], []).append(i)

        if len(idx_by_ser) == 0:
            return torch.tensor(0.0, device=device)

        step_seed = int(getattr(cfg, "GLOBAL_STEP", 0))
        rng = np.random.default_rng(cfg.SEED + step_seed)

        margin = float(getattr(cfg, "GEOM_MARGIN", 0.15))
        triplets_per_ser = int(getattr(cfg, "GEOM_TRIPLETS_PER_SER", 128))

        losses = []
        for ser, idxs in idx_by_ser.items():
            if len(idxs) < 3:
                continue

            n_trip = min(triplets_per_ser, (len(idxs) * (len(idxs) - 1) * (len(idxs) - 2)) // 6)

            for _ in range(n_trip):
                a, b, c = rng.choice(idxs, size=3, replace=False)

                xa = np.array((traj_vals_list[a].get(ser, []) or [])[-cfg.K:], dtype=float)
                xb = np.array((traj_vals_list[b].get(ser, []) or [])[-cfg.K:], dtype=float)
                xc = np.array((traj_vals_list[c].get(ser, []) or [])[-cfg.K:], dtype=float)
                if xa.size < 3 or xb.size < 3 or xc.size < 3:
                    continue

                xa = _norm_seq(xa).tolist()
                xb = _norm_seq(xb).tolist()
                xc = _norm_seq(xc).tolist()

                dab = float(dtw_sakoe_chiba(xa, xb, band=cfg.DTW_BAND))
                dac = float(dtw_sakoe_chiba(xa, xc, band=cfg.DTW_BAND))
                if abs(dab - dac) < 1e-6:
                    continue

                zab = 1.0 - torch.sum(Z_geom[a] * Z_geom[b])
                zac = 1.0 - torch.sum(Z_geom[a] * Z_geom[c])
                if not (torch.isfinite(zab) and torch.isfinite(zac)):
                    continue

                if dab < dac:
                    losses.append(F.relu(zab - zac + margin))
                else:
                    losses.append(F.relu(zac - zab + margin))

        if len(losses) == 0:
            return torch.tensor(0.0, device=device)

        return torch.stack(losses).mean()

    except Exception:
        return torch.tensor(0.0, device=device)





log("Phase 8 ready (NEW): CE + geomDTW with hard NaN guards (slope available but OFF in Phase 9).")

"""# Phase 9"""

@torch.no_grad()
def debug_facet_pos_one_batch(batch, n=4):
    ids = batch["input_ids"]
    fpos = batch["facet_pos"]
    am = batch["attention_mask"]
    for i in range(min(n, ids.size(0))):
        p = int(fpos[i].item())
        tok = tokenizer.convert_ids_to_tokens(int(ids[i, p].item())) if (0 <= p < ids.size(1)) else "OUT_OF_RANGE"
        print(f"[FACETPOS] i={i} pos={p} attn={int(am[i,p].item()) if 0<=p<ids.size(1) else -1} tok={tok}")

import os, gc, torch
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
gc.collect()
print(torch.cuda.memory_summary(device=None, abbreviated=True))

trainable = [(n,p) for n,p in lm.named_parameters() if p.requires_grad]
print("TRAINABLE PARAMS:", len(trainable))
print("TRAINABLE ELEMENTS:", sum(p.numel() for _,p in trainable))

for n,_ in trainable[:20]:
    print("  ", n)

assert len(trainable) > 0, "No trainable params! LoRA might not be enabled / model frozen incorrectly."

"""# Warm up for Head"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW

torch.manual_seed(getattr(cfg, "SEED", 1234))
torch.cuda.manual_seed_all(getattr(cfg, "SEED", 1234))
value_head = nn.Linear(384, 1).to(cfg.DEVICE)

def reset_value_head(m):
    if isinstance(m, nn.Linear):
        m.reset_parameters()

value_head.apply(reset_value_head)
print("[HEAD-WARMUP] value_head reset_parameters() applied")

"""# New warm up"""



for p in lm.parameters():
    p.requires_grad = True


import torch
import torch.nn.functional as F
from torch.optim import AdamW

if not hasattr(cfg, "EV_WARMUP_EPOCHS"):
    cfg.EV_WARMUP_EPOCHS = 2
if not hasattr(cfg, "EV_WARMUP_LR"):
    cfg.EV_WARMUP_LR = 3e-4
if not hasattr(cfg, "EV_WARMUP_CLIP"):
    cfg.EV_WARMUP_CLIP = 1.0
if not hasattr(cfg, "EV_WARMUP_STEPS"):
    cfg.EV_WARMUP_STEPS = 999999

def get_supervised_positions(labels: torch.Tensor) -> torch.Tensor:
    pos = (labels != -100).nonzero(as_tuple=False)
    B = labels.size(0)
    assert pos.size(0) == B, f"Expected exactly 1 supervised token per sample; got {pos.size(0)} for B={B}."
    pos = pos[pos[:, 0].argsort()]
    assert torch.all(pos[:, 0].cpu() == torch.arange(B)), "Supervised positions not aligned to batch order."
    return pos

def value_mse_loss_from_batch(logits: torch.Tensor, labels: torch.Tensor, batch) -> torch.Tensor:
    """
    logits: [B,T,V]
    labels: [B,T] with exactly one non -100 per row
    Uses expected value over allowed_ids[series] at supervised position.
    """
    B, T, V = logits.shape
    pos = get_supervised_positions(labels)
    logits_last = logits[pos[:, 0], pos[:, 1], :]
    series_list = batch["target_series"]

    preds = []
    trues = []

    for i in range(B):
        ser = series_list[i]
        ids = allowed_ids[ser]

        tid = int(labels[pos[i, 0], pos[i, 1]].item())
        tok = tokenizer.convert_ids_to_tokens(tid)
        try:
            y_true = float(token_to_float(tok))
        except Exception:
            y_true = 0.0
        trues.append(y_true)

        if ids.numel() == 0:
            preds.append(torch.tensor([y_true], device=logits.device, dtype=torch.float32))
        else:
            y_pred = expected_value_from_logits(logits_last[i:i+1], ids, id_to_val[ser])
            preds.append(y_pred.to(dtype=torch.float32))

    y_pred_t = torch.cat(preds, dim=0).to(device=logits.device, dtype=torch.float32)
    y_true_t = torch.tensor(trues, device=logits.device, dtype=torch.float32)
    return F.mse_loss(y_pred_t, y_true_t)

if "value_head" in globals() and value_head is not None:
    for p in value_head.parameters():
        p.requires_grad = False
    value_head.eval()

trainable_lm_params = [p for p in lm.parameters() if p.requires_grad]
assert len(trainable_lm_params) > 0, "No trainable LM params found. Did you enable LoRA and set requires_grad=True?"

lm.train()
ev_opt = AdamW(trainable_lm_params, lr=float(cfg.EV_WARMUP_LR))

print(
    f"[EV-WARMUP] epochs={cfg.EV_WARMUP_EPOCHS} lr={cfg.EV_WARMUP_LR} "
    f"clip={cfg.EV_WARMUP_CLIP} steps_cap={cfg.EV_WARMUP_STEPS} trainable_lm_params={len(trainable_lm_params)}"
)

global_step_ev = 0
for ep in range(1, int(cfg.EV_WARMUP_EPOCHS) + 1):
    total = 0.0
    n = 0

    for batch in train_dl:
        global_step_ev += 1
        if global_step_ev > int(cfg.EV_WARMUP_STEPS):
            break

        ev_opt.zero_grad(set_to_none=True)

        bias = build_attn_bias(batch["input_ids"], batch["target_series"])
        with attn_bias_context(bias):
            out = lm(
                input_ids=batch["input_ids"],
                attention_mask=batch["attention_mask"],
                labels=batch["labels"],
                output_hidden_states=False,
            )

        loss = value_mse_loss_from_batch(out.logits, batch["labels"], batch)

        if not torch.isfinite(loss):
            continue

        loss.backward()
        torch.nn.utils.clip_grad_norm_(trainable_lm_params, float(cfg.EV_WARMUP_CLIP))
        ev_opt.step()

        total += float(loss.detach().cpu().item())
        n += 1

        if (n % 50) == 0 or n == 1:
            print(f"[EV-WARMUP][TRAIN] ep={ep} step={n} ev_mse={total/max(1,n):.4f}")

    print(f"[EV-WARMUP][TRAIN] DONE ep={ep} avg_ev_mse={total/max(1,n):.4f} steps={n}")

print("[EV-WARMUP] complete. Now run Phase 9.")

def get_supervised_positions(labels: torch.Tensor) -> torch.Tensor:
    pos = (labels != -100).nonzero(as_tuple=False)
    B = labels.size(0)
    assert pos.size(0) == B, f"Expected exactly 1 supervised token per sample; got {pos.size(0)} for B={B}."
    pos = pos[pos[:, 0].argsort()]
    assert torch.all(pos[:, 0].cpu() == torch.arange(B)), "Supervised positions not aligned to batch order."
    return pos

@torch.no_grad()
def expected_numeric_from_labels_position(logits: torch.Tensor, labels: torch.Tensor, series_list):
    B, T, V = logits.shape
    pos = get_supervised_positions(labels)
    yhat, ytrue = [], []

    for i in range(B):
        tpos = int(pos[i, 1].item())
        ser = series_list[i]

        tid = int(labels[i, tpos].item())
        tok = tokenizer.convert_ids_to_tokens(tid)
        try:
            true_val = float(token_to_float(tok))
        except Exception:
            true_val = 0.0

        ids = allowed_ids[ser]
        if ids.numel() == 0:
            pred_val = true_val
        else:
            logits_last = logits[i:i+1, tpos, :]
            pred_val = float(expected_value_from_logits(logits_last, ids, id_to_val[ser])[0].item())

        yhat.append(pred_val)
        ytrue.append(true_val)

    return yhat, ytrue

def per_series_mse_numpy(yhat, ytrue, series_list):
    out = {}
    for ser in ["ESG", "ENV", "SOC", "GOV"]:
        idx = [k for k, s in enumerate(series_list) if s == ser]
        if len(idx) == 0:
            continue
        out[ser] = float(np.mean([(yhat[k] - ytrue[k]) ** 2 for k in idx]))
    return out

def value_mse_loss_from_batch(logits: torch.Tensor, labels: torch.Tensor, batch) -> torch.Tensor:
    B, T, V = logits.shape
    pos = get_supervised_positions(labels)
    logits_last = logits[pos[:, 0], pos[:, 1], :]
    series_list = batch["target_series"]

    preds = []
    trues = []
    for i in range(B):
        ser = series_list[i]
        ids = allowed_ids[ser]

        tid = int(labels[pos[i, 0], pos[i, 1]].item())
        tok = tokenizer.convert_ids_to_tokens(tid)
        try:
            y_true = float(token_to_float(tok))
        except Exception:
            y_true = 0.0
        trues.append(y_true)

        if ids.numel() == 0:
            preds.append(torch.tensor([y_true], device=logits.device, dtype=torch.float32))
        else:
            y_pred = expected_value_from_logits(logits_last[i:i+1], ids, id_to_val[ser])
            preds.append(y_pred.to(dtype=torch.float32))

    y_pred_t = torch.cat(preds, dim=0).to(device=logits.device, dtype=torch.float32)
    y_true_t = torch.tensor(trues, device=logits.device, dtype=torch.float32)
    return F.mse_loss(y_pred_t, y_true_t)

def regression_mse_from_batch(out, batch) -> torch.Tensor:
    hs = out.hidden_states[-1]
    Z, valid = extract_facet_Z(hs, batch["facet_pos"])
    _, Z_pred = split_facet_z(Z)
    y_pred = value_head(Z_pred).squeeze(-1).float()

    labels = batch["labels"]
    pos = get_supervised_positions(labels)
    y_true = []
    for i in range(labels.size(0)):
        tid = int(labels[pos[i, 0], pos[i, 1]].item())
        tok = tokenizer.convert_ids_to_tokens(tid)
        try:
            y_true.append(float(token_to_float(tok)))
        except Exception:
            y_true.append(0.0)
    y_true = torch.tensor(y_true, device=cfg.DEVICE, dtype=torch.float32)
    return F.mse_loss(y_pred, y_true)

@torch.no_grad()
def _dtw_dist(i, j, ser, traj_vals_list):
    xi = traj_vals_list[i].get(ser, []) or []
    xj = traj_vals_list[j].get(ser, []) or []
    if len(xi) == 0 or len(xj) == 0:
        return 0.0
    xi = _norm_seq(np.array(xi, dtype=float)).tolist()
    xj = _norm_seq(np.array(xj, dtype=float)).tolist()
    return float(dtw_sakoe_chiba(xi, xj, band=cfg.DTW_BAND))

def facet_triplet_loss(hidden_states_last: torch.Tensor, batch, global_step: int) -> torch.Tensor:
    B, T, Hdim = hidden_states_last.shape
    series_list = batch["target_series"]
    traj_vals_list = batch["traj_vals"]

    Z, valid = extract_facet_Z(hidden_states_last, batch["facet_pos"])
    Z_geom, _ = split_facet_z(Z)
    Z_geom = F.normalize(Z_geom.float(), p=2, dim=-1)

    idx_by_ser = {}
    for i in range(B):
        if bool(valid[i].item()):
            idx_by_ser.setdefault(series_list[i], []).append(i)

    rng = np.random.default_rng(cfg.SEED + 12345 + int(global_step))
    margin = float(cfg.TRIPLET_MARGIN)

    losses = []
    for ser, idxs in idx_by_ser.items():
        if len(idxs) < 3:
            continue

        n_trip = max(int(cfg.TRIPLET_PER_SER), min(64, len(idxs) * 2))

        for _ in range(n_trip):
            a, b, c = rng.choice(idxs, size=3, replace=False)

            xa = np.array((traj_vals_list[a].get(ser, []) or [])[-cfg.K:], dtype=float)
            xb = np.array((traj_vals_list[b].get(ser, []) or [])[-cfg.K:], dtype=float)
            xc = np.array((traj_vals_list[c].get(ser, []) or [])[-cfg.K:], dtype=float)
            if xa.size < 3 or xb.size < 3 or xc.size < 3:
                continue

            xa = _norm_seq(xa).tolist()
            xb = _norm_seq(xb).tolist()
            xc = _norm_seq(xc).tolist()

            dab = float(dtw_sakoe_chiba(xa, xb, band=cfg.DTW_BAND))
            dac = float(dtw_sakoe_chiba(xa, xc, band=cfg.DTW_BAND))
            if abs(dab - dac) < 1e-6:
                continue

            zab = torch.norm(Z_geom[a] - Z_geom[b], p=2)
            zac = torch.norm(Z_geom[a] - Z_geom[c], p=2)
            if not (torch.isfinite(zab) and torch.isfinite(zac)):
                continue

            if dab < dac:
                losses.append(F.relu(zab - zac + margin))
            else:
                losses.append(F.relu(zac - zab + margin))

    if len(losses) == 0:
        return torch.tensor(0.0, device=hidden_states_last.device)
    return torch.stack(losses).mean()

@torch.no_grad()
def val_ev_mse_by_series():
    lm.eval()
    ys = {s: [] for s in ["ESG","ENV","SOC","GOV"]}
    yh = {s: [] for s in ["ESG","ENV","SOC","GOV"]}

    for batch in val_dl:
        bias = build_attn_bias(batch["input_ids"], batch["target_series"])
        with attn_bias_context(bias):
            out = lm(
                input_ids=batch["input_ids"],
                attention_mask=batch["attention_mask"],
                labels=batch["labels"],
                output_hidden_states=False,
            )

        yhat, ytrue = expected_numeric_from_labels_position(out.logits, batch["labels"], batch["target_series"])
        for i, ser in enumerate(batch["target_series"]):
            if ser in ys:
                ys[ser].append(float(ytrue[i]))
                yh[ser].append(float(yhat[i]))

    out_mse = {}
    for ser in ["ESG","ENV","SOC","GOV"]:
        if len(ys[ser]) == 0:
            out_mse[ser] = float("inf")
        else:
            yt = np.array(ys[ser], dtype=float)
            yp = np.array(yh[ser], dtype=float)
            out_mse[ser] = float(np.mean((yp - yt) ** 2))
    return out_mse

mse_by_ser = val_ev_mse_by_series()
print("[EV-WARMUP][VAL] EV_MSE_by_series:", mse_by_ser, "worst=", max(mse_by_ser.values()))
lm.train()

for p in lm.parameters():
    p.requires_grad = True

def _norm_seq(x: np.ndarray, eps: float = 1e-6) -> np.ndarray:
    x = np.asarray(x, dtype=float)
    if x.size < 2:
        return x
    m = x.mean()
    s = x.std()
    if s < eps:
        return x - m
    return (x - m) / (s + eps)


import gc
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.amp import autocast, GradScaler
from scipy.stats import spearmanr
from typing import Dict, Any, List

if "split_facet_z" not in globals():
    def split_facet_z(z: torch.Tensor):
        H = z.size(-1)
        h2 = H // 2
        return z[..., :h2], z[..., h2:]

if "extract_facet_Z" not in globals():
    def extract_facet_Z(hidden_states_last: torch.Tensor, facet_pos: torch.Tensor):
        B, T, H = hidden_states_last.shape
        Z = torch.zeros((B, H), device=hidden_states_last.device, dtype=hidden_states_last.dtype)
        valid = torch.zeros((B,), device=hidden_states_last.device, dtype=torch.bool)
        for i in range(B):
            p = int(facet_pos[i].item())
            if 0 <= p < T:
                Z[i] = hidden_states_last[i, p, :]
                valid[i] = True
        return Z, valid

if not hasattr(cfg, "MSE_GATE"):
    cfg.MSE_GATE = 4.5
if not hasattr(cfg, "SPEARMAN_GATE"):
    cfg.SPEARMAN_GATE = 0.85

if not hasattr(cfg, "GEOM_PROBE_MAX_PER_SER"):
    cfg.GEOM_PROBE_MAX_PER_SER = 600
if not hasattr(cfg, "GEOM_PROBE_PAIRS"):
    cfg.GEOM_PROBE_PAIRS = 2000
if not hasattr(cfg, "GEOM_PROBE_EVERY_EPOCH"):
    cfg.GEOM_PROBE_EVERY_EPOCH = 1

if not hasattr(cfg, "TRAIN_SPEARMAN_EVERY"):
    cfg.TRAIN_SPEARMAN_EVERY = getattr(cfg, "PRINT_EVERY_STEPS", 50)
if not hasattr(cfg, "TRAIN_SPEARMAN_PAIRS"):
    cfg.TRAIN_SPEARMAN_PAIRS = 256
if not hasattr(cfg, "TRAIN_SPEARMAN_MAX_PER_SER"):
    cfg.TRAIN_SPEARMAN_MAX_PER_SER = 64

if not hasattr(cfg, "PAIRS_PER_BATCH"):
    cfg.PAIRS_PER_BATCH = 1024
if not hasattr(cfg, "GEOM_MARGIN"):
    cfg.GEOM_MARGIN = 0.05

if not hasattr(cfg, "NUM_WARMUP_STEPS"):
    cfg.NUM_WARMUP_STEPS = 1200
if not hasattr(cfg, "NUM_WARMUP_SCALE"):
    cfg.NUM_WARMUP_SCALE = 1.0

if not hasattr(cfg, "TRIPLET_PER_SER"):
    cfg.TRIPLET_PER_SER = 32
if not hasattr(cfg, "TRIPLET_MARGIN"):
    cfg.TRIPLET_MARGIN = 0.2

USE_AMP = False
AMP_DTYPE = torch.float16
scaler = GradScaler("cuda", enabled=USE_AMP)

L_CE_BASE      = 0.8
L_VALTOK_BASE  = 3.0
L_REG_BASE     = 0.0
L_GEOM_BASE    = 5.0
L_TRIP_BASE    = 5.0

L_VALTOK_SPEARMAN = 1.0
L_GEOM_SPEARMAN   = 10.0
L_TRIP_SPEARMAN   = 8.0
L_CE_SPEARMAN     = 0.6

lm.gradient_checkpointing_enable()
lm.config.use_cache = False

if "value_head" not in globals() or value_head is None:
    print("**CREATED NEW HEAD***")
    value_head = nn.Linear(lm.config.n_embd // 2, 1).to(cfg.DEVICE)
else:
    print("**REUSED EXISTING HEAD")
    value_head = value_head.to(cfg.DEVICE)

trainable_params = [p for p in lm.parameters() if p.requires_grad]
opt = AdamW(trainable_params + list(value_head.parameters()),
            lr=cfg.LR, weight_decay=cfg.WEIGHT_DECAY)

def get_supervised_positions(labels: torch.Tensor) -> torch.Tensor:
    pos = (labels != -100).nonzero(as_tuple=False)
    B = labels.size(0)
    assert pos.size(0) == B, f"Expected exactly 1 supervised token per sample; got {pos.size(0)} for B={B}."
    pos = pos[pos[:, 0].argsort()]
    assert torch.all(pos[:, 0].cpu() == torch.arange(B)), "Supervised positions not aligned to batch order."
    return pos

@torch.no_grad()
def expected_numeric_from_labels_position(logits: torch.Tensor, labels: torch.Tensor, series_list):
    B, T, V = logits.shape
    pos = get_supervised_positions(labels)
    yhat, ytrue = [], []
    for i in range(B):
        tpos = int(pos[i, 1].item())
        ser = series_list[i]
        tid = int(labels[i, tpos].item())
        tok = tokenizer.convert_ids_to_tokens(tid)
        try:
            true_val = float(token_to_float(tok))
        except Exception:
            true_val = 0.0

        ids = allowed_ids[ser]
        if ids.numel() == 0:
            pred_val = true_val
        else:
            logits_last = logits[i:i+1, tpos, :]
            pred_val = float(expected_value_from_logits(logits_last, ids, id_to_val[ser])[0].item())

        yhat.append(pred_val)
        ytrue.append(true_val)
    return yhat, ytrue

def per_series_mse_numpy(yhat, ytrue, series_list):
    out = {}
    for ser in ["ESG", "ENV", "SOC", "GOV"]:
        idx = [k for k, s in enumerate(series_list) if s == ser]
        if len(idx) == 0:
            continue
        out[ser] = float(np.mean([(yhat[k] - ytrue[k]) ** 2 for k in idx]))
    return out

def value_mse_loss_from_batch(logits: torch.Tensor, labels: torch.Tensor, batch) -> torch.Tensor:
    B, T, V = logits.shape
    pos = get_supervised_positions(labels)
    logits_last = logits[pos[:, 0], pos[:, 1], :]
    series_list = batch["target_series"]

    preds = []
    trues = []
    for i in range(B):
        ser = series_list[i]
        ids = allowed_ids[ser]

        tid = int(labels[pos[i, 0], pos[i, 1]].item())
        tok = tokenizer.convert_ids_to_tokens(tid)
        try:
            y_true = float(token_to_float(tok))
        except Exception:
            y_true = 0.0
        trues.append(y_true)

        if ids.numel() == 0:
            preds.append(torch.tensor([y_true], device=logits.device, dtype=torch.float32))
        else:
            y_pred = expected_value_from_logits(logits_last[i:i+1], ids, id_to_val[ser])
            preds.append(y_pred.to(dtype=torch.float32))

    y_pred_t = torch.cat(preds, dim=0).to(device=logits.device, dtype=torch.float32)
    y_true_t = torch.tensor(trues, device=logits.device, dtype=torch.float32)
    return F.mse_loss(y_pred_t, y_true_t)

def regression_mse_from_batch(out, batch) -> torch.Tensor:
    hs = out.hidden_states[-1]
    Z, valid = extract_facet_Z(hs, batch["facet_pos"])
    _, Z_pred = split_facet_z(Z)
    y_pred = value_head(Z_pred).squeeze(-1).float()

    labels = batch["labels"]
    pos = get_supervised_positions(labels)
    y_true = []
    for i in range(labels.size(0)):
        tid = int(labels[pos[i, 0], pos[i, 1]].item())
        tok = tokenizer.convert_ids_to_tokens(tid)
        try:
            y_true.append(float(token_to_float(tok)))
        except Exception:
            y_true.append(0.0)
    y_true = torch.tensor(y_true, device=cfg.DEVICE, dtype=torch.float32)
    return F.mse_loss(y_pred, y_true)

@torch.no_grad()
def _dtw_dist(i, j, ser, traj_vals_list):
    xi = traj_vals_list[i].get(ser, []) or []
    xj = traj_vals_list[j].get(ser, []) or []
    if len(xi) == 0 or len(xj) == 0:
        return 0.0
    xi = _norm_seq(np.array(xi, dtype=float)).tolist()
    xj = _norm_seq(np.array(xj, dtype=float)).tolist()
    return float(dtw_sakoe_chiba(xi, xj, band=cfg.DTW_BAND))

def facet_triplet_loss(hidden_states_last: torch.Tensor, batch) -> torch.Tensor:
    B, T, Hdim = hidden_states_last.shape
    series_list = batch["target_series"]
    traj_vals_list = batch["traj_vals"]

    Z, valid = extract_facet_Z(hidden_states_last, batch["facet_pos"])
    Z_geom, _ = split_facet_z(Z)
    Z_geom = F.normalize(Z_geom.float(), p=2, dim=-1)

    idx_by_ser = {}
    for i in range(B):
        if bool(valid[i].item()):
            idx_by_ser.setdefault(series_list[i], []).append(i)

    rng = np.random.default_rng(cfg.SEED + 12345 + int(torch.randint(0, 100000, (1,)).item()))
    margin = float(cfg.TRIPLET_MARGIN)

    losses = []
    for ser, idxs in idx_by_ser.items():
        if len(idxs) < 3:
            continue
        n_trip = min(int(cfg.TRIPLET_PER_SER), len(idxs) * 4)
        for _ in range(n_trip):
            i, j, k = rng.choice(idxs, size=3, replace=False)
            dij = _dtw_dist(i, j, ser, traj_vals_list)
            dik = _dtw_dist(i, k, ser, traj_vals_list)
            if abs(dij - dik) < 1e-6:
                continue

            if dij < dik:
                z_close = torch.norm(Z_geom[i] - Z_geom[j], p=2)
                z_far   = torch.norm(Z_geom[i] - Z_geom[k], p=2)
            else:
                z_close = torch.norm(Z_geom[i] - Z_geom[k], p=2)
                z_far   = torch.norm(Z_geom[i] - Z_geom[j], p=2)

            losses.append(F.relu(z_close - z_far + margin))

    if len(losses) == 0:
        return torch.tensor(0.0, device=cfg.DEVICE)
    return torch.stack(losses).mean()

def geom_dtw_loss_from_batch(hidden_states_last: torch.Tensor, batch: Dict[str, Any]) -> torch.Tensor:
    """
    Rank loss:
      if DTW(a,b) < DTW(a,c) then dist(z_a,z_b)+m < dist(z_a,z_c)
    """
    device = hidden_states_last.device
    try:
        B, T, Hdim = hidden_states_last.shape
        series_list = batch["target_series"]
        traj_vals_list = batch["traj_vals"]

        Z_facet, valid = extract_facet_Z(hidden_states_last, batch["facet_pos"])
        Z_geom, _ = split_facet_z(Z_facet)
        Z_geom = F.normalize(Z_geom.float(), p=2, dim=-1)

        idx_by_ser: Dict[str, List[int]] = {}
        for i in range(B):
            if bool(valid[i].item()):
                idx_by_ser.setdefault(series_list[i], []).append(i)

        rng = np.random.default_rng(cfg.SEED + int(torch.randint(0, 10**6, (1,)).item()))
        margin = float(getattr(cfg, "GEOM_MARGIN", 0.05))

        losses = []
        for ser, idxs in idx_by_ser.items():
            if len(idxs) < 3:
                continue

            n_trip = max(64, int(cfg.PAIRS_PER_BATCH // max(1, len(idx_by_ser))))
            for _ in range(n_trip):
                a, b, c = rng.choice(idxs, size=3, replace=False)

                xa = np.array((traj_vals_list[a].get(ser, []) or [])[-cfg.K:], dtype=float)
                xb = np.array((traj_vals_list[b].get(ser, []) or [])[-cfg.K:], dtype=float)
                xc = np.array((traj_vals_list[c].get(ser, []) or [])[-cfg.K:], dtype=float)
                if xa.size < 3 or xb.size < 3 or xc.size < 3:
                    continue

                xa = _norm_seq(xa).tolist()
                xb = _norm_seq(xb).tolist()
                xc = _norm_seq(xc).tolist()

                dab = float(dtw_sakoe_chiba(xa, xb, band=cfg.DTW_BAND))
                dac = float(dtw_sakoe_chiba(xa, xc, band=cfg.DTW_BAND))
                if abs(dab - dac) < 1e-6:
                    continue

                zab = 1.0 - torch.sum(Z_geom[a] * Z_geom[b])
                zac = 1.0 - torch.sum(Z_geom[a] * Z_geom[c])
                if not (torch.isfinite(zab) and torch.isfinite(zac)):
                    continue

                if dab < dac:
                    losses.append(F.relu(zab - zac + margin))
                else:
                    losses.append(F.relu(zac - zab + margin))

        if len(losses) == 0:
            return torch.tensor(0.0, device=device)

        loss = torch.stack(losses).mean()
        return loss if torch.isfinite(loss) else torch.tensor(0.0, device=device)

    except Exception:
        return torch.tensor(0.0, device=device)

@torch.no_grad()
def train_geom_spearman_from_batch(hidden_states_last: torch.Tensor, batch: Dict[str, Any],
                                   max_per_ser: int, max_pairs: int) -> Dict[str, float]:
    """
    Fast proxy: within THIS batch only.
    Computes Spearman between DTW distances and Z_geom distances for sampled pairs per series.
    """
    B, T, Hdim = hidden_states_last.shape
    series_list = batch["target_series"]
    traj_vals_list = batch["traj_vals"]

    Z_facet, valid = extract_facet_Z(hidden_states_last, batch["facet_pos"])
    Z_geom, _ = split_facet_z(Z_facet)
    Z_geom = F.normalize(Z_geom.float(), p=2, dim=-1)

    idx_by_ser: Dict[str, List[int]] = {s: [] for s in ["ESG", "ENV", "SOC", "GOV"]}
    for i in range(B):
        if not bool(valid[i].item()):
            continue
        s = series_list[i]
        if s in idx_by_ser and len(idx_by_ser[s]) < int(max_per_ser):
            idx_by_ser[s].append(i)

    rng = np.random.default_rng(cfg.SEED + 4242 + int(torch.randint(0, 10**6, (1,)).item()))
    out = {}

    for ser in ["ESG", "ENV", "SOC", "GOV"]:
        idxs = idx_by_ser.get(ser, [])
        if len(idxs) < 6:
            out[ser] = float("nan")
            continue

        n = len(idxs)
        max_possible = n * (n - 1) // 2
        n_pairs = min(int(max_pairs), int(max_possible))
        if n_pairs < 10:
            out[ser] = float("nan")
            continue

        dtw_list = []
        z_list = []
        for _ in range(n_pairs):
            a, b = rng.choice(idxs, size=2, replace=False)

            xa = np.array((traj_vals_list[a].get(ser, []) or [])[-cfg.K:], dtype=float)
            xb = np.array((traj_vals_list[b].get(ser, []) or [])[-cfg.K:], dtype=float)
            if xa.size < 3 or xb.size < 3:
                continue

            xa = _norm_seq(xa).tolist()
            xb = _norm_seq(xb).tolist()
            dd = float(dtw_sakoe_chiba(xa, xb, band=cfg.DTW_BAND))

            zd = float((1.0 - torch.sum(Z_geom[a] * Z_geom[b])).detach().cpu().item())

            dtw_list.append(dd)
            z_list.append(zd)

        if len(dtw_list) < 10:
            out[ser] = float("nan")
        else:
            out[ser] = float(spearmanr(dtw_list, z_list).correlation)

    return out

global_step = 0
stage = "MSE"

def compute_total_loss(out, batch, use_geom=True, use_valtok=True, use_reg=True, use_trip=True):
    ce = out.loss

    if stage == "SPEARMAN":
        L_CE     = L_CE_SPEARMAN
        L_VALTOK = L_VALTOK_SPEARMAN
        L_GEOM   = L_GEOM_SPEARMAN
        L_TRIP   = L_TRIP_SPEARMAN
    else:
        L_CE     = L_CE_BASE
        L_VALTOK = L_VALTOK_BASE
        L_GEOM   = L_GEOM_BASE
        L_TRIP   = L_TRIP_BASE

    L_REG = L_REG_BASE

    geom = torch.tensor(0.0, device=cfg.DEVICE)
    if use_geom:
        geom = geom_dtw_loss_from_batch(out.hidden_states[-1], batch)

    val_tok = torch.tensor(0.0, device=cfg.DEVICE)
    if use_valtok:
        val_tok = value_mse_loss_from_batch(out.logits, batch["labels"], batch)

    reg = torch.tensor(0.0, device=cfg.DEVICE)
    if use_reg:
        reg = regression_mse_from_batch(out, batch)

    trip = torch.tensor(0.0, device=cfg.DEVICE)
    if use_trip:
        trip = facet_triplet_loss(out.hidden_states[-1], batch)

    warm = cfg.NUM_WARMUP_SCALE if global_step <= cfg.NUM_WARMUP_STEPS else 1.0

    total = (
        L_CE * ce +
        L_GEOM * geom +
        warm * (L_VALTOK * val_tok + L_REG * reg) +
        L_TRIP * trip
    )
    return total, ce, geom, val_tok, reg, trip

def _first_trainable_param():
    for n, p in lm.named_parameters():
        if p.requires_grad:
            return n, p
    return None, None

@torch.no_grad()
def batch_reg_mse_numpy(out, batch):
    hs = out.hidden_states[-1]
    Z, valid = extract_facet_Z(hs, batch["facet_pos"])
    _, Z_pred = split_facet_z(Z)
    y_pred = value_head(Z_pred).squeeze(-1).float().detach().cpu().numpy()

    labels = batch["labels"]
    pos = get_supervised_positions(labels)
    y_true = []
    for i in range(labels.size(0)):
        tid = int(labels[pos[i,0], pos[i,1]].item())
        tok = tokenizer.convert_ids_to_tokens(tid)
        try:
            y_true.append(float(token_to_float(tok)))
        except Exception:
            y_true.append(0.0)
    y_true = np.array(y_true, dtype=float)

    mse = float(np.mean((y_pred - y_true) ** 2)) if y_true.size else float("nan")
    per_ser = {}
    for ser in ["ESG","ENV","SOC","GOV"]:
        idx = [i for i,s in enumerate(batch["target_series"]) if s == ser]
        if len(idx) == 0:
            continue
        per_ser[ser] = float(np.mean([(y_pred[i]-y_true[i])**2 for i in idx]))
    return mse, per_ser

@torch.no_grad()
def val_ev_mse_by_series() -> dict:
    lm.eval()
    ys = {s: [] for s in ["ESG","ENV","SOC","GOV"]}
    yh = {s: [] for s in ["ESG","ENV","SOC","GOV"]}

    for batch in val_dl:
        bias = build_attn_bias(batch["input_ids"], batch["target_series"])
        with attn_bias_context(bias):
            out = lm(
                input_ids=batch["input_ids"],
                attention_mask=batch["attention_mask"],
                labels=batch["labels"],
                output_hidden_states=False,
            )
        yhat, ytrue = expected_numeric_from_labels_position(out.logits, batch["labels"], batch["target_series"])
        for i, ser in enumerate(batch["target_series"]):
            if ser in ys:
                ys[ser].append(float(ytrue[i]))
                yh[ser].append(float(yhat[i]))

    out_mse = {}
    for ser in ["ESG","ENV","SOC","GOV"]:
        if len(ys[ser]) == 0:
            out_mse[ser] = float("inf")
        else:
            yt = np.array(ys[ser], dtype=float)
            yp = np.array(yh[ser], dtype=float)
            out_mse[ser] = float(np.mean((yp - yt) ** 2))
    return out_mse

@torch.no_grad()
def val_geom_spearman_by_series(max_per_ser=None, max_pairs=None) -> dict:
    lm.eval(); value_head.eval()
    if max_per_ser is None: max_per_ser = int(cfg.GEOM_PROBE_MAX_PER_SER)
    if max_pairs is None:   max_pairs = int(cfg.GEOM_PROBE_PAIRS)

    store = {s: [] for s in ["ESG","ENV","SOC","GOV"]}

    for batch in val_dl:
        out = lm(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["labels"],
            output_hidden_states=True,
        )
        h_last = out.hidden_states[-1]
        Z, valid = extract_facet_Z(h_last, batch["facet_pos"])
        Z_geom, _ = split_facet_z(Z)
        Z_geom = F.normalize(Z_geom.float(), p=2, dim=-1)

        B = Z_geom.size(0)
        for i in range(B):
            ser = batch["target_series"][i]
            if ser not in store or len(store[ser]) >= max_per_ser:
                continue
            if not bool(valid[i].item()):
                continue
            z = Z_geom[i].detach().cpu().numpy()
            seq = np.array((batch["traj_vals"][i].get(ser, []) or [])[-cfg.K:], dtype=float)
            if seq.size == 0:
                continue
            store[ser].append((z, seq))

        if all(len(store[s]) >= max_per_ser for s in store):
            break

    rng = np.random.default_rng(cfg.SEED + 999)
    out = {}
    for ser in ["ESG","ENV","SOC","GOV"]:
        arr = store[ser]
        if len(arr) < 20:
            out[ser] = float("nan")
            continue

        n = len(arr)
        idx = np.arange(n)
        max_possible = n * (n - 1) // 2
        n_pairs = min(int(max_pairs), int(max_possible))

        dtw_list = []
        z_list = []
        for _ in range(n_pairs):
            i, j = rng.choice(idx, size=2, replace=False)
            zi, hi = arr[i]
            zj, hj = arr[j]

            hi_n = _norm_seq(hi)
            hj_n = _norm_seq(hj)
            dd = float(dtw_sakoe_chiba(hi_n.tolist(), hj_n.tolist(), band=cfg.DTW_BAND))

            zd = float(np.linalg.norm(zi - zj))
            dtw_list.append(dd)
            z_list.append(zd)

        out[ser] = float(spearmanr(dtw_list, z_list).correlation) if len(dtw_list) >= 20 else float("nan")

    return out

def run_epoch_train(epoch: int, use_geom=True, use_valtok=True, use_reg=True, use_trip=True):
    global global_step
    lm.train(); value_head.train()

    name, p0 = _first_trainable_param()
    w_before = p0.detach().float().clone() if p0 is not None else None
    head_before = value_head.weight.detach().float().clone()

    total_sum = ce_sum = geom_sum = vt_sum = reg_sum = trip_sum = 0.0
    step = 0
    skipped = 0

    for batch in train_dl:
        step += 1
        global_step += 1
        opt.zero_grad(set_to_none=True)

        bias = build_attn_bias(batch["input_ids"], batch["target_series"])
        with attn_bias_context(bias):
            if USE_AMP:
                with autocast(device_type="cuda", dtype=AMP_DTYPE):
                    out = lm(
                        input_ids=batch["input_ids"],
                        attention_mask=batch["attention_mask"],
                        labels=batch["labels"],
                        output_hidden_states=True,
                    )
                    total, ce, geom, vt, reg, trip = compute_total_loss(out, batch, use_geom, use_valtok, use_reg, use_trip)
            else:
                out = lm(
                    input_ids=batch["input_ids"],
                    attention_mask=batch["attention_mask"],
                    labels=batch["labels"],
                    output_hidden_states=True,
                )
                total, ce, geom, vt, reg, trip = compute_total_loss(out, batch, use_geom, use_valtok, use_reg, use_trip)

        if not (torch.isfinite(total) and torch.isfinite(ce) and torch.isfinite(geom)
                and torch.isfinite(vt) and torch.isfinite(reg) and torch.isfinite(trip)):
            skipped += 1
            continue

        if USE_AMP:
            scaler.scale(total).backward()
            scaler.unscale_(opt)
            torch.nn.utils.clip_grad_norm_(trainable_params + list(value_head.parameters()), 1.0)
            scaler.step(opt)
            scaler.update()
        else:
            total.backward()
            torch.nn.utils.clip_grad_norm_(trainable_params + list(value_head.parameters()), 1.0)
            opt.step()

        total_sum += float(total.detach().cpu().item())
        ce_sum    += float(ce.detach().cpu().item())
        geom_sum  += float(geom.detach().cpu().item())
        vt_sum    += float(vt.detach().cpu().item())
        reg_sum   += float(reg.detach().cpu().item())
        trip_sum  += float(trip.detach().cpu().item())

        do_print = ((step % int(cfg.PRINT_EVERY_STEPS)) == 0 or step == 1)
        if do_print:
            yhat, ytrue = expected_numeric_from_labels_position(out.logits.detach(), batch["labels"], batch["target_series"])
            ev_mse_batch = float(np.mean([(a - b) ** 2 for a, b in zip(yhat, ytrue)])) if len(yhat) else float("nan")
            ev_by_ser = per_series_mse_numpy(yhat, ytrue, batch["target_series"])

            reg_mse_batch, reg_by_ser = batch_reg_mse_numpy(out, batch)

            sp_train = train_geom_spearman_from_batch(
                out.hidden_states[-1], batch,
                max_per_ser=int(cfg.TRAIN_SPEARMAN_MAX_PER_SER),
                max_pairs=int(cfg.TRAIN_SPEARMAN_PAIRS),
            )
            sp_worst = np.nanmin([sp_train.get(s, np.nan) for s in ["ESG","ENV","SOC","GOV"]])

            print(
                f"[TRAIN] stage={stage} epoch={epoch} step={step}/{len(train_dl)} "
                f"tot={total_sum/max(1, step-skipped):.4f} ce={ce_sum/max(1, step-skipped):.4f} "
                f"geom={geom_sum/max(1, step-skipped):.4f} valTok={vt_sum/max(1, step-skipped):.4f} "
                f"regLoss={reg_sum/max(1, step-skipped):.4f} trip={trip_sum/max(1, step-skipped):.4f} | "
                f"EV_MSE={ev_mse_batch:.3f} EV_by_ser={ev_by_ser} | "
                f"REG_MSE={reg_mse_batch:.3f} REG_by_ser={reg_by_ser} | "
                f"SP_train={sp_train} SP_worst={sp_worst:.3f} "
                f"cuda_alloc={torch.cuda.memory_allocated()/1e9:.2f}GB"
            )

        if (step % 50) == 0:
            gc.collect()
            torch.cuda.empty_cache()

    denom = max(1, step - skipped)
    if p0 is not None and w_before is not None:
        w_after = p0.detach().float()
        delta_max = (w_after - w_before).abs().max().item()
        delta_mean = (w_after - w_before).abs().mean().item()
        print(f"[SANITY] LoRA param={name} delta_max={delta_max:.6e} delta_mean_abs={delta_mean:.6e}")
    head_delta = (value_head.weight.detach().float() - head_before).abs().max().item()
    print(f"[SANITY] value_head.weight delta_max={head_delta:.6e}")

    return {"total": total_sum/denom, "ce": ce_sum/denom, "geom": geom_sum/denom,
            "val_tok": vt_sum/denom, "reg": reg_sum/denom, "trip": trip_sum/denom,
            "skipped": skipped}

@torch.no_grad()
def run_epoch_val_losses(epoch: int, use_geom=True, use_valtok=True, use_reg=True, use_trip=True):
    lm.eval(); value_head.eval()

    total_sum = ce_sum = geom_sum = vt_sum = reg_sum = trip_sum = 0.0
    step = 0
    skipped = 0

    for batch in val_dl:
        step += 1
        bias = build_attn_bias(batch["input_ids"], batch["target_series"])
        with attn_bias_context(bias):
            out = lm(
                input_ids=batch["input_ids"],
                attention_mask=batch["attention_mask"],
                labels=batch["labels"],
                output_hidden_states=True,
            )
            total, ce, geom, vt, reg, trip = compute_total_loss(out, batch, use_geom, use_valtok, use_reg, use_trip)

        if not (torch.isfinite(total) and torch.isfinite(ce) and torch.isfinite(geom)
                and torch.isfinite(vt) and torch.isfinite(reg) and torch.isfinite(trip)):
            skipped += 1
            continue

        total_sum += float(total.detach().cpu().item())
        ce_sum    += float(ce.detach().cpu().item())
        geom_sum  += float(geom.detach().cpu().item())
        vt_sum    += float(vt.detach().cpu().item())
        reg_sum   += float(reg.detach().cpu().item())
        trip_sum  += float(trip.detach().cpu().item())

    denom = max(1, step - skipped)
    print(
        f"[VAL-LOSSES] stage={stage} epoch={epoch} tot={total_sum/denom:.4f} ce={ce_sum/denom:.4f} "
        f"geom={geom_sum/denom:.4f} valTok={vt_sum/denom:.4f} reg={reg_sum/denom:.4f} trip={trip_sum/denom:.4f} "
        f"skipped={skipped}"
    )
    return {"total": total_sum/denom, "ce": ce_sum/denom, "geom": geom_sum/denom,
            "val_tok": vt_sum/denom, "reg": reg_sum/denom, "trip": trip_sum/denom,
            "skipped": skipped}

best_state = None
best_head = None

patience = 3
bad = 0

stage = "MSE"
best_mse_metric = float("inf")
best_spearman_metric = -float("inf")

use_geom = True
use_valtok = True
use_reg = True
use_trip = True

GATE_ON = "EV"

for epoch in range(1, cfg.EPOCHS + 1):
    print(f"========== EPOCH {epoch} ==========")
    tr = run_epoch_train(epoch, use_geom, use_valtok, use_reg, use_trip)
    _ = run_epoch_val_losses(epoch, use_geom, use_valtok, use_reg, use_trip)

    mse_by_ser = val_ev_mse_by_series()
    worst_mse = max(mse_by_ser[s] for s in ["ESG", "ENV", "SOC", "GOV"])

    do_geom_probe = ((epoch % cfg.GEOM_PROBE_EVERY_EPOCH) == 0)
    if do_geom_probe:
        sp = val_geom_spearman_by_series(cfg.GEOM_PROBE_MAX_PER_SER, cfg.GEOM_PROBE_PAIRS)
        worst_sp = min(sp.get(s, float("nan")) for s in ["ESG", "ENV", "SOC", "GOV"])
    else:
        sp = {s: float("nan") for s in ["ESG", "ENV", "SOC", "GOV"]}
        worst_sp = float("nan")

    mse_gate_ok = (worst_mse < cfg.MSE_GATE)
    sp_gate_ok  = (do_geom_probe and all(sp[s] > cfg.SPEARMAN_GATE for s in ["ESG","ENV","SOC","GOV"]))

    print(f"[VAL-GATE] (EV) MSE_by_series={mse_by_ser} worst={worst_mse:.4f} (gate<{cfg.MSE_GATE}) -> {mse_gate_ok}")
    if do_geom_probe:
        print(f"[VAL-GATE] spearman_by_series={sp} worst={worst_sp:.4f} (gate>{cfg.SPEARMAN_GATE}) -> {sp_gate_ok}")
    else:
        print(f"[VAL-GATE] spearman probe skipped this epoch (every {cfg.GEOM_PROBE_EVERY_EPOCH} epochs)")

    if stage == "MSE" and mse_gate_ok:
        stage = "SPEARMAN"
        bad = 0
        print("[EARLYSTOP] Stage transition: EV-MSE gate met -> now optimizing SPEARMAN (while keeping EV gate).")

    improved = False
    if stage == "MSE":
        metric = worst_mse
        if metric < best_mse_metric:
            best_mse_metric = metric
            improved = True
    else:
        if mse_gate_ok and do_geom_probe:
            metric = worst_sp
            if metric > best_spearman_metric:
                best_spearman_metric = metric
                improved = True
        else:
            metric = -float("inf")

    if improved:
        bad = 0
        best_state = {k: v.detach().cpu().clone() for k, v in lm.state_dict().items()}
        best_head  = {k: v.detach().cpu().clone() for k, v in value_head.state_dict().items()}
        print(f"[EARLYSTOP] New best checkpoint saved. stage={stage} metric={metric}")
    else:
        bad += 1
        print(f"[EARLYSTOP] No improvement. stage={stage} bad={bad}/{patience}")
        if bad >= patience:
            print("[EARLYSTOP] stopping.")
            break

if best_state is not None:
    lm.load_state_dict(best_state, strict=True)
    lm.to(cfg.DEVICE)
    if best_head is not None:
        value_head.load_state_dict(best_head, strict=True)
        value_head.to(cfg.DEVICE)
    print("[RESTORE] Loaded best checkpoint (lm + value_head).")

print("Phase 9 complete: EV-MSE gate + Spearman gate + train-step EV/REG/Spearman prints.")

"""# Geometric Booster"""


import torch.nn as nn

def has_cycle(root: nn.Module):
    seen = set()
    stack = set()

    def dfs(m):
        mid = id(m)
        if mid in stack:
            return True
        if mid in seen:
            return False
        seen.add(mid)
        stack.add(mid)
        for c in m.children():
            if dfs(c):
                return True
        stack.remove(mid)
        return False

    return dfs(root)

print("Cycle in lm module tree?", has_cycle(lm))

"""# Phase 10

"""




import numpy as np
import pandas as pd
import torch
from typing import List, Tuple, Dict
from collections import defaultdict

@torch.no_grad()
def extract_supervised_positions(labels: torch.Tensor) -> torch.Tensor:
    """
    labels: [B,T], exactly one != -100 per row.
    Returns tpos [B] giving supervised token index per sample.
    """
    pos = (labels != -100).nonzero(as_tuple=False)
    pos = pos[pos[:, 0].argsort()]
    return pos[:, 1]

@torch.no_grad()
def expected_numeric_pred_from_logits(
    out_logits: torch.Tensor,
    labels: torch.Tensor,
    series_list: List[str]
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Returns y_true_numeric [B], y_pred_numeric [B] using expected value over allowed ids.
    """
    B, T, V = out_logits.shape
    tpos = extract_supervised_positions(labels)
    logits_last = out_logits[torch.arange(B, device=out_logits.device), tpos, :]

    y_true = []
    y_pred = []
    for i in range(B):
        ser = series_list[i]

        tid = int(labels[i, int(tpos[i].item())].item())
        true_tok = tokenizer.convert_ids_to_tokens(tid)
        try:
            y_true_val = float(token_to_float(true_tok))
        except Exception:
            y_true_val = 0.0

        ids = allowed_ids[ser]
        if ids.numel() == 0:
            pid = int(torch.argmax(logits_last[i]).item())
            pred_tok = tokenizer.convert_ids_to_tokens(pid)
            try:
                y_pred_val = float(token_to_float(pred_tok))
            except Exception:
                y_pred_val = y_true_val
        else:
            ev = expected_value_from_logits(logits_last[i:i+1], ids, id_to_val[ser])[0].item()
            y_pred_val = float(ev)

        y_true.append(y_true_val)
        y_pred.append(y_pred_val)

    return np.array(y_true, dtype=float), np.array(y_pred, dtype=float)

@torch.no_grad()
def numeric_pred_from_reg_head(
    hidden_states_last: torch.Tensor,
    labels: torch.Tensor,
    series_list: List[str],
    facet_pos: torch.Tensor
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Uses FACET position representation, splits into (Z_geom, Z_pred), applies value_head(Z_pred).
    Returns y_true_numeric [B], y_pred_numeric [B].
    """
    B = labels.size(0)
    tpos = extract_supervised_positions(labels)
    y_true = []
    for i in range(B):
        tid = int(labels[i, int(tpos[i].item())].item())
        tok = tokenizer.convert_ids_to_tokens(tid)
        try:
            y_true.append(float(token_to_float(tok)))
        except Exception:
            y_true.append(0.0)
    y_true = np.array(y_true, dtype=float)

    Z, valid = extract_facet_Z(hidden_states_last, facet_pos)
    Z_geom, Z_pred = split_facet_z(Z)
    y_pred = value_head(Z_pred).squeeze(-1).float().detach().cpu().numpy()


    return y_true, y_pred

def safe_mape(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    denom = np.maximum(np.abs(y_true), 1e-6)
    return float(np.mean(np.abs((y_pred - y_true) / denom)) * 100.0)

def bias(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    return float(np.mean(y_pred - y_true))

def theils_u2_with_prev(y_true: np.ndarray, y_pred: np.ndarray, y_prev: np.ndarray) -> float:
    rmse_p = np.sqrt(np.mean((y_pred - y_true) ** 2) + 1e-12)
    rmse_n = np.sqrt(np.mean((y_prev - y_true) ** 2) + 1e-12)
    return float(rmse_p / (rmse_n + 1e-12))

@torch.no_grad()
def eval_forecast_metrics(
    dataloader,
    split_name: str = "TEST",
    mode: str = "token_ev",
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    """
    Returns:
      df_series_macro: per series macro metrics (equal weight per company)
      df_company: per (series,ticker) metrics
      df_rows: per-sample rows with y_true/y_pred/y_prev (for later analysis)
    """
    lm.eval()
    value_head.eval()

    rows = []
    for batch_i, batch in enumerate(dataloader):

        out = lm(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["labels"],
            output_hidden_states=(mode == "reg_head"),
        )

        if mode == "token_ev":
            y_true, y_pred = expected_numeric_pred_from_logits(out.logits, batch["labels"], batch["target_series"])
        elif mode == "reg_head":
            y_true, y_pred = numeric_pred_from_reg_head(
                out.hidden_states[-1], batch["labels"], batch["target_series"], batch["facet_pos"]
            )
        else:
            raise ValueError(f"Unknown mode={mode}")

        y_prev = []
        for i in range(len(batch["target_series"])):
            ser = batch["target_series"][i]
            hv = batch["hist_vals"][i].get(ser, [])
            if hv is None or len(hv) == 0:
                y_prev.append(float(y_true[i]))
            else:
                y_prev.append(float(hv[-1]))
        y_prev = np.array(y_prev, dtype=float)

        for i in range(len(y_true)):
            rows.append({
                "split": split_name,
                "mode": mode,
                "ticker": batch["ticker"][i],
                "series": batch["target_series"][i],
                "t_index": int(batch["t_index"][i]),
                "y_true": float(y_true[i]),
                "y_pred": float(y_pred[i]),
                "y_prev": float(y_prev[i]),
            })

    df_rows = pd.DataFrame(rows)

    comp_rows = []
    for (ser, tick), g in df_rows.groupby(["series", "ticker"]):
        yt = g["y_true"].to_numpy(float)
        yp = g["y_pred"].to_numpy(float)
        yv = g["y_prev"].to_numpy(float)
        comp_rows.append({
            "split": split_name,
            "mode": mode,
            "series": ser,
            "ticker": tick,
            "n": len(g),
            "MSE": float(np.mean((yp-yt)**2)),
            "MAE": float(np.mean(np.abs(yp-yt))),
            "MAPE": safe_mape(yt, yp),
            "TheilsU2": theils_u2_with_prev(yt, yp, yv),
            "Bias": bias(yt, yp),
        })
    df_company = pd.DataFrame(comp_rows).sort_values(["series","MSE"])

    macro_rows = []
    for ser, g in df_company.groupby("series"):
        macro_rows.append({
            "split": split_name,
            "mode": mode,
            "series": ser,
            "n_companies": int(g["ticker"].nunique()),
            "MSE": float(g["MSE"].mean()),
            "MAE": float(g["MAE"].mean()),
            "MAPE": float(g["MAPE"].mean()),
            "TheilsU2": float(g["TheilsU2"].mean()),
            "Bias": float(g["Bias"].mean()),
        })
    df_series_macro = pd.DataFrame(macro_rows).sort_values("series")

    return df_series_macro, df_company, df_rows

val_macro_tok, val_company_tok, _ = eval_forecast_metrics(val_dl, "VAL", mode="token_ev")
test_macro_tok, test_company_tok, df_pred_test_tok = eval_forecast_metrics(test_dl, "TEST", mode="token_ev")

val_macro_reg, val_company_reg, _ = eval_forecast_metrics(val_dl, "VAL", mode="reg_head")
test_macro_reg, test_company_reg, df_pred_test_reg = eval_forecast_metrics(test_dl, "TEST", mode="reg_head")

print("\n[METRICS] VAL per-series macro (token_ev):")
display(val_macro_tok)

print("\n[METRICS] VAL per-series macro (reg_head):")
display(val_macro_reg)

print("\n[METRICS] TEST per-series macro (token_ev):")
display(test_macro_tok)

print("\n[METRICS] TEST per-series macro (reg_head):")
display(test_macro_reg)

print("\n[METRICS] TEST best/worst by company (token_ev):")
display(test_company_tok.head(10))
display(test_company_tok.tail(10))

print("\n[METRICS] TEST best/worst by company (reg_head):")
display(test_company_reg.head(10))
display(test_company_reg.tail(10))

@torch.no_grad()
def compute_facet_vectors_geom(dataloader) -> pd.DataFrame:
    """
    For each sample, compute z_geom = first half of hidden_state at facet position.
    Return df with: ticker, series, t_index, z_geom (np array), hist (np array)
    """
    lm.eval()
    rows = []
    for batch in dataloader:
        out = lm(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["labels"],
            output_hidden_states=True,
        )
        h_last = out.hidden_states[-1]
        B, T, Hdim = h_last.shape

        for i in range(B):
            ser = batch["target_series"][i]
            p = int(batch["facet_pos"][i].item())
            if p < 0 or p >= T:
                continue

            z_full = h_last[i, p, :].detach()
            z_geom = z_full[: z_full.numel() // 2].cpu().numpy()

            hist = np.array((batch["hist_vals"][i].get(ser, []) or [])[-cfg.K:], dtype=float)
            rows.append({
                "ticker": batch["ticker"][i],
                "series": ser,
                "t_index": int(batch["t_index"][i]),
                "z": z_geom,
                "hist": hist,
            })
    return pd.DataFrame(rows)

df_z = compute_facet_vectors_geom(test_dl)
print(f"[GEOM] extracted Z_geom facet vectors on TEST: n={len(df_z)}")
display(df_z.head(3))

def pairwise_geom_metrics_for_series(df_ser: pd.DataFrame, max_pairs: int = 2000) -> Dict[str, float]:
    if len(df_ser) < 5:
        return {"spearman": float("nan"), "pairs": 0}

    rng = np.random.default_rng(cfg.SEED)
    idx = np.arange(len(df_ser))
    n_pairs = min(max_pairs, len(df_ser) * (len(df_ser)-1) // 2)

    dtw_list = []
    z_list = []
    for _ in range(n_pairs):
        i, j = rng.choice(idx, size=2, replace=False)
        xi = df_ser.iloc[i]["hist"]
        xj = df_ser.iloc[j]["hist"]
        if len(xi) == 0 or len(xj) == 0:
            continue
        dd = float(dtw_sakoe_chiba(xi.tolist(), xj.tolist(), band=cfg.DTW_BAND))

        zi = df_ser.iloc[i]["z"]
        zj = df_ser.iloc[j]["z"]
        zd = float(np.linalg.norm(zi - zj))

        dtw_list.append(dd)
        z_list.append(zd)

    if len(dtw_list) < 10:
        return {"spearman": float("nan"), "pairs": len(dtw_list)}

    from scipy.stats import spearmanr
    sp = float(spearmanr(dtw_list, z_list).correlation)
    return {"spearman": sp, "pairs": len(dtw_list)}

geom_rows = []
for ser in ["ESG", "ENV", "SOC", "GOV"]:
    df_ser = df_z[df_z["series"] == ser]
    m = pairwise_geom_metrics_for_series(df_ser, max_pairs=2000)
    geom_rows.append({"series": ser, **m})
df_geom = pd.DataFrame(geom_rows)

print("\n[GEOM] Spearman corr(DTW, ||Z_geom - Z_geom'||) per series:")
display(df_geom)

def knn_overlap(df_ser: pd.DataFrame, k: int = 5, n_queries: int = 30) -> float:
    if len(df_ser) < k + 2:
        return float("nan")

    rng = np.random.default_rng(cfg.SEED)
    idxs = np.arange(len(df_ser))
    Q = min(n_queries, len(df_ser))

    overlaps = []
    for qi in rng.choice(idxs, size=Q, replace=False):
        zq = df_ser.iloc[qi]["z"]
        xq = df_ser.iloc[qi]["hist"]

        zd = []
        for j in idxs:
            if j == qi:
                continue
            zj = df_ser.iloc[j]["z"]
            zd.append((j, float(np.linalg.norm(zq - zj))))
        z_neighbors = [j for j, _ in sorted(zd, key=lambda x: x[1])[:k]]

        dd = []
        for j in idxs:
            if j == qi:
                continue
            xj = df_ser.iloc[j]["hist"]
            d = float(dtw_sakoe_chiba(xq.tolist(), xj.tolist(), band=cfg.DTW_BAND))
            dd.append((j, d))
        d_neighbors = [j for j, _ in sorted(dd, key=lambda x: x[1])[:k]]

        overlaps.append(len(set(z_neighbors).intersection(set(d_neighbors))) / float(k))

    return float(np.mean(overlaps))

overlap_rows = []
for ser in ["ESG","ENV","SOC","GOV"]:
    df_ser = df_z[df_z["series"] == ser]
    ov = knn_overlap(df_ser, k=5, n_queries=30)
    overlap_rows.append({"series": ser, "knn_overlap@5": ov})
df_overlap = pd.DataFrame(overlap_rows)

print("\n[RETRIEVE] kNN overlap@5 (Z_geom-space vs DTW) per series:")
display(df_overlap)

val_macro_tok.to_csv("val_series_macro_token_ev.csv", index=False)
test_macro_tok.to_csv("test_series_macro_token_ev.csv", index=False)
test_company_tok.to_csv("test_company_metrics_token_ev.csv", index=False)

val_macro_reg.to_csv("val_series_macro_reg_head.csv", index=False)
test_macro_reg.to_csv("test_series_macro_reg_head.csv", index=False)
test_company_reg.to_csv("test_company_metrics_reg_head.csv", index=False)

df_geom.to_csv("geom_spearman_by_series_Zgeom.csv", index=False)
df_overlap.to_csv("knn_overlap_by_series_Zgeom.csv", index=False)

df_pred_test_tok.to_csv("test_pred_rows_token_ev.csv", index=False)
df_pred_test_reg.to_csv("test_pred_rows_reg_head.csv", index=False)

print("[SAVED] Phase 10b CSVs written to current folder.")
print("Phase 10b complete: compared token-EV vs regression-head + geometry metrics on Z_geom.")

"""# SAVE MODEL"""

import os, json, time, zipfile, shutil, torch
from pathlib import Path

RUN_NAME = f"facet_phase9_{time.strftime('%Y%m%d_%H%M%S')}"
OUT_DIR = Path(f"/content/{RUN_NAME}")
OUT_DIR.mkdir(parents=True, exist_ok=True)

MODEL_DIR = OUT_DIR / "lm"
HEAD_DIR  = OUT_DIR / "value_head"
TOK_DIR   = OUT_DIR / "tokenizer"

MODEL_DIR.mkdir(exist_ok=True)
HEAD_DIR.mkdir(exist_ok=True)
TOK_DIR.mkdir(exist_ok=True)

print("Saving to:", OUT_DIR)

lm_cpu = lm.to("cpu")
value_head_cpu = value_head.to("cpu")
torch.cuda.empty_cache()

try:
    tokenizer.save_pretrained(str(TOK_DIR))
    print("[OK] tokenizer saved")
except Exception as e:
    print("[WARN] tokenizer save_pretrained failed:", e)

try:
    lm_cpu.save_pretrained(str(MODEL_DIR), safe_serialization=True)
    print("[OK] lm.save_pretrained saved")
except Exception as e:
    print("[WARN] lm.save_pretrained failed, falling back to state_dict:", e)
    torch.save(lm_cpu.state_dict(), str(MODEL_DIR / "pytorch_model.bin"))
    print("[OK] lm state_dict saved")

torch.save(value_head_cpu.state_dict(), str(HEAD_DIR / "value_head_state_dict.pt"))
torch.save(value_head_cpu, str(HEAD_DIR / "value_head_full_module.pt"))
print("[OK] value_head saved (state_dict + full module)")

manifest = {
    "run_name": RUN_NAME,
    "saved_at": time.strftime("%Y-%m-%d %H:%M:%S"),
    "notes": "Phase 9 checkpoint: lm + value_head + tokenizer",
}

try:
    if isinstance(cfg, dict):
        manifest["cfg"] = cfg
    else:
        manifest["cfg"] = {k: getattr(cfg, k) for k in dir(cfg) if k.isupper() or k.startswith("MSE") or k.startswith("SPEARMAN")}
except Exception as e:
    manifest["cfg_error"] = str(e)

with open(OUT_DIR / "manifest.json", "w") as f:
    json.dump(manifest, f, indent=2)
print("[OK] manifest.json saved")

ZIP_PATH = Path(f"/content/{RUN_NAME}.zip")
if ZIP_PATH.exists():
    ZIP_PATH.unlink()

def zipdir(src_dir: Path, zip_path: Path):
    with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as z:
        for root, _, files in os.walk(src_dir):
            for file in files:
                full = Path(root) / file
                rel = full.relative_to(src_dir.parent)
                z.write(full, rel.as_posix())

zipdir(OUT_DIR, ZIP_PATH)
print("[OK] zipped:", ZIP_PATH)

from google.colab import files
files.download(str(ZIP_PATH))

lm = lm_cpu.to(cfg.DEVICE)
value_head = value_head_cpu.to(cfg.DEVICE)
print("[OK] restored lm + value_head back to", cfg.DEVICE)

"""# PLOTS"""


import numpy as np
import pandas as pd
import torch

from collections import defaultdict

SERIES_LIST = ["ESG","ENV","SOC","GOV"]

MAX_PER_SER = 1000
PAIR_SAMPLES = 4000
TRIPLET_SAMPLES = 3000

KNN_K = 5
KNN_QUERIES = 60

@torch.no_grad()
def extract_facet_z_and_hist(dataloader, max_per_ser=1000):
    lm.eval()
    store = {s: [] for s in SERIES_LIST}

    for batch in dataloader:
        out = lm(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["labels"],
            output_hidden_states=True,
        )
        h_last = out.hidden_states[-1]
        B, T, H = h_last.shape

        for i in range(B):
            ser = batch["target_series"][i]
            if ser not in store:
                continue
            if len(store[ser]) >= max_per_ser:
                continue

            p = int(batch["facet_pos"][i].item())
            if p < 0 or p >= T:
                continue

            hist = (batch["hist_vals"][i].get(ser, []) or [])[-cfg.K:]
            if len(hist) == 0:
                continue

            z = h_last[i, p, :].detach().float().cpu().numpy()
            store[ser].append({
                "ticker": batch["ticker"][i],
                "t_index": int(batch["t_index"][i]),
                "z": z,
                "hist": np.asarray(hist, dtype=float)
            })

        if all(len(store[s]) >= max_per_ser for s in store):
            break

    return store

store = extract_facet_z_and_hist(test_dl, max_per_ser=MAX_PER_SER)
for s in SERIES_LIST:
    print(f"[GEOM DATA] {s}: n={len(store[s])}")

def _l2(a, b):
    return float(np.linalg.norm(a - b))

def _dtw(a, b):
    return float(dtw_sakoe_chiba(a.tolist(), b.tolist(), band=cfg.DTW_BAND))

def sample_pairs(n, m, rng):
    idx = np.arange(n)
    pairs = []
    for _ in range(m):
        i, j = rng.choice(idx, size=2, replace=False)
        pairs.append((int(i), int(j)))
    return pairs

def corr_metrics_for_series(arr, pair_samples=4000, seed=123):
    if len(arr) < 25:
        return {"spearman": np.nan, "pearson_logdtw": np.nan, "pairs": 0}

    rng = np.random.default_rng(seed)
    pairs = sample_pairs(len(arr), min(pair_samples, len(arr)*(len(arr)-1)//2), rng)

    dtw_list = []
    z_list = []
    for i, j in pairs:
        xi, xj = arr[i]["hist"], arr[j]["hist"]
        zi, zj = arr[i]["z"], arr[j]["z"]
        dtw_list.append(_dtw(xi, xj))
        z_list.append(_l2(zi, zj))

    if len(dtw_list) < 20:
        return {"spearman": np.nan, "pearson_logdtw": np.nan, "pairs": len(dtw_list)}

    from scipy.stats import spearmanr, pearsonr
    sp = float(spearmanr(dtw_list, z_list).correlation)
    pl = float(pearsonr(np.log1p(dtw_list), z_list)[0])
    return {"spearman": sp, "pearson_logdtw": pl, "pairs": len(dtw_list)}

rows = []
for ser in SERIES_LIST:
    m = corr_metrics_for_series(store[ser], pair_samples=PAIR_SAMPLES, seed=cfg.SEED + 7)
    rows.append({"series": ser, **m})
df_corr = pd.DataFrame(rows)
print("\n[A/B] Correlations per series:")
print(df_corr.to_string(index=False))

def knn_overlap(arr, k=5, n_queries=50, seed=0):
    if len(arr) < k + 5:
        return np.nan
    rng = np.random.default_rng(seed)
    idxs = np.arange(len(arr))
    Q = min(n_queries, len(arr))

    overlaps = []
    for qi in rng.choice(idxs, size=Q, replace=False):
        zq = arr[qi]["z"]
        xq = arr[qi]["hist"]

        zds = []
        for j in idxs:
            if j == qi:
                continue
            zds.append((int(j), _l2(zq, arr[j]["z"])))
        z_nn = [j for j,_ in sorted(zds, key=lambda x: x[1])[:k]]

        dds = []
        for j in idxs:
            if j == qi:
                continue
            dds.append((int(j), _dtw(xq, arr[j]["hist"])))
        d_nn = [j for j,_ in sorted(dds, key=lambda x: x[1])[:k]]

        overlaps.append(len(set(z_nn).intersection(set(d_nn))) / float(k))

    return float(np.mean(overlaps))

rows = []
for ser in SERIES_LIST:
    ov = knn_overlap(store[ser], k=KNN_K, n_queries=KNN_QUERIES, seed=cfg.SEED + 11)
    rows.append({"series": ser, f"knn_overlap@{KNN_K}": ov})
df_knn = pd.DataFrame(rows)
print("\n[C] kNN overlap per series:")
print(df_knn.to_string(index=False))

def triplet_order_accuracy(arr, triplet_samples=2000, seed=0):
    if len(arr) < 10:
        return {"triplet_acc": np.nan, "triplets": 0}

    rng = np.random.default_rng(seed)
    idx = np.arange(len(arr))

    correct = 0
    total = 0
    for _ in range(triplet_samples):
        i, j, k = rng.choice(idx, size=3, replace=False)
        xi, xj, xk = arr[i]["hist"], arr[j]["hist"], arr[k]["hist"]
        zi, zj, zk = arr[i]["z"], arr[j]["z"], arr[k]["z"]

        dij = _dtw(xi, xj)
        dik = _dtw(xi, xk)
        if abs(dij - dik) < 1e-9:
            continue

        zij = _l2(zi, zj)
        zik = _l2(zi, zk)

        if dij < dik:
            if zij < zik:
                correct += 1
        else:
            if zik < zij:
                correct += 1
        total += 1

    return {"triplet_acc": float(correct / max(1, total)), "triplets": int(total)}

rows = []
for ser in SERIES_LIST:
    m = triplet_order_accuracy(store[ser], triplet_samples=TRIPLET_SAMPLES, seed=cfg.SEED + 19)
    rows.append({"series": ser, **m})
df_trip = pd.DataFrame(rows)
print("\n[D] Triplet ordering accuracy per series:")
print(df_trip.to_string(index=False))

def dtw_bins_vs_z(arr, pair_samples=4000, n_bins=6, seed=0):
    if len(arr) < 25:
        return None

    rng = np.random.default_rng(seed)
    pairs = sample_pairs(len(arr), min(pair_samples, len(arr)*(len(arr)-1)//2), rng)

    dtw_list = []
    z_list = []
    for i,j in pairs:
        dtw_list.append(_dtw(arr[i]["hist"], arr[j]["hist"]))
        z_list.append(_l2(arr[i]["z"], arr[j]["z"]))

    dtw_arr = np.asarray(dtw_list, dtype=float)
    z_arr = np.asarray(z_list, dtype=float)

    qs = np.quantile(dtw_arr, np.linspace(0, 1, n_bins+1))
    rows = []
    for b in range(n_bins):
        lo, hi = qs[b], qs[b+1]
        if b == n_bins - 1:
            mask = (dtw_arr >= lo) & (dtw_arr <= hi)
        else:
            mask = (dtw_arr >= lo) & (dtw_arr < hi)
        if mask.sum() == 0:
            rows.append({"bin": b, "dtw_lo": lo, "dtw_hi": hi, "n": 0, "z_mean": np.nan})
            continue
        rows.append({
            "bin": b,
            "dtw_lo": float(lo),
            "dtw_hi": float(hi),
            "n": int(mask.sum()),
            "z_mean": float(z_arr[mask].mean()),
            "z_median": float(np.median(z_arr[mask])),
        })
    return pd.DataFrame(rows)

print("\n[E] DTW bins -> mean z-dist (per series):")
for ser in SERIES_LIST:
    df_bins = dtw_bins_vs_z(store[ser], pair_samples=PAIR_SAMPLES, n_bins=6, seed=cfg.SEED + 23)
    print("\nSeries:", ser)
    if df_bins is None:
        print("  Not enough data.")
    else:
        print(df_bins.to_string(index=False))

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.stats import spearmanr

def cosine_sim(a, b, eps=1e-12):
    a = np.asarray(a, dtype=float)
    b = np.asarray(b, dtype=float)
    return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + eps))

def sample_pairs(n, max_pairs, seed=0):
    rng = np.random.default_rng(seed)
    idx = np.arange(n)
    pairs = []
    for _ in range(max_pairs):
        i, j = rng.choice(idx, size=2, replace=False)
        pairs.append((i, j))
    return pairs

def cos_vs_dtw_for_series(df_ser, max_pairs=3000, seed=0):
    """
    df_ser: subset of df_z for one series, with columns: z (np array), hist (np array)
    Returns arrays: dtw_list, cos_list
    """
    df_ser = df_ser.reset_index(drop=True)
    n = len(df_ser)
    if n < 5:
        return np.array([]), np.array([])

    pairs = sample_pairs(n, max_pairs, seed=seed)
    dtw_list = []
    cos_list = []

    for i, j in pairs:
        hi = df_ser.loc[i, "hist"]
        hj = df_ser.loc[j, "hist"]
        if hi is None or hj is None or len(hi) == 0 or len(hj) == 0:
            continue

        zi = df_ser.loc[i, "z"]
        zj = df_ser.loc[j, "z"]
        if zi is None or zj is None:
            continue

        d = float(dtw_sakoe_chiba(hi.tolist(), hj.tolist(), band=cfg.DTW_BAND))
        c = cosine_sim(zi, zj)

        dtw_list.append(d)
        cos_list.append(c)

    return np.array(dtw_list, dtype=float), np.array(cos_list, dtype=float)

def plot_cos_vs_dtw(df_z, series="ESG", max_pairs=3000, seed=0, bins=30, save_prefix=None):
    df_ser = df_z[df_z["series"] == series].copy()
    dtw_list, cos_list = cos_vs_dtw_for_series(df_ser, max_pairs=max_pairs, seed=seed)

    if len(dtw_list) < 50:
        print(f"[COSvsDTW] Not enough pairs for series={series}. Got {len(dtw_list)}")
        return {"series": series, "pairs": len(dtw_list), "spearman": np.nan}

    sp = float(spearmanr(dtw_list, cos_list).correlation)
    print(f"[COSvsDTW] series={series} pairs={len(dtw_list)} spearman(dtw,cos)={sp:.4f}")

    rng = np.random.default_rng(seed + 123)
    keep = min(len(dtw_list), 2000)
    sel = rng.choice(np.arange(len(dtw_list)), size=keep, replace=False)

    plt.figure()
    plt.scatter(dtw_list[sel], cos_list[sel], s=8, alpha=0.35)
    plt.xlabel("DTW distance (raw trajectory)")
    plt.ylabel("Cosine similarity (FACET embedding)")
    plt.title(f"{series}: cosine(FACET) vs DTW(raw) | spearman={sp:.3f}")

    if save_prefix is not None:
        plt.savefig(f"{save_prefix}_{series}_scatter.png", dpi=150, bbox_inches="tight")

    order = np.argsort(dtw_list)
    x = dtw_list[order]
    y = cos_list[order]
    edges = np.linspace(x.min(), x.max(), bins + 1)

    xb, yb = [], []
    for b in range(bins):
        m = (x >= edges[b]) & (x < edges[b + 1])
        if np.sum(m) < 5:
            continue
        xb.append(float(np.mean(x[m])))
        yb.append(float(np.mean(y[m])))

    plt.figure()
    plt.plot(xb, yb, marker="o")
    plt.xlabel("DTW distance (raw trajectory)")
    plt.ylabel("Mean cosine similarity (FACET)")
    plt.title(f"{series}: binned mean cosine vs DTW")

    if save_prefix is not None:
        plt.savefig(f"{save_prefix}_{series}_binned.png", dpi=150, bbox_inches="tight")

    plt.show()

    return {"series": series, "pairs": len(dtw_list), "spearman": sp}

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

def plot_cos_vs_dtw(
    df,
    series,
    max_pairs=3000,
    seed=0,
    bins=30,
    save_prefix="cos_vs_dtw",
    series_col="series",
    vec_col="z",
    ticker_col="ticker",
    t_index_col="t_index",
    restrict_same_t_index=False,
):
    rng = np.random.default_rng(seed)

    d = df[df[series_col] == series].copy()
    if d.empty:
        raise ValueError(f"No rows where {series_col} == {series!r}")

    vecs = [np.asarray(v, dtype=float) for v in d[vec_col].values]
    n = len(vecs)
    if n < 2:
        raise ValueError(f"Need >=2 rows for series={series!r}; got {n}")

    def normalize_ts(x):
        x = np.asarray(x, dtype=float)
        mu = x.mean()
        sigma = x.std()
        if sigma < 1e-12:
            return x * 0.0
        return (x - mu) / sigma

    def dtw_distance(a, b):
        na, nb = len(a), len(b)
        D = np.full((na + 1, nb + 1), np.inf)
        D[0, 0] = 0.0
        for i in range(1, na + 1):
            ai = a[i - 1]
            for j in range(1, nb + 1):
                cost = abs(ai - b[j - 1])
                D[i, j] = cost + min(D[i - 1, j], D[i, j - 1], D[i - 1, j - 1])
        return float(D[na, nb])

    total_pairs = n * (n - 1) // 2
    target = min(max_pairs, total_pairs)

    if restrict_same_t_index:
        groups = d.groupby(t_index_col).indices
        pos_of = {idx: k for k, idx in enumerate(d.index)}
        grouped_pos = [np.array([pos_of[idx] for idx in idxs], dtype=int) for idxs in groups.values()]
        grouped_pos = [g for g in grouped_pos if len(g) >= 2]
        if not grouped_pos:
            raise ValueError("restrict_same_t_index=True but no t_index group has >=2 rows")

        pairs = set()
        while len(pairs) < target:
            g = grouped_pos[int(rng.integers(0, len(grouped_pos)))]
            i = int(rng.choice(g))
            j = int(rng.choice(g))
            if i == j:
                continue
            a, b = (i, j) if i < j else (j, i)
            pairs.add((a, b))
    else:
        pairs = set()
        while len(pairs) < target:
            i = int(rng.integers(0, n))
            j = int(rng.integers(0, n))
            if i == j:
                continue
            a, b = (i, j) if i < j else (j, i)
            pairs.add((a, b))

    cos_vals, dtw_vals = [], []
    for i, j in pairs:
        s1 = normalize_ts(vecs[i])
        s2 = normalize_ts(vecs[j])

        num = float(np.dot(s1, s2))
        den = float(np.linalg.norm(s1) * np.linalg.norm(s2) + 1e-12)
        cos = num / den

        dtw = dtw_distance(s1, s2)

        cos_vals.append(cos)
        dtw_vals.append(dtw)

    cos_vals = np.asarray(cos_vals, dtype=float)
    dtw_vals = np.asarray(dtw_vals, dtype=float)

    spearman = pd.Series(cos_vals).corr(pd.Series(dtw_vals), method="spearman")

    plt.figure(figsize=(7, 5))
    plt.scatter(cos_vals, dtw_vals, s=6, alpha=0.4)
    plt.xlabel("Cosine similarity (per-vector z-normalized)")
    plt.ylabel("DTW distance (per-vector z-normalized)")
    plt.title(f"{series}: Spearman ρ = {spearman:.3f} | pairs={len(cos_vals)} | nrows={n}")
    plt.grid(True, alpha=0.3)
    plt.tight_layout()

    fname = f"{save_prefix}_{series}.png"
    plt.savefig(fname, dpi=150)
    plt.close()

    return {
        "series": series,
        "spearman": float(spearman),
        "cos_mean": float(cos_vals.mean()),
        "dtw_mean": float(dtw_vals.mean()),
        "n_pairs": int(len(cos_vals)),
        "n_rows": int(n),
        "vec_len": int(len(vecs[0])) if n > 0 else None,
        "restrict_same_t_index": bool(restrict_same_t_index),
    }

cos_rows = []
for ser in ["ESG", "ENV", "SOC", "GOV"]:
    r = plot_cos_vs_dtw(
        df_z,
        series=ser,
        max_pairs=3000,
        seed=cfg.SEED,
        bins=30,
        save_prefix="cos_vs_dtw"
    )
    cos_rows.append(r)

df_cos = pd.DataFrame(cos_rows)
log("[GEOM] Spearman corr(DTW, cosine(FACET)) per series:")
display(df_cos)

df_cos.to_csv("geom_cosine_vs_dtw_spearman.csv", index=False)
log("[SAVED] cos_vs_dtw_*.png + geom_cosine_vs_dtw_spearman.csv written.")

df_z

print("df_z.columns type:", type(df_z.columns))
print("Number of columns:", len(df_z.columns))

for i, c in enumerate(df_z.columns[:100]):
    print(i, repr(c))

df_z

df_cos

import os, json, time, zipfile, shutil, torch
from pathlib import Path

RUN_NAME = f"facet_phase9_{time.strftime('%Y%m%d_%H%M%S')}"
OUT_DIR = Path(f"/content/{RUN_NAME}")
OUT_DIR.mkdir(parents=True, exist_ok=True)

MODEL_DIR = OUT_DIR / "lm"
HEAD_DIR  = OUT_DIR / "value_head"
TOK_DIR   = OUT_DIR / "tokenizer"

MODEL_DIR.mkdir(exist_ok=True)
HEAD_DIR.mkdir(exist_ok=True)
TOK_DIR.mkdir(exist_ok=True)

print("Saving to:", OUT_DIR)

lm_cpu = lm.to("cpu")
value_head_cpu = value_head.to("cpu")
torch.cuda.empty_cache()

try:
    tokenizer.save_pretrained(str(TOK_DIR))
    print("[OK] tokenizer saved")
except Exception as e:
    print("[WARN] tokenizer save_pretrained failed:", e)

try:
    lm_cpu.save_pretrained(str(MODEL_DIR), safe_serialization=True)
    print("[OK] lm.save_pretrained saved")
except Exception as e:
    print("[WARN] lm.save_pretrained failed, falling back to state_dict:", e)
    torch.save(lm_cpu.state_dict(), str(MODEL_DIR / "pytorch_model.bin"))
    print("[OK] lm state_dict saved")

torch.save(value_head_cpu.state_dict(), str(HEAD_DIR / "value_head_state_dict.pt"))
torch.save(value_head_cpu, str(HEAD_DIR / "value_head_full_module.pt"))
print("[OK] value_head saved (state_dict + full module)")

manifest = {
    "run_name": RUN_NAME,
    "saved_at": time.strftime("%Y-%m-%d %H:%M:%S"),
    "notes": "Phase 9 checkpoint: lm + value_head + tokenizer",
}

try:
    if isinstance(cfg, dict):
        manifest["cfg"] = cfg
    else:
        manifest["cfg"] = {k: getattr(cfg, k) for k in dir(cfg) if k.isupper() or k.startswith("MSE") or k.startswith("SPEARMAN")}
except Exception as e:
    manifest["cfg_error"] = str(e)

with open(OUT_DIR / "manifest.json", "w") as f:
    json.dump(manifest, f, indent=2)
print("[OK] manifest.json saved")

ZIP_PATH = Path(f"/content/{RUN_NAME}.zip")
if ZIP_PATH.exists():
    ZIP_PATH.unlink()

def zipdir(src_dir: Path, zip_path: Path):
    with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as z:
        for root, _, files in os.walk(src_dir):
            for file in files:
                full = Path(root) / file
                rel = full.relative_to(src_dir.parent)
                z.write(full, rel.as_posix())

zipdir(OUT_DIR, ZIP_PATH)
print("[OK] zipped:", ZIP_PATH)

from google.colab import files
files.download(str(ZIP_PATH))

lm = lm_cpu.to(cfg.DEVICE)
value_head = value_head_cpu.to(cfg.DEVICE)
print("[OK] restored lm + value_head back to", cfg.DEVICE)

"""# OTHER EXCHANGES"""


import os
import pandas as pd
from collections import defaultdict
from math import sqrt

lm.eval()

EXCHANGES = ["AMEX", "CBOE", "NASDAQ", "NYSE", "OTC", "PNK"]
MAX_PRED_STEPS = 10

def mse(y, yhat):
    return float(np.mean((np.array(y) - np.array(yhat))**2))

def mae(y, yhat):
    return float(np.mean(np.abs(np.array(y) - np.array(yhat))))

def mape(y, yhat):
    y = np.array(y)
    yhat = np.array(yhat)
    return float(np.mean(np.abs((y - yhat) / (y + 1e-8))) * 100.0)

def bias(y, yhat):
    return float(np.mean(np.array(yhat) - np.array(y)))

def theils_u(y, yhat):
    y = np.array(y)
    yhat = np.array(yhat)
    num = np.sqrt(np.mean((yhat - y)**2))
    den = np.sqrt(np.mean(y**2)) + 1e-8
    return float(num / den)

def load_series_file(path):
    data = {}
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) < 2:
                continue
            ticker = parts[0]
            vals = [float(x) for x in parts[1:]]
            data[ticker] = vals
    return data

@torch.no_grad()
def forecast_company(
    ticker, esg, env, soc, gov, ret, dates
):
    hist = {
        "ESG": esg[:cfg.K].copy(),
        "ENV": env[:cfg.K].copy(),
        "SOC": soc[:cfg.K].copy(),
        "GOV": gov[:cfg.K].copy(),
        "RET": ret[:cfg.K].copy(),
    }

    preds = {"ESG": [], "ENV": [], "SOC": [], "GOV": []}
    trues = {"ESG": [], "ENV": [], "SOC": [], "GOV": []}

    max_t = min(len(esg)-1, cfg.K + MAX_PRED_STEPS)

    for t in range(cfg.K, max_t):
        for series in ["ESG","ENV","SOC","GOV"]:
            hist_tokens = {}
            for k in hist:
                hist_tokens[k] = [f"<{k}_{v:.2f}>" for v in hist[k][-cfg.K:]]

            prompt = build_prompt(
                ticker=ticker,
                start_dt=None,
                end_dt=None,
                hist_tokens_map=hist_tokens,
                senti_tokens=[],
                news_text="(no news)",
                target_series=series
            )

            enc = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(cfg.DEVICE)

            out = lm(**enc, output_hidden_states=False)
            logits = out.logits

            ids = allowed_ids[series]
            logit_last = logits[0, -1]

            sel = logit_last.index_select(0, ids)
            j = int(torch.argmax(sel).item())
            pid = int(ids[j].item())
            tok = tokenizer.convert_ids_to_tokens(pid)

            y_pred = token_to_float(tok)

            y_true = {
                "ESG": esg[t],
                "ENV": env[t],
                "SOC": soc[t],
                "GOV": gov[t],
            }[series]

            preds[series].append(y_pred)
            trues[series].append(y_true)

        hist["ESG"].append(esg[t])
        hist["ENV"].append(env[t])
        hist["SOC"].append(soc[t])
        hist["GOV"].append(gov[t])
        hist["RET"].append(ret[t])

    return preds, trues

GLOBAL_RESULTS = {}

for EX in EXCHANGES:
    print("\n" + "="*80)
    print("EXCHANGE:", EX)
    print("="*80)

    base = EX

    path_esg = os.path.join(base, "esg_risk_ratings_1.txt")
    path_env = os.path.join(base, "PaperReady_e_scores.txt")
    path_soc = os.path.join(base, "PaperReady_s_scores.txt")
    path_gov = os.path.join(base, "PaperReady_g_scores.txt")

    if not os.path.exists(path_esg):
        print("Missing folder:", EX, "-> skipping")
        continue

    ESG = load_series_file(path_esg)
    ENV = load_series_file(path_env)
    SOC = load_series_file(path_soc)
    GOV = load_series_file(path_gov)

    RETURNS = {}
    for tkr in ESG.keys():
        dates = list(range(len(ESG[tkr])))
        try:
            RETURNS[tkr] = align_and_fill_returns(tkr, dates)
        except:
            RETURNS[tkr] = [0.0]*len(ESG[tkr])

    METRICS = defaultdict(lambda: defaultdict(list))

    for ticker in ESG.keys():
        if len(ESG[ticker]) < cfg.K + 5:
            continue

        preds, trues = forecast_company(
            ticker,
            ESG[ticker],
            ENV[ticker],
            SOC[ticker],
            GOV[ticker],
            RETURNS[ticker],
            None
        )

        for series in ["ESG","ENV","SOC","GOV"]:
            y = trues[series]
            yhat = preds[series]
            if len(y) == 0:
                continue

            METRICS[series]["MSE"].append(mse(y,yhat))
            METRICS[series]["MAE"].append(mae(y,yhat))
            METRICS[series]["MAPE"].append(mape(y,yhat))
            METRICS[series]["TheilU"].append(theils_u(y,yhat))
            METRICS[series]["Bias"].append(bias(y,yhat))

    print("\nRESULTS FOR", EX)
    print("-"*80)
    for series in ["ESG","ENV","SOC","GOV"]:
        print("\nSeries:", series)
        for m in ["MSE","MAE","MAPE","TheilU","Bias"]:
            if len(METRICS[series][m]) == 0:
                print(f"  {m}: N/A")
            else:
                print(f"  {m}: {np.mean(METRICS[series][m]):.4f}")

    GLOBAL_RESULTS[EX] = METRICS

print("\n\n==================== GLOBAL SUMMARY ====================")
for EX in GLOBAL_RESULTS:
    print("\n", EX)
    for series in ["ESG","ENV","SOC","GOV"]:
        print(" ", series)
        for m in ["MSE","MAE","MAPE","TheilU","Bias"]:
            vals = GLOBAL_RESULTS[EX][series][m]
            if len(vals)==0:
                print("   ", m, ": N/A")
            else:
                print("   ", m, ":", float(np.mean(vals)))

print("df_z.columns type:", type(df_z.columns))
print("Number of columns:", len(df_z.columns))

for i, c in enumerate(df_z.columns[:100]):
    print(i, repr(c))


import os
import re
import numpy as np
import pandas as pd
from collections import defaultdict

lm.eval()

EXCHANGES = ["AMEX", "CBOE", "NASDAQ", "NYSE", "OTC", "PNK"]
MAX_PRED_STEPS = 1

def mse(y, yhat):
    return float(np.mean((np.array(y) - np.array(yhat))**2))

def mae(y, yhat):
    return float(np.mean(np.abs(np.array(y) - np.array(yhat))))

def mape(y, yhat):
    y = np.array(y)
    yhat = np.array(yhat)
    return float(np.mean(np.abs((y - yhat) / (y + 1e-8))) * 100.0)

def bias(y, yhat):
    return float(np.mean(np.array(yhat) - np.array(y)))

def theils_u(y, yhat):
    y = np.array(y)
    yhat = np.array(yhat)
    num = np.sqrt(np.mean((yhat - y)**2))
    den = np.sqrt(np.mean(y**2)) + 1e-8
    return float(num / den)

_num_re = re.compile(r"[-+]?\d*\.\d+|[-+]?\d+")

def load_series_file(path, expected_series=None):
    """
    Returns dict[ticker] = list[float]

    Supports:
      A) Plain:        TICKER 1.0 2.0 3.0 ...
      B) Tagged:       Company: TICKER ESG: 1.0 2.0 ...
                       Company: TICKER ENV: ...
                       Company: TICKER SOC: ...
                       Company: TICKER GOV: ...
      C) Lines wrapped in dict-ish dumps: {'text': 'Company: ...'} etc.
      D) Quotes around Company token: "'Company:" etc.

    expected_series: optional string like "ESG"/"ENV"/"SOC"/"GOV".
                     If provided, we only accept tagged lines for that series.
                     (Plain format still accepted.)
    """
    data = {}

    with open(path, "r", encoding="utf-8") as f:
        for raw in f:
            line = raw.strip()
            if not line:
                continue

            line = line.strip().strip(",")
            line = line.replace("{'text':", "").replace('{"text":', "")
            line = line.strip().strip("'").strip('"')
            line = line.strip().strip("'").strip('"')

            if "Company:" in line and ":" in line:
                m = re.search(r"Company:\s*([A-Za-z0-9\.\-\_]+)\s+([A-Za-z]+)\s*:\s*(.*)$", line)
                if m:
                    ticker = m.group(1)
                    series = m.group(2).upper()
                    tail = m.group(3)

                    if expected_series is None or series == expected_series.upper():
                        nums = _num_re.findall(tail)
                        if nums:
                            data[ticker] = [float(x) for x in nums]
                        continue

            parts = line.split()
            if len(parts) >= 2:
                ticker = parts[0].strip("'").strip('"')
                try:
                    vals = [float(x) for x in parts[1:]]
                    data[ticker] = vals
                except ValueError:
                    pass

    return data

@torch.no_grad()
def forecast_company_multiseries(ticker, esg, env, soc, gov, ret):
    hist = {
        "ESG": esg[:cfg.K].copy(),
        "ENV": env[:cfg.K].copy(),
        "SOC": soc[:cfg.K].copy(),
        "GOV": gov[:cfg.K].copy(),
        "RET": ret[:cfg.K].copy(),
    }

    preds = {"ESG": [], "ENV": [], "SOC": [], "GOV": []}
    trues = {"ESG": [], "ENV": [], "SOC": [], "GOV": []}

    max_t = min(len(esg) - 1, cfg.K + MAX_PRED_STEPS)

    for t in range(cfg.K, max_t):
        for series in ["ESG", "ENV", "SOC", "GOV"]:
            hist_tokens = {k: [f"<{k}_{v:.2f}>" for v in hist[k][-cfg.K:]] for k in hist}

            prompt = build_prompt(
                ticker=ticker,
                start_dt=None,
                end_dt=None,
                hist_tokens_map=hist_tokens,
                senti_tokens=[],
                news_text="(no news)",
                target_series=series
            )

            enc = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(cfg.DEVICE)
            out = lm(**enc, output_hidden_states=False)
            logits = out.logits

            ids = allowed_ids[series]
            logit_last = logits[0, -1]
            sel = logit_last.index_select(0, ids)
            j = int(torch.argmax(sel).item())
            pid = int(ids[j].item())
            tok = tokenizer.convert_ids_to_tokens(pid)

            y_pred = token_to_float(tok)

            y_true = {"ESG": esg[t], "ENV": env[t], "SOC": soc[t], "GOV": gov[t]}[series]

            preds[series].append(y_pred)
            trues[series].append(y_true)

        hist["ESG"].append(esg[t])
        hist["ENV"].append(env[t])
        hist["SOC"].append(soc[t])
        hist["GOV"].append(gov[t])
        hist["RET"].append(ret[t])

    return preds, trues

@torch.no_grad()
def forecast_company_esg_only(ticker, esg, ret):
    hist = {
        "ESG": esg[:cfg.K].copy(),
        "RET": ret[:cfg.K].copy(),
    }

    preds = {"ESG": []}
    trues = {"ESG": []}

    max_t = min(len(esg) - 1, cfg.K + MAX_PRED_STEPS)

    for t in range(cfg.K, max_t):
        hist_tokens = {k: [f"<{k}_{v:.2f}>" for v in hist[k][-cfg.K:]] for k in hist}

        prompt = build_prompt(
            ticker=ticker,
            start_dt=None,
            end_dt=None,
            hist_tokens_map=hist_tokens,
            senti_tokens=[],
            news_text="(no news)",
            target_series="ESG"
        )

        enc = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(cfg.DEVICE)
        out = lm(**enc, output_hidden_states=False)
        logits = out.logits

        ids = allowed_ids["ESG"]
        logit_last = logits[0, -1]
        sel = logit_last.index_select(0, ids)
        j = int(torch.argmax(sel).item())
        pid = int(ids[j].item())
        tok = tokenizer.convert_ids_to_tokens(pid)

        y_pred = token_to_float(tok)
        y_true = esg[t]

        preds["ESG"].append(y_pred)
        trues["ESG"].append(y_true)

        hist["ESG"].append(esg[t])
        hist["RET"].append(ret[t])

    return preds, trues

GLOBAL_RESULTS = {}

for EX in EXCHANGES:
    print("\n" + "="*80)
    print("EXCHANGE:", EX)
    print("="*80)

    base = EX

    path_esg = os.path.join(base, "esg_risk_ratings_1.txt")
    path_env = os.path.join(base, "PaperReady_e_scores.txt")
    path_soc = os.path.join(base, "PaperReady_s_scores.txt")
    path_gov = os.path.join(base, "PaperReady_g_scores.txt")

    if not os.path.exists(path_esg):
        print("Missing folder:", EX, "-> skipping")
        continue

    ESG = load_series_file(path_esg, expected_series="ESG")
    ENV = load_series_file(path_env, expected_series="ENV") if os.path.exists(path_env) else {}
    SOC = load_series_file(path_soc, expected_series="SOC") if os.path.exists(path_soc) else {}
    GOV = load_series_file(path_gov, expected_series="GOV") if os.path.exists(path_gov) else {}

    print(f"Loaded: ESG={len(ESG)} ENV={len(ENV)} SOC={len(SOC)} GOV={len(GOV)}")

    have_full = (len(ENV) > 0 and len(SOC) > 0 and len(GOV) > 0)

    RETURNS = {}
    for tkr in ESG.keys():
        L = len(ESG[tkr])
        dates = list(range(L))
        try:
            r = align_and_fill_returns(tkr, dates)
            if len(r) < L:
                r = (r + [0.0]*L)[:L]
            else:
                r = r[:L]
            RETURNS[tkr] = r
        except:
            RETURNS[tkr] = [0.0] * L

    METRICS = defaultdict(lambda: defaultdict(list))

    if have_full:
        COMMON = sorted(set(ESG) & set(ENV) & set(SOC) & set(GOV))
        print(f"Common tickers ESG∩ENV∩SOC∩GOV = {len(COMMON)}")

        for ticker in COMMON:
            L = min(len(ESG[ticker]), len(ENV[ticker]), len(SOC[ticker]), len(GOV[ticker]), len(RETURNS[ticker]))
            if L < cfg.K + 5:
                continue

            preds, trues = forecast_company_multiseries(
                ticker,
                ESG[ticker][:L],
                ENV[ticker][:L],
                SOC[ticker][:L],
                GOV[ticker][:L],
                RETURNS[ticker][:L],
            )

            for series in ["ESG", "ENV", "SOC", "GOV"]:
                y = trues[series]
                yhat = preds[series]
                if len(y) == 0:
                    continue
                METRICS[series]["MSE"].append(mse(y, yhat))
                METRICS[series]["MAE"].append(mae(y, yhat))
                METRICS[series]["MAPE"].append(mape(y, yhat))
                METRICS[series]["TheilU"].append(theils_u(y, yhat))
                METRICS[series]["Bias"].append(bias(y, yhat))

    else:
        print("ENV/SOC/GOV appear empty for this exchange -> running ESG-only evaluation")
        for ticker in ESG.keys():
            L = min(len(ESG[ticker]), len(RETURNS[ticker]))
            if L < cfg.K + 5:
                continue

            preds, trues = forecast_company_esg_only(
                ticker,
                ESG[ticker][:L],
                RETURNS[ticker][:L],
            )

            y = trues["ESG"]
            yhat = preds["ESG"]
            if len(y) == 0:
                continue

            METRICS["ESG"]["MSE"].append(mse(y, yhat))
            METRICS["ESG"]["MAE"].append(mae(y, yhat))
            METRICS["ESG"]["MAPE"].append(mape(y, yhat))
            METRICS["ESG"]["TheilU"].append(theils_u(y, yhat))
            METRICS["ESG"]["Bias"].append(bias(y, yhat))

    print("\nRESULTS FOR", EX)
    print("-"*80)
    for series in ["ESG", "ENV", "SOC", "GOV"]:
        print("\nSeries:", series)
        for m in ["MSE", "MAE", "MAPE", "TheilU", "Bias"]:
            vals = METRICS[series][m]
            if len(vals) == 0:
                print(f"  {m}: N/A")
            else:
                print(f"  {m}: {np.mean(vals):.4f}")

    GLOBAL_RESULTS[EX] = METRICS

print("\n\n==================== GLOBAL SUMMARY ====================")
for EX in GLOBAL_RESULTS:
    print("\n", EX)
    for series in ["ESG", "ENV", "SOC", "GOV"]:
        print(" ", series)
        for m in ["MSE", "MAE", "MAPE", "TheilU", "Bias"]:
            vals = GLOBAL_RESULTS[EX][series][m]
            if len(vals) == 0:
                print("   ", m, ": N/A")
            else:
                print("   ", m, ":", float(np.mean(vals)))

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

def plot_cos_vs_dtw(
    df,
    series,
    max_pairs=3000,
    seed=0,
    bins=30,
    save_prefix="cos_vs_dtw"
):
    rng = np.random.default_rng(seed)

    def normalize_ts(x):
        x = np.asarray(x, dtype=float)
        mu = x.mean()
        sigma = x.std()
        if sigma < 1e-12:
            return x * 0.0
        return (x - mu) / sigma

    X = df[series].values
    n = len(X)

    pairs = set()
    while len(pairs) < min(max_pairs, n * (n - 1) // 2):
        i = rng.integers(0, n)
        j = rng.integers(0, n)
        if i != j:
            pairs.add((min(i, j), max(i, j)))

    cos_vals = []
    dtw_vals = []

    def dtw_distance(a, b):
        n, m = len(a), len(b)
        D = np.full((n+1, m+1), np.inf)
        D[0, 0] = 0.0
        for i in range(1, n+1):
            for j in range(1, m+1):
                cost = abs(a[i-1] - b[j-1])
                D[i, j] = cost + min(D[i-1, j], D[i, j-1], D[i-1, j-1])
        return D[n, m]

    for i, j in pairs:
        s1 = normalize_ts(X[i])
        s2 = normalize_ts(X[j])

        num = np.dot(s1, s2)
        den = np.linalg.norm(s1) * np.linalg.norm(s2) + 1e-12
        cos = num / den

        dtw = dtw_distance(s1, s2)

        cos_vals.append(cos)
        dtw_vals.append(dtw)

    cos_vals = np.array(cos_vals)
    dtw_vals = np.array(dtw_vals)

    spearman = pd.Series(cos_vals).corr(pd.Series(dtw_vals), method="spearman")

    plt.figure(figsize=(7, 5))
    plt.scatter(cos_vals, dtw_vals, s=6, alpha=0.4)
    plt.xlabel("Cosine similarity (z-normalized series)")
    plt.ylabel("DTW distance (z-normalized series)")
    plt.title(f"{series}: Spearman ρ = {spearman:.3f}")
    plt.grid(True, alpha=0.3)

    fname = f"{save_prefix}_{series}.png"
    plt.tight_layout()
    plt.savefig(fname, dpi=150)
    plt.close()

    return {
        "series": series,
        "spearman": spearman,
        "cos_mean": float(np.mean(cos_vals)),
        "dtw_mean": float(np.mean(dtw_vals)),
        "n_pairs": len(cos_vals),
    }

"""# Revised phase 10"""


import numpy as np
import pandas as pd
import torch
from typing import List, Tuple, Dict
from collections import defaultdict

if not hasattr(cfg, "K"):
    cfg.K = 32
if not hasattr(cfg, "DTW_BAND"):
    cfg.DTW_BAND = 8
if not hasattr(cfg, "SEED"):
    cfg.SEED = 1337

if not hasattr(cfg, "PRINT_ALL_TEST_SAMPLES"):
    cfg.PRINT_ALL_TEST_SAMPLES = False

if not hasattr(cfg, "ENFORCE_THRESHOLDS"):
    cfg.ENFORCE_THRESHOLDS = False
if not hasattr(cfg, "MSE_TARGET"):
    cfg.MSE_TARGET = 2.0
if not hasattr(cfg, "SPEARMAN_TARGET"):
    cfg.SPEARMAN_TARGET = 0.70

USE_AMP_EVAL = getattr(cfg, "USE_AMP_EVAL", False)
AMP_DTYPE_EVAL = getattr(cfg, "AMP_DTYPE_EVAL", torch.float16)

def _log(msg: str):
   print(msg)

def _display(df):
    if "display" in globals() and callable(globals()["display"]):
        globals()["display"](df)
    else:
        print(df)

@torch.no_grad()
def get_supervised_positions_strict(labels: torch.Tensor) -> torch.Tensor:
    """
    labels: [B,T], expects EXACTLY one supervised token (!=-100) per row.
    returns pos: [B,2] where pos[i] = [i, tpos_i]
    """
    B = labels.size(0)
    pos = (labels != -100).nonzero(as_tuple=False)
    assert pos.size(0) == B, (
        f"Expected exactly 1 supervised token per sample; got {pos.size(0)} for B={B}. "
        "Fix labels masking."
    )
    pos = pos[pos[:, 0].argsort()]
    assert torch.all(pos[:, 0].cpu() == torch.arange(B)), (
        "Supervised positions not aligned to batch order."
    )
    return pos

@torch.no_grad()
def supervised_tpos_strict(labels: torch.Tensor) -> torch.Tensor:
    pos = get_supervised_positions_strict(labels)
    return pos[:, 1]

@torch.no_grad()
def predict_token_at_supervised_pos(out_logits: torch.Tensor, labels: torch.Tensor, series_list: List[str]) -> List[str]:
    B, T, V = out_logits.shape
    pos = get_supervised_positions_strict(labels)
    preds = []
    for i in range(B):
        tpos = int(pos[i, 1].item())
        ser = series_list[i]
        logit_i = out_logits[i, tpos, :]

        ids = allowed_ids[ser]
        if ids.numel() == 0:
            pid = int(torch.argmax(logit_i).item())
        else:
            sel = logit_i.index_select(0, ids)
            j = int(torch.argmax(sel).item())
            pid = int(ids[j].item())

        preds.append(tokenizer.convert_ids_to_tokens(pid))
    return preds

@torch.no_grad()
def expected_numeric_pred(
    out_logits: torch.Tensor,
    labels: torch.Tensor,
    series_list: List[str],
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Returns:
      y_true_numeric [B]
      y_pred_numeric [B] expected value over allowed_ids[series]
    """
    B, T, V = out_logits.shape
    pos = get_supervised_positions_strict(labels)
    tpos = pos[:, 1]
    logits_last = out_logits[torch.arange(B, device=out_logits.device), tpos, :]

    y_true = np.zeros(B, dtype=float)
    y_pred = np.zeros(B, dtype=float)

    for i in range(B):
        ser = series_list[i]

        tid = int(labels[i, int(tpos[i].item())].item())
        true_tok = tokenizer.convert_ids_to_tokens(tid)
        try:
            y_true_val = float(token_to_float(true_tok))
        except Exception:
            y_true_val = 0.0

        ids = allowed_ids[ser]
        if ids.numel() == 0:
            pid = int(torch.argmax(logits_last[i]).item())
            pred_tok = tokenizer.convert_ids_to_tokens(pid)
            try:
                y_pred_val = float(token_to_float(pred_tok))
            except Exception:
                y_pred_val = y_true_val
        else:
            ev = expected_value_from_logits(logits_last[i:i+1], ids, id_to_val[ser])[0].item()
            y_pred_val = float(ev)

        y_true[i] = y_true_val
        y_pred[i] = y_pred_val

    return y_true, y_pred

def safe_mape(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    denom = np.maximum(np.abs(y_true), 1e-6)
    return float(np.mean(np.abs((y_pred - y_true) / denom)) * 100.0)

def bias(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    return float(np.mean(y_pred - y_true))

def theils_u2_with_prev(y_true: np.ndarray, y_pred: np.ndarray, y_prev: np.ndarray) -> float:
    rmse_p = np.sqrt(np.mean((y_pred - y_true) ** 2) + 1e-12)
    rmse_n = np.sqrt(np.mean((y_prev - y_true) ** 2) + 1e-12)
    return float(rmse_p / (rmse_n + 1e-12))

@torch.no_grad()
def eval_forecast_metrics(dataloader, split_name="TEST") -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    """
    Returns:
      df_series_macro: rows per series with macro metrics (equal weight per company)
      df_company: rows per (series,ticker) with metrics
      df_rows: per-sample rows (useful for ranking / reuse)
    """
    lm.eval()
    rows = []

    for batch_i, batch in enumerate(dataloader):
        bias_mat = build_attn_bias(batch["input_ids"], batch["target_series"])
        with attn_bias_context(bias_mat):
            if USE_AMP_EVAL:
                from torch.amp import autocast
                with autocast(device_type="cuda", dtype=AMP_DTYPE_EVAL):
                    out = lm(
                        input_ids=batch["input_ids"],
                        attention_mask=batch["attention_mask"],
                        labels=batch["labels"],
                        output_hidden_states=False,
                    )
            else:
                out = lm(
                    input_ids=batch["input_ids"],
                    attention_mask=batch["attention_mask"],
                    labels=batch["labels"],
                    output_hidden_states=False,
                )

        y_true, y_pred = expected_numeric_pred(out.logits, batch["labels"], batch["target_series"])

        y_prev = np.zeros(len(batch["target_series"]), dtype=float)
        for i in range(len(batch["target_series"])):
            ser = batch["target_series"][i]
            hv = batch["hist_vals"][i].get(ser, [])
            y_prev[i] = float(hv[-1]) if (hv is not None and len(hv) > 0) else float(y_true[i])

        if split_name == "TEST" and cfg.PRINT_ALL_TEST_SAMPLES and batch_i < 3:
            preds_tok = predict_token_at_supervised_pos(out.logits, batch["labels"], batch["target_series"])
            for i in range(min(5, len(preds_tok))):
                _log(
                    f"[{split_name} PRED] ticker={batch['ticker'][i]} series={batch['target_series'][i]} "
                    f"true={batch['target_token'][i]} pred_tok={preds_tok[i]} pred_val={y_pred[i]:.2f}"
                )
                print("----- PROMPT (HEAD) -----")
                print(batch["text"][i][:900])
                print("-------------------------\n")

        for i in range(len(y_true)):
            rows.append({
                "split": split_name,
                "ticker": batch["ticker"][i],
                "series": batch["target_series"][i],
                "t_index": int(batch["t_index"][i]),
                "y_true": float(y_true[i]),
                "y_pred": float(y_pred[i]),
                "y_prev": float(y_prev[i]),
            })

    df_rows = pd.DataFrame(rows)

    comp_rows = []
    for (ser, tick), g in df_rows.groupby(["series", "ticker"]):
        yt = g["y_true"].to_numpy(float)
        yp = g["y_pred"].to_numpy(float)
        yv = g["y_prev"].to_numpy(float)
        comp_rows.append({
            "split": split_name,
            "series": ser,
            "ticker": tick,
            "n": int(len(g)),
            "MSE": float(np.mean((yp - yt) ** 2)),
            "MAE": float(np.mean(np.abs(yp - yt))),
            "MAPE": safe_mape(yt, yp),
            "TheilsU2": theils_u2_with_prev(yt, yp, yv),
            "Bias": bias(yt, yp),
        })
    df_company = pd.DataFrame(comp_rows).sort_values(["series", "MSE"], ascending=[True, True])

    macro_rows = []
    for ser, g in df_company.groupby("series"):
        macro_rows.append({
            "split": split_name,
            "series": ser,
            "n_companies": int(g["ticker"].nunique()),
            "MSE": float(g["MSE"].mean()),
            "MAE": float(g["MAE"].mean()),
            "MAPE": float(g["MAPE"].mean()),
            "TheilsU2": float(g["TheilsU2"].mean()),
            "Bias": float(g["Bias"].mean()),
        })
    df_series_macro = pd.DataFrame(macro_rows).sort_values("series")

    return df_series_macro, df_company, df_rows

val_macro, val_company, df_pred_val = eval_forecast_metrics(val_dl, split_name="VAL")
test_macro, test_company, df_pred_test = eval_forecast_metrics(test_dl, split_name="TEST")

_log("[METRICS] VAL per-series macro:")
_display(val_macro)

_log("[METRICS] TEST per-series macro:")
_display(test_macro)

_log("[METRICS] TEST top-10 best company-series by MSE:")
_display(test_company.head(10))

_log("[METRICS] TEST bottom-10 worst company-series by MSE:")
_display(test_company.tail(10))

if cfg.ENFORCE_THRESHOLDS:
    bad = test_macro[test_macro["MSE"] > float(cfg.MSE_TARGET)]
    assert len(bad) == 0, f"[THRESHOLD FAIL] Some series macro MSE > {cfg.MSE_TARGET}:\n{bad}"

@torch.no_grad()
def compute_facet_vectors(dataloader) -> pd.DataFrame:
    """
    For each sample, compute z = hidden_state at facet_pos (FACET token position).
    Return df columns: ticker, series, t_index, z (np array), hist (np array)
    """
    lm.eval()
    rows = []

    for batch in dataloader:
        bias_mat = build_attn_bias(batch["input_ids"], batch["target_series"])
        with attn_bias_context(bias_mat):
            if USE_AMP_EVAL:
                from torch.amp import autocast
                with autocast(device_type="cuda", dtype=AMP_DTYPE_EVAL):
                    out = lm(
                        input_ids=batch["input_ids"],
                        attention_mask=batch["attention_mask"],
                        labels=batch["labels"],
                        output_hidden_states=True,
                    )
            else:
                out = lm(
                    input_ids=batch["input_ids"],
                    attention_mask=batch["attention_mask"],
                    labels=batch["labels"],
                    output_hidden_states=True,
                )

        h_last = out.hidden_states[-1]
        B, T, Hdim = h_last.shape

        for i in range(B):
            ser = batch["target_series"][i]
            p = int(batch["facet_pos"][i].item())
            if p < 0 or p >= T:
                continue

            z = h_last[i, p, :].detach().float().cpu().numpy()
            hist = np.array(batch["hist_vals"][i].get(ser, [])[-cfg.K:], dtype=float)

            rows.append({
                "ticker": batch["ticker"][i],
                "series": ser,
                "t_index": int(batch["t_index"][i]),
                "z": z,
                "hist": hist,
            })

    return pd.DataFrame(rows)

df_z = compute_facet_vectors(test_dl)
_log(f"[GEOM] extracted facet vectors on TEST: n={len(df_z)}")
_display(df_z.head(3))

def pairwise_geom_metrics_for_series(df_ser: pd.DataFrame, max_pairs: int = 2000) -> Dict[str, float]:
    """
    Sample random pairs, compute:
      - DTW distance in raw space (hist)
      - L2 distance in z space
    Report Spearman correlation.
    """
    if len(df_ser) < 5:
        return {"spearman": float("nan"), "pairs": 0}

    rng = np.random.default_rng(cfg.SEED)
    idx = np.arange(len(df_ser))

    n_possible = len(df_ser) * (len(df_ser) - 1) // 2
    n_pairs = int(min(max_pairs, n_possible))

    pairs = [tuple(rng.choice(idx, size=2, replace=False)) for _ in range(n_pairs)]

    dtw_list = []
    z_list = []
    for i, j in pairs:
        xi = df_ser.iloc[i]["hist"]
        xj = df_ser.iloc[j]["hist"]
        if len(xi) == 0 or len(xj) == 0:
            continue

        dd = dtw_sakoe_chiba(xi.tolist(), xj.tolist(), band=cfg.DTW_BAND)

        zi = df_ser.iloc[i]["z"]
        zj = df_ser.iloc[j]["z"]
        zd = float(np.linalg.norm(zi - zj))

        dtw_list.append(float(dd))
        z_list.append(float(zd))

    if len(dtw_list) < 10:
        return {"spearman": float("nan"), "pairs": len(dtw_list)}

    try:
        from scipy.stats import spearmanr
        sp = float(spearmanr(dtw_list, z_list).correlation)
    except Exception:
        sp = float("nan")

    return {"spearman": sp, "pairs": len(dtw_list)}

geom_rows = []
for ser in ["ESG", "ENV", "SOC", "GOV"]:
    df_ser = df_z[df_z["series"] == ser]
    m = pairwise_geom_metrics_for_series(df_ser, max_pairs=2000)
    geom_rows.append({"series": ser, **m})
df_geom = pd.DataFrame(geom_rows)

_log("[GEOM] Spearman corr(DTW, z-dist) per series:")
_display(df_geom)

if cfg.ENFORCE_THRESHOLDS:
    bad = df_geom[(~df_geom["spearman"].isna()) & (df_geom["spearman"] < float(cfg.SPEARMAN_TARGET))]
    assert len(bad) == 0, f"[THRESHOLD FAIL] Some series Spearman < {cfg.SPEARMAN_TARGET}:\n{bad}"

def knn_overlap(df_ser: pd.DataFrame, k: int = 5, n_queries: int = 30) -> float:
    """
    For random queries:
      - find top-k neighbors by z-dist
      - find top-k neighbors by DTW
      compute overlap fraction
    """
    if len(df_ser) < k + 2:
        return float("nan")

    rng = np.random.default_rng(cfg.SEED)
    idxs = np.arange(len(df_ser))
    Q = int(min(n_queries, len(df_ser)))

    overlaps = []
    for qi in rng.choice(idxs, size=Q, replace=False):
        zq = df_ser.iloc[qi]["z"]
        xq = df_ser.iloc[qi]["hist"]

        zd = []
        for j in idxs:
            if j == qi:
                continue
            zj = df_ser.iloc[j]["z"]
            zd.append((j, float(np.linalg.norm(zq - zj))))
        z_neighbors = [j for j, _ in sorted(zd, key=lambda x: x[1])[:k]]

        dd = []
        for j in idxs:
            if j == qi:
                continue
            xj = df_ser.iloc[j]["hist"]
            if len(xq) == 0 or len(xj) == 0:
                continue
            d = dtw_sakoe_chiba(xq.tolist(), xj.tolist(), band=cfg.DTW_BAND)
            dd.append((j, float(d)))
        d_neighbors = [j for j, _ in sorted(dd, key=lambda x: x[1])[:k]]

        if len(d_neighbors) < k:
            continue

        ov = len(set(z_neighbors).intersection(set(d_neighbors))) / float(k)
        overlaps.append(ov)

    return float(np.mean(overlaps)) if len(overlaps) else float("nan")

overlap_rows = []
for ser in ["ESG", "ENV", "SOC", "GOV"]:
    df_ser = df_z[df_z["series"] == ser]
    ov = knn_overlap(df_ser, k=5, n_queries=30)
    overlap_rows.append({"series": ser, "knn_overlap@5": ov})
df_overlap = pd.DataFrame(overlap_rows)

_log("[RETRIEVE] kNN overlap@5 (z-space vs DTW) per series:")
_display(df_overlap)

def ranking_experiment(df_pred: pd.DataFrame) -> pd.DataFrame:
    out = []
    for tick, g in df_pred.groupby("ticker"):
        g_esg = g[g["series"] == "ESG"]
        g_ret = g[g["series"] == "RET"]

        esg_delta = float(np.mean(g_esg["y_pred"] - g_esg["y_prev"])) if len(g_esg) > 0 else 0.0
        ret_delta = float(np.mean(g_ret["y_pred"] - g_ret["y_prev"])) if len(g_ret) > 0 else 0.0

        score = (-esg_delta) + (ret_delta)
        out.append({
            "ticker": tick,
            "esg_delta_pred": esg_delta,
            "ret_delta_pred": ret_delta,
            "score": score,
        })
    return pd.DataFrame(out).sort_values("score", ascending=False)

rank_df = ranking_experiment(df_pred_test)
_log("[RANK] Top-15 tickers by combined (ESG down good + RET up good) predicted score:")
_display(rank_df.head(15))

val_macro.to_csv("val_series_macro.csv", index=False)
test_macro.to_csv("test_series_macro.csv", index=False)
test_company.to_csv("test_company_metrics.csv", index=False)
df_geom.to_csv("geom_spearman_by_series.csv", index=False)
df_overlap.to_csv("knn_overlap_by_series.csv", index=False)
rank_df.to_csv("ranking_experiment.csv", index=False)
df_z.to_pickle("facet_vectors_test.pkl")

_log("[SAVED] metrics CSV files + facet_vectors_test.pkl written to current folder.")
_log("Phase 10 complete (revised, consistent with Phase 9).")

"""# Out of sample

"""


import os
import re
import numpy as np
import pandas as pd
from collections import defaultdict
import torch

lm.eval()

EXCHANGES = ["AMEX", "CBOE", "NASDAQ", "NYSE", "OTC", "PNK"]
MAX_PRED_STEPS = 10

FREE_RUN = False

def mse(y, yhat):
    y = np.asarray(y, dtype=float)
    yhat = np.asarray(yhat, dtype=float)
    return float(np.mean((yhat - y) ** 2))

def mae(y, yhat):
    y = np.asarray(y, dtype=float)
    yhat = np.asarray(yhat, dtype=float)
    return float(np.mean(np.abs(yhat - y)))

def safe_mape(y, yhat):
    y = np.asarray(y, dtype=float)
    yhat = np.asarray(yhat, dtype=float)
    denom = np.maximum(np.abs(y), 1e-6)
    return float(np.mean(np.abs((yhat - y) / denom)) * 100.0)

def bias(y, yhat):
    y = np.asarray(y, dtype=float)
    yhat = np.asarray(yhat, dtype=float)
    return float(np.mean(yhat - y))

def theils_u2_with_prev(y_true, y_pred, y_prev):
    """
    Phase 10 style Theil-U2:
      U2 = RMSE(pred vs true) / RMSE(prev vs true)
    """
    y_true = np.asarray(y_true, dtype=float)
    y_pred = np.asarray(y_pred, dtype=float)
    y_prev = np.asarray(y_prev, dtype=float)
    rmse_p = np.sqrt(np.mean((y_pred - y_true)**2) + 1e-12)
    rmse_n = np.sqrt(np.mean((y_prev - y_true)**2) + 1e-12)
    return float(rmse_p / (rmse_n + 1e-12))

_num_re = re.compile(r"[-+]?\d*\.\d+|[-+]?\d+")

def load_series_file(path, expected_series=None):
    """
    Returns dict[ticker] = list[float]

    Supports:
      A) Plain:  TICKER 1.0 2.0 3.0 ...
      B) Tagged: Company: TICKER ESG: 1.0 2.0 ...
                 Company: TICKER ENV: ...
                 Company: TICKER SOC: ...
                 Company: TICKER GOV: ...
      C) Lines wrapped in dumps: {'text': 'Company: ...'} or quotes "'Company:" etc.

    expected_series: if set (e.g. "ESG"), only accepts tagged lines for that series.
                     Plain format is still accepted.
    """
    data = {}
    if not os.path.exists(path):
        return data

    with open(path, "r", encoding="utf-8") as f:
        for raw in f:
            line = raw.strip()
            if not line:
                continue

            line = line.strip().strip(",")
            line = line.replace("{'text':", "").replace('{"text":', "")
            line = line.strip().strip("'").strip('"')
            line = line.strip().strip("'").strip('"')

            if "Company:" in line:
                m = re.search(r"Company:\s*([A-Za-z0-9\.\-\_]+)\s+([A-Za-z]+)\s*:\s*(.*)$", line)
                if m:
                    ticker = m.group(1)
                    series = m.group(2).upper()
                    tail = m.group(3)

                    if expected_series is None or series == expected_series.upper():
                        nums = _num_re.findall(tail)
                        if nums:
                            data[ticker] = [float(x) for x in nums]
                        continue

            parts = line.split()
            if len(parts) >= 2:
                ticker = parts[0].strip("'").strip('"')
                try:
                    vals = [float(x) for x in parts[1:]]
                    data[ticker] = vals
                except ValueError:
                    pass

    return data

@torch.no_grad()
def expected_value_decode(logits_last_1v: torch.Tensor, series: str) -> float:
    """
    logits_last_1v: [1, V] float tensor
    returns numeric expected value over allowed_ids[series]
    """
    ids = allowed_ids[series]
    if ids is None or ids.numel() == 0:
        pid = int(torch.argmax(logits_last_1v[0]).item())
        tok = tokenizer.convert_ids_to_tokens(pid)
        try:
            return float(token_to_float(tok))
        except Exception:
            return 0.0

    ev = expected_value_from_logits(logits_last_1v, ids, id_to_val[series])[0]
    return float(ev.item())

@torch.no_grad()
def forecast_company_oos(
    ticker,
    esg, env, soc, gov,
    ret,
):
    hist = {
        "ESG": esg[:cfg.K].copy(),
        "ENV": env[:cfg.K].copy(),
        "SOC": soc[:cfg.K].copy(),
        "GOV": gov[:cfg.K].copy(),
        "RET": ret[:cfg.K].copy(),
    }

    preds = {s: [] for s in ["ESG", "ENV", "SOC", "GOV"]}
    trues = {s: [] for s in ["ESG", "ENV", "SOC", "GOV"]}
    prevs = {s: [] for s in ["ESG", "ENV", "SOC", "GOV"]}

    max_t = min(len(esg), cfg.K + MAX_PRED_STEPS)

    for t in range(cfg.K, max_t):
        for series in ["ESG", "ENV", "SOC", "GOV"]:
            hist_tokens = {k: [f"<{k}_{v:.2f}>" for v in hist[k][-cfg.K:]] for k in hist}

            prompt = build_prompt(
                ticker=ticker,
                start_dt=None,
                end_dt=None,
                hist_tokens_map=hist_tokens,
                senti_tokens=[],
                news_text="(no news)",
                target_series=series
            )

            enc = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(cfg.DEVICE)
            out = lm(**enc, output_hidden_states=False)
            logits = out.logits

            logits_last = logits[:, -1, :]
            y_pred = expected_value_decode(logits_last, series)

            y_true = {"ESG": esg[t], "ENV": env[t], "SOC": soc[t], "GOV": gov[t]}[series]
            y_prev = hist[series][-1] if len(hist[series]) else y_true

            preds[series].append(float(y_pred))
            trues[series].append(float(y_true))
            prevs[series].append(float(y_prev))

        if not FREE_RUN:
            hist["ESG"].append(esg[t])
            hist["ENV"].append(env[t])
            hist["SOC"].append(soc[t])
            hist["GOV"].append(gov[t])
            hist["RET"].append(ret[t])
        else:
            hist["ESG"].append(preds["ESG"][-1])
            hist["ENV"].append(preds["ENV"][-1])
            hist["SOC"].append(preds["SOC"][-1])
            hist["GOV"].append(preds["GOV"][-1])
            hist["RET"].append(ret[t])

    return preds, trues, prevs

GLOBAL_RESULTS = {}

for EX in EXCHANGES:
    print("\n" + "=" * 80)
    print("EXCHANGE:", EX, "| FREE_RUN:", FREE_RUN)
    print("=" * 80)

    base = EX
    path_esg = os.path.join(base, "esg_risk_ratings_1.txt")
    path_env = os.path.join(base, "PaperReady_e_scores.txt")
    path_soc = os.path.join(base, "PaperReady_s_scores.txt")
    path_gov = os.path.join(base, "PaperReady_g_scores.txt")

    if not os.path.exists(path_esg):
        print("Missing folder:", EX, "-> skipping")
        continue

    ESG = load_series_file(path_esg, expected_series="ESG")
    ENV = load_series_file(path_env, expected_series="ENV")
    SOC = load_series_file(path_soc, expected_series="SOC")
    GOV = load_series_file(path_gov, expected_series="GOV")

    print(f"Loaded: ESG={len(ESG)} ENV={len(ENV)} SOC={len(SOC)} GOV={len(GOV)}")

    RETURNS = {}
    for tkr in ESG.keys():
        n = len(ESG[tkr])
        dates = list(range(n))
        try:
            r = align_and_fill_returns(tkr, dates)
            if r is None:
                r = [0.0] * n
            if len(r) != n:
                if len(r) < n:
                    r = r + [0.0] * (n - len(r))
                else:
                    r = r[:n]
            RETURNS[tkr] = r
        except Exception:
            RETURNS[tkr] = [0.0] * n

    rows = []


    for ticker in ESG.keys():
        if ticker not in ENV or ticker not in SOC or ticker not in GOV:
            continue
        if len(ESG[ticker]) < cfg.K + 1:
            continue

        L = min(len(ESG[ticker]), len(ENV[ticker]), len(SOC[ticker]), len(GOV[ticker]), len(RETURNS[ticker]))
        if L < cfg.K + 1:
            continue

        preds, trues, prevs = forecast_company_oos(
            ticker,
            ESG[ticker][:L],
            ENV[ticker][:L],
            SOC[ticker][:L],
            GOV[ticker][:L],
            RETURNS[ticker][:L],
        )

        for series in ["ESG", "ENV", "SOC", "GOV"]:
            yhat = preds[series]
            y = trues[series]
            yprev = prevs[series]
            for i in range(len(y)):
                rows.append({
                    "exchange": EX,
                    "ticker": ticker,
                    "series": series,
                    "horizon_step": i + 1,
                    "y_true": float(y[i]),
                    "y_pred": float(yhat[i]),
                    "y_prev": float(yprev[i]),
                })

    df = pd.DataFrame(rows)
    if len(df) == 0:
        print("No rows for exchange:", EX, "(likely no ticker overlap across ESG/ENV/SOC/GOV)")
        continue

    comp_rows = []
    for (ser, tick), g in df.groupby(["series", "ticker"]):
        yt = g["y_true"].to_numpy(float)
        yp = g["y_pred"].to_numpy(float)
        yv = g["y_prev"].to_numpy(float)
        comp_rows.append({
            "exchange": EX,
            "series": ser,
            "ticker": tick,
            "n": int(len(g)),
            "MSE": mse(yt, yp),
            "MAE": mae(yt, yp),
            "MAPE": safe_mape(yt, yp),
            "TheilsU2": theils_u2_with_prev(yt, yp, yv),
            "Bias": bias(yt, yp),
        })
    df_company = pd.DataFrame(comp_rows).sort_values(["series", "MSE"])

    macro_rows = []
    for ser, g in df_company.groupby("series"):
        macro_rows.append({
            "exchange": EX,
            "series": ser,
            "n_companies": int(g["ticker"].nunique()),
            "MSE": float(g["MSE"].mean()),
            "MAE": float(g["MAE"].mean()),
            "MAPE": float(g["MAPE"].mean()),
            "TheilsU2": float(g["TheilsU2"].mean()),
            "Bias": float(g["Bias"].mean()),
        })
    df_macro = pd.DataFrame(macro_rows).sort_values("series")

    print("\n[OOS METRICS] Macro per series (Phase-10 style):")
    print(df_macro.to_string(index=False))

    GLOBAL_RESULTS[EX] = {
        "df_rows": df,
        "df_company": df_company,
        "df_macro": df_macro,
    }

print("\n\n==================== GLOBAL SUMMARY (Macro MSE) ====================")
for EX, pack in GLOBAL_RESULTS.items():
    df_macro = pack["df_macro"]
    print("\n", EX, "| FREE_RUN:", FREE_RUN)
    for ser in ["ESG", "ENV", "SOC", "GOV"]:
        g = df_macro[df_macro["series"] == ser]
        if len(g) == 0:
            print(" ", ser, ": N/A")
        else:
            print(" ", ser, "MSE:", float(g["MSE"].iloc[0]))

"""# OUT OF SAMPLE"""


import os
import re
import numpy as np
import pandas as pd
from collections import defaultdict
import torch

lm.eval()

EXCHANGES = ["AMEX", "CBOE", "NASDAQ", "NYSE", "OTC", "PNK"]
MAX_PRED_STEPS = 10

FREE_RUN = False
PRINT_ALL_PROMPTS = True
PRINT_PRED_LINES = True

LIMIT_TICKERS = None
LIMIT_STEPS = None
ONLY_SERIES = ["ESG","ENV","SOC","GOV"]

def mse(y, yhat):
    y = np.asarray(y, dtype=float)
    yhat = np.asarray(yhat, dtype=float)
    return float(np.mean((yhat - y) ** 2))

def mae(y, yhat):
    y = np.asarray(y, dtype=float)
    yhat = np.asarray(yhat, dtype=float)
    return float(np.mean(np.abs(yhat - y)))

def safe_mape(y, yhat):
    y = np.asarray(y, dtype=float)
    yhat = np.asarray(yhat, dtype=float)
    denom = np.maximum(np.abs(y), 1e-6)
    return float(np.mean(np.abs((yhat - y) / denom)) * 100.0)

def bias(y, yhat):
    y = np.asarray(y, dtype=float)
    yhat = np.asarray(yhat, dtype=float)
    return float(np.mean(yhat - y))

def theils_u2_with_prev(y_true, y_pred, y_prev):
    y_true = np.asarray(y_true, dtype=float)
    y_pred = np.asarray(y_pred, dtype=float)
    y_prev = np.asarray(y_prev, dtype=float)
    rmse_p = np.sqrt(np.mean((y_pred - y_true)**2) + 1e-12)
    rmse_n = np.sqrt(np.mean((y_prev - y_true)**2) + 1e-12)
    rmse_n = max(float(rmse_n), 1e-3)
    return float(rmse_p / rmse_n)

_num_re = re.compile(r"[-+]?\d*\.\d+|[-+]?\d+")

def load_series_file(path, expected_series=None):
    """
    Returns dict[ticker] = list[float]
    Supports:
      A) Plain:  TICKER 1.0 2.0 3.0 ...
      B) Tagged: Company: TICKER ESG: 1.0 2.0 ...
      C) Dump artifacts: {'text': 'Company: ...'} or quoted lines
    """
    data = {}
    if not os.path.exists(path):
        return data

    with open(path, "r", encoding="utf-8") as f:
        for raw in f:
            line = raw.strip()
            if not line:
                continue

            line = line.strip().strip(",")
            line = line.replace("{'text':", "").replace('{"text":', "")
            line = line.strip().strip("'").strip('"')
            line = line.strip().strip("'").strip('"')

            if "Company:" in line:
                m = re.search(r"Company:\s*([A-Za-z0-9\.\-\_]+)\s+([A-Za-z]+)\s*:\s*(.*)$", line)
                if m:
                    ticker = m.group(1)
                    series = m.group(2).upper()
                    tail = m.group(3)
                    if expected_series is None or series == expected_series.upper():
                        nums = _num_re.findall(tail)
                        if nums:
                            data[ticker] = [float(x) for x in nums]
                        continue

            parts = line.split()
            if len(parts) >= 2:
                ticker = parts[0].strip("'").strip('"')
                try:
                    vals = [float(x) for x in parts[1:]]
                    data[ticker] = vals
                except ValueError:
                    pass

    return data

_SER_CACHE = {}

def _build_series_cache(series: str):
    ids = allowed_ids.get(series, None)
    if ids is None or ids.numel() == 0:
        _SER_CACHE[series] = (np.array([0.0], dtype=float), [f"<{series}_0.00>"])
        return
    vals = id_to_val[series]
    vals_np = vals.detach().cpu().numpy().astype(float)
    toks = tokenizer.convert_ids_to_tokens(ids.detach().cpu().tolist())
    _SER_CACHE[series] = (vals_np, toks)

def snap_value_to_vadt_token(series: str, v: float) -> str:
    if series not in _SER_CACHE:
        _build_series_cache(series)
    vals_np, toks = _SER_CACHE[series]
    j = int(np.argmin(np.abs(vals_np - float(v))))
    return toks[j]

def build_hist_tokens_map_from_float_hist(hist_float_map: dict) -> dict:
    out = {}
    for ser, arr in hist_float_map.items():
        toks = []
        for v in arr[-cfg.K:]:
            if ser in ["ESG","ENV","SOC","GOV","RET"]:
                toks.append(snap_value_to_vadt_token(ser, float(v)))
            else:
                toks.append(str(v))
        out[ser] = toks
    return out

@torch.no_grad()
def expected_value_decode(logits_last_1v: torch.Tensor, series: str) -> float:
    ids = allowed_ids.get(series, None)
    if ids is None or ids.numel() == 0:
        pid = int(torch.argmax(logits_last_1v[0]).item())
        tok = tokenizer.convert_ids_to_tokens(pid)
        try:
            return float(token_to_float(tok))
        except Exception:
            return 0.0
    ev = expected_value_from_logits(logits_last_1v, ids, id_to_val[series])[0]
    return float(ev.item())

@torch.no_grad()
def restricted_argmax_token(logits_last_1v: torch.Tensor, series: str) -> str:
    ids = allowed_ids.get(series, None)
    if ids is None or ids.numel() == 0:
        pid = int(torch.argmax(logits_last_1v[0]).item())
        return tokenizer.convert_ids_to_tokens(pid)
    sel = logits_last_1v[0].index_select(0, ids)
    j = int(torch.argmax(sel).item())
    pid = int(ids[j].item())
    return tokenizer.convert_ids_to_tokens(pid)

@torch.no_grad()
def forecast_company_oos_verbose(ticker, esg, env, soc, gov, ret, exchange_name=""):
    hist = {
        "ESG": list(esg[:cfg.K]),
        "ENV": list(env[:cfg.K]),
        "SOC": list(soc[:cfg.K]),
        "GOV": list(gov[:cfg.K]),
        "RET": list(ret[:cfg.K]),
    }

    preds = {s: [] for s in ONLY_SERIES}
    trues = {s: [] for s in ONLY_SERIES}
    prevs = {s: [] for s in ONLY_SERIES}
    argmax_toks = {s: [] for s in ONLY_SERIES}

    max_t = min(len(esg), cfg.K + MAX_PRED_STEPS)
    if LIMIT_STEPS is not None:
        max_t = min(max_t, cfg.K + int(LIMIT_STEPS))

    for t in range(cfg.K, max_t):
        hist_tokens_map = build_hist_tokens_map_from_float_hist(hist)

        for series in ONLY_SERIES:
            prompt = build_prompt(
                ticker=ticker,
                start_dt=None,
                end_dt=None,
                hist_tokens_map=hist_tokens_map,
                senti_tokens=[],
                news_text="(no news)",
                target_series=series
            )

            enc = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(cfg.DEVICE)
            out = lm(**enc, output_hidden_states=False)
            logits_last = out.logits[:, -1, :]

            pred_tok = restricted_argmax_token(logits_last, series)
            pred_ev  = expected_value_decode(logits_last, series)

            y_true = {"ESG": esg[t], "ENV": env[t], "SOC": soc[t], "GOV": gov[t]}[series]
            y_prev = hist[series][-1] if len(hist[series]) else y_true

            argmax_toks[series].append(pred_tok)
            preds[series].append(float(pred_ev))
            trues[series].append(float(y_true))
            prevs[series].append(float(y_prev))

            if PRINT_PRED_LINES:
                print(
                    f"[OOS PRED] ex={exchange_name} ticker={ticker} t={t} horizon={t-cfg.K+1:02d} series={series} "
                    f"argmax_tok={pred_tok} pred_ev={pred_ev:.4f} true={float(y_true):.4f} prev={float(y_prev):.4f}"
                )

            if PRINT_ALL_PROMPTS:
                print("----- OOS PROMPT (FULL) -----")
                print(prompt)
                print("----- END PROMPT -----\n")

        if not FREE_RUN:
            hist["ESG"].append(float(esg[t]))
            hist["ENV"].append(float(env[t]))
            hist["SOC"].append(float(soc[t]))
            hist["GOV"].append(float(gov[t]))
            hist["RET"].append(float(ret[t]))
        else:
            hist["ESG"].append(float(preds["ESG"][-1]))
            hist["ENV"].append(float(preds["ENV"][-1]))
            hist["SOC"].append(float(preds["SOC"][-1]))
            hist["GOV"].append(float(preds["GOV"][-1]))
            hist["RET"].append(float(ret[t]))

    return preds, trues, prevs, argmax_toks

def _describe_dict(d, name, nmax=200):
    if not d:
        print(name, "EMPTY")
        return
    arr = []
    for vals in d.values():
        if vals:
            arr.extend(vals[:nmax])
    if not arr:
        print(name, "EMPTY")
        return
    a = np.asarray(arr, dtype=float)
    print(f"[SCALE] {name} min={a.min():.4f} mean={a.mean():.4f} max={a.max():.4f}")

GLOBAL_RESULTS = {}

for EX in EXCHANGES:
    print("\n" + "=" * 80)
    print("EXCHANGE:", EX, "| FREE_RUN:", FREE_RUN)
    print("=" * 80)

    base = EX
    path_esg = os.path.join(base, "esg_risk_ratings_1.txt")
    path_env = os.path.join(base, "PaperReady_e_scores.txt")
    path_soc = os.path.join(base, "PaperReady_s_scores.txt")
    path_gov = os.path.join(base, "PaperReady_g_scores.txt")

    if not os.path.exists(path_esg):
        print("Missing folder:", EX, "-> skipping")
        continue

    ESG = load_series_file(path_esg, expected_series="ESG")
    ENV = load_series_file(path_env, expected_series="ENV")
    SOC = load_series_file(path_soc, expected_series="SOC")
    GOV = load_series_file(path_gov, expected_series="GOV")

    print(f"Loaded: ESG={len(ESG)} ENV={len(ENV)} SOC={len(SOC)} GOV={len(GOV)}")

    _describe_dict(ESG, f"{EX}:ESG")
    _describe_dict(ENV, f"{EX}:ENV")
    _describe_dict(SOC, f"{EX}:SOC")
    _describe_dict(GOV, f"{EX}:GOV")

    RETURNS = {}
    for tkr in ESG.keys():
        n = len(ESG[tkr])
        dates = list(range(n))
        try:
            r = align_and_fill_returns(tkr, dates)
            if r is None:
                r = [0.0] * n
            if len(r) != n:
                if len(r) < n:
                    r = r + [0.0] * (n - len(r))
                else:
                    r = r[:n]
            RETURNS[tkr] = r
        except Exception:
            RETURNS[tkr] = [0.0] * n

    rows = []

    tickers = list(ESG.keys())
    if LIMIT_TICKERS is not None:
        tickers = tickers[:int(LIMIT_TICKERS)]

    for ticker in tickers:
        if ticker not in ENV or ticker not in SOC or ticker not in GOV:
            continue

        L = min(len(ESG[ticker]), len(ENV[ticker]), len(SOC[ticker]), len(GOV[ticker]), len(RETURNS[ticker]))
        if L < cfg.K + 1:
            continue

        preds, trues, prevs, argmax_toks = forecast_company_oos_verbose(
            ticker,
            ESG[ticker][:L],
            ENV[ticker][:L],
            SOC[ticker][:L],
            GOV[ticker][:L],
            RETURNS[ticker][:L],
            exchange_name=EX
        )

        for series in ONLY_SERIES:
            yhat = preds[series]
            y = trues[series]
            yprev = prevs[series]
            toks = argmax_toks[series]
            for h in range(len(y)):
                rows.append({
                    "exchange": EX,
                    "ticker": ticker,
                    "series": series,
                    "horizon_step": h + 1,
                    "y_true": float(y[h]),
                    "y_pred": float(yhat[h]),
                    "y_prev": float(yprev[h]),
                    "pred_argmax_tok": toks[h],
                })

    df = pd.DataFrame(rows)
    if len(df) == 0:
        print("No rows for exchange:", EX, "(likely no overlap or too-short sequences)")
        continue

    comp_rows = []
    for (ser, tick), g in df.groupby(["series", "ticker"]):
        yt = g["y_true"].to_numpy(float)
        yp = g["y_pred"].to_numpy(float)
        yv = g["y_prev"].to_numpy(float)
        comp_rows.append({
            "exchange": EX,
            "series": ser,
            "ticker": tick,
            "n": int(len(g)),
            "MSE": mse(yt, yp),
            "MAE": mae(yt, yp),
            "MAPE": safe_mape(yt, yp),
            "TheilsU2": theils_u2_with_prev(yt, yp, yv),
            "Bias": bias(yt, yp),
        })
    df_company = pd.DataFrame(comp_rows).sort_values(["series", "MSE"])

    macro_rows = []
    for ser, g in df_company.groupby("series"):
        macro_rows.append({
            "exchange": EX,
            "series": ser,
            "n_companies": int(g["ticker"].nunique()),
            "MSE": float(g["MSE"].mean()),
            "MAE": float(g["MAE"].mean()),
            "MAPE": float(g["MAPE"].mean()),
            "TheilsU2": float(g["TheilsU2"].mean()),
            "Bias": float(g["Bias"].mean()),
        })
    df_macro = pd.DataFrame(macro_rows).sort_values("series")

    print("\n[OOS METRICS] Macro per series (Phase-10 style):")
    print(df_macro.to_string(index=False))

    base_rows = []
    for ser, g in df.groupby("series"):
        yt = g["y_true"].to_numpy(float)
        yprev = g["y_prev"].to_numpy(float)
        base_rows.append({
            "exchange": EX,
            "series": ser,
            "baseline_prev_MSE": float(np.mean((yprev - yt) ** 2)),
        })
    df_base = pd.DataFrame(base_rows).sort_values("series")
    print("\n[OOS BASELINE] Prev-value baseline MSE:")
    print(df_base.to_string(index=False))

    out_prefix = f"oos_{EX.lower()}"
    df.to_csv(f"{out_prefix}_rows.csv", index=False)
    df_company.to_csv(f"{out_prefix}_company.csv", index=False)
    df_macro.to_csv(f"{out_prefix}_macro.csv", index=False)
    df_base.to_csv(f"{out_prefix}_baseline.csv", index=False)
    print(f"[SAVED] {out_prefix}_*.csv written.")

    GLOBAL_RESULTS[EX] = {
        "df_rows": df,
        "df_company": df_company,
        "df_macro": df_macro,
        "df_baseline": df_base,
    }

print("\n\n==================== GLOBAL SUMMARY (Macro MSE) ====================")
for EX, pack in GLOBAL_RESULTS.items():
    df_macro = pack["df_macro"]
    print("\n", EX, "| FREE_RUN:", FREE_RUN)
    for ser in ONLY_SERIES:
        g = df_macro[df_macro["series"] == ser]
        if len(g) == 0:
            print(" ", ser, ": N/A")
        else:
            print(" ", ser, "MSE:", float(g["MSE"].iloc[0]))

"""# Experiment"""


import numpy as np
import torch
import matplotlib.pyplot as plt

SERIES_TO_TEST = "ESG"
MAX_SAMPLES = 180
PAIR_SPEARMAN = 2000
NN_K = 5
DTW_BAND = int(getattr(cfg, "DTW_BAND", 2))

print("\n" + "="*80)
print("[EXPERIMENT] DTW vs FACET Geometry Alignment")
print(f"[SETTINGS] series={SERIES_TO_TEST} MAX_SAMPLES={MAX_SAMPLES} PAIR_SPEARMAN={PAIR_SPEARMAN} NN_K={NN_K} DTW_BAND={DTW_BAND}")
print("="*80)

def _spearman_np(x, y):
    x = np.asarray(x, dtype=float)
    y = np.asarray(y, dtype=float)
    if x.size < 3:
        return np.nan

    def rankdata(a):
        temp = a.argsort()
        ranks = np.empty_like(temp, dtype=float)
        ranks[temp] = np.arange(len(a), dtype=float)
        sorted_a = a[temp]
        i = 0
        while i < len(a):
            j = i
            while j + 1 < len(a) and sorted_a[j + 1] == sorted_a[i]:
                j += 1
            if j > i:
                avg = (i + j) / 2.0
                ranks[temp[i:j+1]] = avg
            i = j + 1
        return ranks

    rx = rankdata(x)
    ry = rankdata(y)
    rx = rx - rx.mean()
    ry = ry - ry.mean()
    denom = (np.sqrt((rx**2).sum()) * np.sqrt((ry**2).sum()))
    if denom < 1e-12:
        return np.nan
    return float((rx * ry).sum() / denom)

def _extract_traj(sample_hist_or_traj, ser):
    arr = sample_hist_or_traj.get(ser, []) or []
    arr = [float(v) for v in arr if np.isfinite(float(v))]
    return arr

lm.eval()
Z_list = []
traj_list = []

with torch.no_grad():
    for batch in val_dl:
        series_list = batch["target_series"]
        idx = [i for i, s in enumerate(series_list) if s == SERIES_TO_TEST]
        if len(idx) == 0:
            continue

        out = lm(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["labels"],
            output_hidden_states=True,
        )
        h_last = out.hidden_states[-1]
        fpos = batch["facet_pos"]

        has_traj = ("traj_vals" in batch)
        source_key = "traj_vals" if has_traj else "hist_vals"

        for i in idx:
            p = int(fpos[i].item())
            if p < 0 or p >= h_last.size(1):
                continue

            z = h_last[i, p, :].detach().float().cpu().numpy()

            sample_dict = batch[source_key][i]
            tr = _extract_traj(sample_dict, SERIES_TO_TEST)
            if len(tr) < 3:
                continue

            Z_list.append(z)
            traj_list.append(tr)

            if len(Z_list) >= MAX_SAMPLES:
                break
        if len(Z_list) >= MAX_SAMPLES:
            break

N = len(Z_list)
print(f"[COLLECT] gathered N={N} samples for series={SERIES_TO_TEST}")
if N < 30:
    print("[WARNING] Too few samples to make the plots convincing. Increase MAX_SAMPLES or check data.")
    raise SystemExit

Z = np.stack(Z_list, axis=0)

rng = np.random.default_rng(int(getattr(cfg, "SEED", 0)) + 2026)
n_pairs = min(PAIR_SPEARMAN, N*(N-1)//2)
pairs = set()
while len(pairs) < n_pairs:
    i, j = rng.integers(0, N), rng.integers(0, N)
    if i == j:
        continue
    if i > j:
        i, j = j, i
    pairs.add((i, j))
pairs = list(pairs)

dtw_vals = []
z_vals = []

for (i, j) in pairs:
    xi = traj_list[i]
    xj = traj_list[j]
    dd = float(dtw_sakoe_chiba(xi, xj, band=DTW_BAND))
    zd = float(np.linalg.norm(Z[i] - Z[j]))
    if np.isfinite(dd) and np.isfinite(zd):
        dtw_vals.append(dd)
        z_vals.append(zd)

dtw_vals = np.asarray(dtw_vals, dtype=float)
z_vals = np.asarray(z_vals, dtype=float)

rho = _spearman_np(dtw_vals, z_vals)

print("\n" + "-"*80)
print("[RESULT A] Pairwise alignment test (random pairs)")
print(f"  #pairs used = {len(dtw_vals)}")
print(f"  Spearman( DTW(traj_i,traj_j), ||z_i - z_j|| ) = {rho:.4f}")
print("  Interpretation:")
print("   - If FACET space matches trajectory geometry, this Spearman should be strongly POSITIVE.")
print("   - Values ~0 mean no monotonic relationship; negative means inverted geometry.")
print("-"*80)

plt.figure()
plt.scatter(dtw_vals, z_vals, s=6, alpha=0.25)
plt.xlabel("DTW distance between trajectories")
plt.ylabel("FACET embedding distance ||z_i - z_j||")
plt.title(f"{SERIES_TO_TEST}: DTW vs FACET distance (Spearman={rho:.3f})")
plt.show()

print("\n" + "-"*80)
print("[RESULT B] Nearest-neighbor consistency (this is a stronger claim than correlation)")
print("  We compare top-k neighbors under FACET distance vs top-k under DTW distance.")
print("  overlap@k close to 1.0 is excellent; random baseline is about k/(N-1).")
print("-"*80)

Dz = np.sqrt(((Z[:, None, :] - Z[None, :, :]) ** 2).sum(axis=2))
np.fill_diagonal(Dz, np.inf)

Dd = np.full((N, N), np.inf, dtype=float)
for i in range(N):
    Dd[i, i] = np.inf
    for j in range(i+1, N):
        dd = float(dtw_sakoe_chiba(traj_list[i], traj_list[j], band=DTW_BAND))
        Dd[i, j] = dd
        Dd[j, i] = dd

overlaps = []
for i in range(N):
    nn_z = np.argsort(Dz[i])[:NN_K]
    nn_d = np.argsort(Dd[i])[:NN_K]
    ov = len(set(nn_z.tolist()) & set(nn_d.tolist())) / float(NN_K)
    overlaps.append(ov)

overlaps = np.asarray(overlaps, dtype=float)
mean_ov = float(np.mean(overlaps))
median_ov = float(np.median(overlaps))
rand_base = NN_K / float(max(1, N-1))

print(f"[NN] overlap@{NN_K}: mean={mean_ov:.4f}  median={median_ov:.4f}")
print(f"[NN] random-baseline≈{rand_base:.4f}  (this is what you'd get if FACET neighbors were unrelated to DTW)")
print("  Interpretation:")
print("   - If mean overlap is MUCH larger than baseline, FACET neighborhood preserves DTW neighborhood.")
print("   - This directly supports: “FACET space represents the trajectory geometry.”")

plt.figure()
plt.hist(overlaps, bins=20)
plt.xlabel(f"overlap@{NN_K} (DTW-NN ∩ FACET-NN)")
plt.ylabel("count")
plt.title(f"{SERIES_TO_TEST}: NN overlap@{NN_K} (mean={mean_ov:.3f}, baseline≈{rand_base:.3f})")
plt.show()

print("\n" + "="*80)
print("[SUMMARY FOR PAPER / REPORT]")
print(f"Series: {SERIES_TO_TEST}")
print(f"1) Spearman(DTW, ||z-z'||) = {rho:.4f}  (should be high +)")
print(f"2) NN overlap@{NN_K} mean = {mean_ov:.4f} vs random baseline {rand_base:.4f} (should be >> baseline)")
print("These two together are strong evidence that FACET embeddings encode trajectory geometry.")
print("="*80 + "\n")

"""# FACET Embeddings"""

import numpy as np
import torch

@torch.no_grad()
def collect_facet_embeddings(lm, dl, device, company_key="company_id", normalize=True):
    """
    Returns:
      company_ids: np.ndarray shape (N,)
      embeddings : np.ndarray shape (N, H)
    """
    lm.eval()

    all_ids = []
    all_emb = []

    for batch in dl:
        for k, v in batch.items():
            if torch.is_tensor(v):
                batch[k] = v.to(device)

        out = lm(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch.get("labels", None),
            output_hidden_states=True,
        )

        h_last = out.hidden_states[-1]
        fpos   = batch["facet_pos"].long()
        B, T, H = h_last.shape

        fpos = torch.clamp(fpos, 0, T - 1)

        z = h_last[torch.arange(B, device=device), fpos, :]
        z = z.float()

        if normalize:
            z = torch.nn.functional.normalize(z, p=2, dim=-1)

        z = z.detach().cpu().numpy()

        cids = batch[company_key]
        if torch.is_tensor(cids):
            cids = cids.detach().cpu().numpy()
        else:
            cids = np.array(cids, dtype=object)

        all_ids.append(cids)
        all_emb.append(z)

    company_ids = np.concatenate(all_ids, axis=0)
    embeddings  = np.concatenate(all_emb, axis=0).astype(np.float32)

    return company_ids, embeddings

import os
import numpy as np

def save_split_embeddings(save_dir, split_name, company_ids, embeddings, extra_meta=None):
    os.makedirs(save_dir, exist_ok=True)
    path = os.path.join(save_dir, f"facet_{split_name}.npz")

    payload = {
        "company_ids": company_ids,
        "embeddings": embeddings.astype(np.float32),
    }
    if extra_meta:
        for k, v in extra_meta.items():
            payload[k] = np.array(v, dtype=object)

    np.savez_compressed(path, **payload)
    print(f"Saved {split_name}: ids={len(company_ids)} emb_shape={embeddings.shape} -> {path}")
    return path

SAVE_DIR = "/content/facet_cache"

train_ids, train_emb = collect_facet_embeddings(lm, train_dl, cfg.DEVICE, company_key="company_id", normalize=True)
val_ids,   val_emb   = collect_facet_embeddings(lm, val_dl,   cfg.DEVICE, company_key="company_id", normalize=True)
test_ids,  test_emb  = collect_facet_embeddings(lm, test_dl,  cfg.DEVICE, company_key="company_id", normalize=True)

save_split_embeddings(SAVE_DIR, "train", train_ids, train_emb, extra_meta={"normalized": True})
save_split_embeddings(SAVE_DIR, "val",   val_ids,   val_emb,   extra_meta={"normalized": True})
save_split_embeddings(SAVE_DIR, "test",  test_ids,  test_emb,  extra_meta={"normalized": True})

"""# Load Later"""

import numpy as np

def load_split_embeddings(path):
    data = np.load(path, allow_pickle=True)
    company_ids = data["company_ids"]
    embeddings  = data["embeddings"].astype(np.float32)
    return company_ids, embeddings

"""# Build Look up"""

def build_company_index(company_ids):
    """
    If you have duplicates (multiple rows per company), this stores ALL indices per company.
    """
    idx = {}
    for i, cid in enumerate(company_ids):
        idx.setdefault(cid, []).append(i)
    return idx

"""# Retrieve"""

train_ids, train_emb = load_split_embeddings("/content/facet_cache/facet_train.npz")
train_index = build_company_index(train_ids)

def get_company_embeddings(company_id, embeddings, index_map, reduce="mean"):
    inds = index_map.get(company_id, [])
    if not inds:
        return None
    vecs = embeddings[inds]
    if reduce == "mean":
        return vecs.mean(axis=0)
    if reduce == "first":
        return vecs[0]
    return vecs

"""# New section"""

cid = train_ids[0]
vec = get_company_embeddings(cid, train_emb, train_index, reduce="mean")
print(cid, vec.shape)

batch = next(iter(train_dl))
print(batch.keys())


import numpy as np
import pandas as pd
import torch
from collections import defaultdict

SERIES_LIST = ["ESG", "ENV", "SOC", "GOV"]

MAX_PER_SER = 1000
PAIR_SAMPLES = 4000
TRIPLET_SAMPLES = 3000
KNN_K = 5
KNN_QUERIES = 60

Z_MODE = "geom"

NORMALIZE_Z = True

def _l2(a, b):
    return float(np.linalg.norm(a - b))

def _normalize(v, eps=1e-12):
    n = float(np.linalg.norm(v))
    if n < eps:
        return v
    return v / n

def _dtw_list(a_list, b_list):
    return float(dtw_sakoe_chiba(a_list, b_list, band=cfg.DTW_BAND))

def sample_pairs(n, m, rng):
    idx = np.arange(n)
    pairs = []
    for _ in range(m):
        i, j = rng.choice(idx, size=2, replace=False)
        pairs.append((int(i), int(j)))
    return pairs

@torch.no_grad()
def extract_facet_z_and_hist(dataloader, max_per_ser=1000, z_mode="geom", normalize_z=True):
    """
    Returns store[ser] = list of dicts: {ticker,t_index,z,hist}
    z extracted at facet_pos, optionally selecting geom/pred/full.
    """
    assert z_mode in ["geom", "pred", "full"], f"z_mode must be geom/pred/full, got {z_mode}"

    lm.eval()
    store = {s: [] for s in SERIES_LIST}

    K = int(getattr(cfg, "K", 0))
    if K <= 0:
        K = 10

    for batch in dataloader:
        out = lm(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["labels"],
            output_hidden_states=True,
        )
        h_last = out.hidden_states[-1]
        B, T, H = h_last.shape
        h2 = H // 2

        for i in range(B):
            ser = batch["target_series"][i]
            if ser not in store:
                continue
            if len(store[ser]) >= max_per_ser:
                continue

            p = int(batch["facet_pos"][i].item())
            if p < 0 or p >= T:
                continue

            hist = (batch["hist_vals"][i].get(ser, []) or [])
            hist = hist[-K:]
            if len(hist) == 0:
                continue

            z_full = h_last[i, p, :].detach().float().cpu().numpy()

            if z_mode == "geom":
                z = z_full[:h2]
            elif z_mode == "pred":
                z = z_full[h2:]
            else:
                z = z_full

            if normalize_z:
                z = _normalize(z)

            store[ser].append({
                "ticker": batch["ticker"][i],
                "t_index": int(batch["t_index"][i]),
                "z": z,
                "hist": np.asarray(hist, dtype=float),
            })

        if all(len(store[s]) >= max_per_ser for s in store):
            break

    return store

store = extract_facet_z_and_hist(test_dl, max_per_ser=MAX_PER_SER, z_mode=Z_MODE, normalize_z=NORMALIZE_Z)
for s in SERIES_LIST:
    print(f"[GEOM DATA] {s}: n={len(store[s])}  (Z_MODE={Z_MODE}, normalize={NORMALIZE_Z})")

def make_dtw_cache(arr):
    """
    Cache DTW for pairs in this series to avoid recompute in knn/triplet/bins.
    """
    cache = {}
    hists = [a["hist"].tolist() for a in arr]

    def dtw_cached(i, j):
        if i == j:
            return 0.0
        if i > j:
            i, j = j, i
        key = (i, j)
        if key in cache:
            return cache[key]
        d = _dtw_list(hists[i], hists[j])
        cache[key] = d
        return d

    return dtw_cached

def corr_metrics_for_series(arr, pair_samples=4000, seed=123):
    if len(arr) < 25:
        return {"spearman": np.nan, "pearson_logdtw": np.nan, "pairs": 0}

    rng = np.random.default_rng(seed)
    max_possible = len(arr) * (len(arr) - 1) // 2
    pairs = sample_pairs(len(arr), min(pair_samples, max_possible), rng)

    dtw_cached = make_dtw_cache(arr)

    dtw_list = []
    z_list = []
    for i, j in pairs:
        dtw_list.append(dtw_cached(i, j))
        z_list.append(_l2(arr[i]["z"], arr[j]["z"]))

    if len(dtw_list) < 20:
        return {"spearman": np.nan, "pearson_logdtw": np.nan, "pairs": len(dtw_list)}

    from scipy.stats import spearmanr, pearsonr
    sp = float(spearmanr(dtw_list, z_list).correlation)
    pl = float(pearsonr(np.log1p(dtw_list), z_list)[0])
    return {"spearman": sp, "pearson_logdtw": pl, "pairs": len(dtw_list)}

rows = []
for ser in SERIES_LIST:
    m = corr_metrics_for_series(store[ser], pair_samples=PAIR_SAMPLES, seed=int(cfg.SEED) + 7)
    rows.append({"series": ser, **m})
df_corr = pd.DataFrame(rows)
print("\n[A/B] Correlations per series:")
print(df_corr.to_string(index=False))

def knn_overlap(arr, k=5, n_queries=50, seed=0):
    if len(arr) < k + 5:
        return np.nan

    rng = np.random.default_rng(seed)
    idxs = np.arange(len(arr))
    Q = min(n_queries, len(arr))

    dtw_cached = make_dtw_cache(arr)

    overlaps = []
    for qi in rng.choice(idxs, size=Q, replace=False):
        zq = arr[int(qi)]["z"]

        zds = []
        for j in idxs:
            if j == qi:
                continue
            zds.append((int(j), _l2(zq, arr[int(j)]["z"])))
        z_nn = [j for j, _ in sorted(zds, key=lambda x: x[1])[:k]]

        dds = []
        for j in idxs:
            if j == qi:
                continue
            dds.append((int(j), dtw_cached(int(qi), int(j))))
        d_nn = [j for j, _ in sorted(dds, key=lambda x: x[1])[:k]]

        overlaps.append(len(set(z_nn).intersection(set(d_nn))) / float(k))

    return float(np.mean(overlaps))

rows = []
for ser in SERIES_LIST:
    ov = knn_overlap(store[ser], k=_


import numpy as np
import pandas as pd
import torch
from collections import defaultdict

SERIES_LIST = ["ESG", "ENV", "SOC", "GOV"]

MAX_PER_SER = 1000
PAIR_SAMPLES = 4000
TRIPLET_SAMPLES = 3000
KNN_K = 5
KNN_QUERIES = 60

Z_MODE = "geom"

NORMALIZE_Z = True

def _l2(a, b):
    return float(np.linalg.norm(a - b))

def _normalize(v, eps=1e-12):
    n = float(np.linalg.norm(v))
    if n < eps:
        return v
    return v / n

def _dtw_list(a_list, b_list):
    return float(dtw_sakoe_chiba(a_list, b_list, band=cfg.DTW_BAND))

def sample_pairs(n, m, rng):
    idx = np.arange(n)
    pairs = []
    for _ in range(m):
        i, j = rng.choice(idx, size=2, replace=False)
        pairs.append((int(i), int(j)))
    return pairs

@torch.no_grad()
def extract_facet_z_and_hist(dataloader, max_per_ser=1000, z_mode="geom", normalize_z=True):
    """
    Returns store[ser] = list of dicts: {ticker,t_index,z,hist}
    z extracted at facet_pos, optionally selecting geom/pred/full.
    """
    assert z_mode in ["geom", "pred", "full"], f"z_mode must be geom/pred/full, got {z_mode}"

    lm.eval()
    store = {s: [] for s in SERIES_LIST}

    K = int(getattr(cfg, "K", 0))
    if K <= 0:
        K = 10

    for batch in dataloader:
        out = lm(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["labels"],
            output_hidden_states=True,
        )
        h_last = out.hidden_states[-1]
        B, T, H = h_last.shape
        h2 = H // 2

        for i in range(B):
            ser = batch["target_series"][i]
            if ser not in store:
                continue
            if len(store[ser]) >= max_per_ser:
                continue

            p = int(batch["facet_pos"][i].item())
            if p < 0 or p >= T:
                continue

            hist = (batch["hist_vals"][i].get(ser, []) or [])
            hist = hist[-K:]
            if len(hist) == 0:
                continue

            z_full = h_last[i, p, :].detach().float().cpu().numpy()

            if z_mode == "geom":
                z = z_full[:h2]
            elif z_mode == "pred":
                z = z_full[h2:]
            else:
                z = z_full

            if normalize_z:
                z = _normalize(z)

            store[ser].append({
                "ticker": batch["ticker"][i],
                "t_index": int(batch["t_index"][i]),
                "z": z,
                "hist": np.asarray(hist, dtype=float),
            })

        if all(len(store[s]) >= max_per_ser for s in store):
            break

    return store

store = extract_facet_z_and_hist(test_dl, max_per_ser=MAX_PER_SER, z_mode=Z_MODE, normalize_z=NORMALIZE_Z)
for s in SERIES_LIST:
    print(f"[GEOM DATA] {s}: n={len(store[s])}  (Z_MODE={Z_MODE}, normalize={NORMALIZE_Z})")

def make_dtw_cache(arr):
    """
    Cache DTW for pairs in this series to avoid recompute in knn/triplet/bins.
    """
    cache = {}
    hists = [a["hist"].tolist() for a in arr]

    def dtw_cached(i, j):
        if i == j:
            return 0.0
        if i > j:
            i, j = j, i
        key = (i, j)
        if key in cache:
            return cache[key]
        d = _dtw_list(hists[i], hists[j])
        cache[key] = d
        return d

    return dtw_cached

def corr_metrics_for_series(arr, pair_samples=4000, seed=123):
    if len(arr) < 25:
        return {"spearman": np.nan, "pearson_logdtw": np.nan, "pairs": 0}

    rng = np.random.default_rng(seed)
    max_possible = len(arr) * (len(arr) - 1) // 2
    pairs = sample_pairs(len(arr), min(pair_samples, max_possible), rng)

    dtw_cached = make_dtw_cache(arr)

    dtw_list = []
    z_list = []
    for i, j in pairs:
        dtw_list.append(dtw_cached(i, j))
        z_list.append(_l2(arr[i]["z"], arr[j]["z"]))

    if len(dtw_list) < 20:
        return {"spearman": np.nan, "pearson_logdtw": np.nan, "pairs": len(dtw_list)}

    from scipy.stats import spearmanr, pearsonr
    sp = float(spearmanr(dtw_list, z_list).correlation)
    pl = float(pearsonr(np.log1p(dtw_list), z_list)[0])
    return {"spearman": sp, "pearson_logdtw": pl, "pairs": len(dtw_list)}

rows = []
for ser in SERIES_LIST:
    m = corr_metrics_for_series(store[ser], pair_samples=PAIR_SAMPLES, seed=int(cfg.SEED) + 7)
    rows.append({"series": ser, **m})
df_corr = pd.DataFrame(rows)
print("\n[A/B] Correlations per series:")
print(df_corr.to_string(index=False))

def knn_overlap(arr, k=5, n_queries=50, seed=0):
    if len(arr) < k + 5:
        return np.nan

    rng = np.random.default_rng(seed)
    idxs = np.arange(len(arr))
    Q = min(n_queries, len(arr))

    dtw_cached = make_dtw_cache(arr)

    overlaps = []
    for qi in rng.choice(idxs, size=Q, replace=False):
        zq = arr[int(qi)]["z"]

        zds = []
        for j in idxs:
            if j == qi:
                continue
            zds.append((int(j), _l2(zq, arr[int(j)]["z"])))
        z_nn = [j for j, _ in sorted(zds, key=lambda x: x[1])[:k]]

        dds = []
        for j in idxs:
            if j == qi:
                continue
            dds.append((int(j), dtw_cached(int(qi), int(j))))
        d_nn = [j for j, _ in sorted(dds, key=lambda x: x[1])[:k]]

        overlaps.append(len(set(z_nn).intersection(set(d_nn))) / float(k))

    return float(np.mean(overlaps))

rows = []
for ser in SERIES_LIST:
    ov = knn_overlap(store[ser], k=KNN_K, n_queries=KNN_QUERIES, seed=int(cfg.SEED) + 11)
    rows.append({"series": ser, f"knn_overlap@{KNN_K}": ov})
df_knn = pd.DataFrame(rows)
print("\n[C] kNN overlap per series:")
print(df_knn.to_string(index=False))

def triplet_order_accuracy(arr, triplet_samples=2000, seed=0):
    if len(arr) < 10:
        return {"triplet_acc": np.nan, "triplets": 0}

    rng = np.random.default_rng(seed)
    idx = np.arange(len(arr))

    dtw_cached = make_dtw_cache(arr)

    correct = 0
    total = 0
    for _ in range(triplet_samples):
        i, j, k = rng.choice(idx, size=3, replace=False)
        i, j, k = int(i), int(j), int(k)

        dij = dtw_cached(i, j)
        dik = dtw_cached(i, k)
        if abs(dij - dik) < 1e-9:
            continue

        zij = _l2(arr[i]["z"], arr[j]["z"])
        zik = _l2(arr[i]["z"], arr[k]["z"])

        if dij < dik:
            correct += int(zij < zik)
        else:
            correct += int(zik < zij)
        total += 1

    return {"triplet_acc": float(correct / max(1, total)), "triplets": int(total)}

rows = []
for ser in SERIES_LIST:
    m = triplet_order_accuracy(store[ser], triplet_samples=TRIPLET_SAMPLES, seed=int(cfg.SEED) + 19)
    rows.append({"series": ser, **m})
df_trip = pd.DataFrame(rows)
print("\n[D] Triplet ordering accuracy per series:")
print(df_trip.to_string(index=False))

def dtw_bins_vs_z(arr, pair_samples=4000, n_bins=6, seed=0):
    if len(arr) < 25:
        return None

    rng = np.random.default_rng(seed)
    max_possible = len(arr) * (len(arr) - 1) // 2
    pairs = sample_pairs(len(arr), min(pair_samples, max_possible), rng)

    dtw_cached = make_dtw_cache(arr)

    dtw_list = []
    z_list = []
    for i, j in pairs:
        dtw_list.append(dtw_cached(int(i), int(j)))
        z_list.append(_l2(arr[int(i)]["z"], arr[int(j)]["z"]))

    dtw_arr = np.asarray(dtw_list, dtype=float)
    z_arr = np.asarray(z_list, dtype=float)

    qs = np.quantile(dtw_arr, np.linspace(0, 1, n_bins + 1))
    rows = []
    for b in range(n_bins):
        lo, hi = qs[b], qs[b + 1]
        if b == n_bins - 1:
            mask = (dtw_arr >= lo) & (dtw_arr <= hi)
        else:
            mask = (dtw_arr >= lo) & (dtw_arr < hi)

        rows.append({
            "bin": b,
            "dtw_lo": float(lo),
            "dtw_hi": float(hi),
            "n": int(mask.sum()),
            "z_mean": float(z_arr[mask].mean()) if mask.sum() else np.nan,
            "z_median": float(np.median(z_arr[mask])) if mask.sum() else np.nan,
        })
    return pd.DataFrame(rows)

print("\n[E] DTW bins -> mean z-dist (per series):")
bins_tables = {}
for ser in SERIES_LIST:
    df_bins = dtw_bins_vs_z(store[ser], pair_samples=PAIR_SAMPLES, n_bins=6, seed=int(cfg.SEED) + 23)
    bins_tables[ser] = df_bins
    print("\nSeries:", ser)
    if df_bins is None:
        print("  Not enough data.")
    else:
        print(df_bins.to_string(index=False))



import os
import numpy as np
import pandas as pd
import torch
from collections import defaultdict

lm.eval()
if "value_head" in globals() and value_head is not None:
    value_head.eval()

EXCHANGES = ["AMEX", "CBOE", "NASDAQ", "NYSE",  "PNK"]
MAX_PRED_STEPS = 10
SERIES = ["ESG","ENV","SOC","GOV"]

def mse(y, yhat):
    y = np.asarray(y, dtype=float); yhat = np.asarray(yhat, dtype=float)
    return float(np.mean((y - yhat) ** 2))

def mae(y, yhat):
    y = np.asarray(y, dtype=float); yhat = np.asarray(yhat, dtype=float)
    return float(np.mean(np.abs(y - yhat)))

def mape(y, yhat):
    y = np.asarray(y, dtype=float); yhat = np.asarray(yhat, dtype=float)
    return float(np.mean(np.abs((y - yhat) / (np.abs(y) + 1e-8))) * 100.0)

def bias(y, yhat):
    y = np.asarray(y, dtype=float); yhat = np.asarray(yhat, dtype=float)
    return float(np.mean(yhat - y))

def theils_u(y, yhat):
    y = np.asarray(y, dtype=float); yhat = np.asarray(yhat, dtype=float)
    num = np.sqrt(np.mean((yhat - y) ** 2))
    den = np.sqrt(np.mean(y ** 2)) + 1e-8
    return float(num / den)

def compute_metrics_dict(y, yhat):
    return {
        "MSE": mse(y, yhat),
        "MAE": mae(y, yhat),
        "MAPE": mape(y, yhat),
        "TheilU": theils_u(y, yhat),
        "Bias": bias(y, yhat),
    }

def load_series_file(path):
    data = {}
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) < 2:
                continue
            ticker = parts[0]
            vals = []
            for x in parts[1:]:
                try:
                    vals.append(float(x))
                except:
                    vals.append(0.0)
            data[ticker] = vals
    return data

def _facet_token_for_series(series: str) -> str:
    return f"<FACET_{series}>"

def _find_facet_pos(input_ids_1d: torch.Tensor, series: str) -> int:
    """
    Find the position of the series-specific FACET token in the encoded prompt.
    Fallback: return last index if not found.
    """
    tok = _facet_token_for_series(series)
    tid = tokenizer.convert_tokens_to_ids(tok)
    if tid is None or tid < 0:
        return int(input_ids_1d.numel() - 1)

    hits = (input_ids_1d == tid).nonzero(as_tuple=False)
    if hits.numel() == 0:
        return int(input_ids_1d.numel() - 1)
    return int(hits[0].item())

@torch.no_grad()
def predict_token_ev(logits_last_1v: torch.Tensor, series: str) -> float:
    """
    logits_last_1v: [V] tensor (last position logits)
    """
    ids = allowed_ids[series]
    if ids.numel() == 0:
        pid = int(torch.argmax(logits_last_1v).item())
        tok = tokenizer.convert_ids_to_tokens(pid)
        try:
            return float(token_to_float(tok))
        except:
            return 0.0

    logits_1V = logits_last_1v.unsqueeze(0)
    ev = expected_value_from_logits(logits_1V, ids, id_to_val[series])[0].item()
    return float(ev)

@torch.no_grad()
def predict_reg_head(out_hidden_last: torch.Tensor, input_ids_1d: torch.Tensor, series: str) -> float:
    """
    out_hidden_last: [T,H] last hidden states for 1 sample
    input_ids_1d: [T]
    """
    if "value_head" not in globals() or value_head is None:
        return float("nan")

    p = _find_facet_pos(input_ids_1d, series)
    z_full = out_hidden_last[p, :]
    H = z_full.numel()
    h2 = H // 2
    z_pred = z_full[h2:]
    y = value_head(z_pred.unsqueeze(0)).squeeze(-1).item()
    return float(y)

@torch.no_grad()
def forecast_company_compare(
    ticker, esg, env, soc, gov, ret
):
    hist = {
        "ESG": esg[:cfg.K].copy(),
        "ENV": env[:cfg.K].copy(),
        "SOC": soc[:cfg.K].copy(),
        "GOV": gov[:cfg.K].copy(),
        "RET": ret[:cfg.K].copy(),
    }

    preds_ev   = {s: [] for s in SERIES}
    preds_head = {s: [] for s in SERIES}
    trues      = {s: [] for s in SERIES}

    max_t = min(len(esg) - 1, cfg.K + MAX_PRED_STEPS)

    for t in range(cfg.K, max_t):
        for series in SERIES:
            hist_tokens = {}
            for k in hist:
                hist_tokens[k] = [f"<{k}_{v:.2f}>" for v in hist[k][-cfg.K:]]

            prompt = build_prompt(
                ticker=ticker,
                start_dt=None,
                end_dt=None,
                hist_tokens_map=hist_tokens,
                senti_tokens=[],
                news_text="(no news)",
                target_series=series
            )

            enc = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(cfg.DEVICE)

            out = lm(**enc, output_hidden_states=True)
            logits = out.logits
            h_last = out.hidden_states[-1]

            logit_last = logits[0, -1, :]
            y_ev = predict_token_ev(logit_last, series)

            y_head = predict_reg_head(h_last[0], enc["input_ids"][0], series)

            y_true = {"ESG": esg[t], "ENV": env[t], "SOC": soc[t], "GOV": gov[t]}[series]

            preds_ev[series].append(float(y_ev))
            preds_head[series].append(float(y_head))
            trues[series].append(float(y_true))

        hist["ESG"].append(esg[t])
        hist["ENV"].append(env[t])
        hist["SOC"].append(soc[t])
        hist["GOV"].append(gov[t])
        hist["RET"].append(ret[t])

    return preds_ev, preds_head, trues

rows = []

for EX in EXCHANGES:
    print("\n" + "="*80)
    print("EXCHANGE:", EX)
    print("="*80)

    base = EX
    path_esg = os.path.join(base, "esg_risk_ratings_1.txt")
    path_env = os.path.join(base, "PaperReady_e_scores.txt")
    path_soc = os.path.join(base, "PaperReady_s_scores.txt")
    path_gov = os.path.join(base, "PaperReady_g_scores.txt")

    if not os.path.exists(path_esg):
        print("Missing folder:", EX, "-> skipping")
        continue

    ESG = load_series_file(path_esg)
    ENV = load_series_file(path_env)
    SOC = load_series_file(path_soc)
    GOV = load_series_file(path_gov)

    RETURNS = {}
    for tkr in ESG.keys():
        dates = list(range(len(ESG[tkr])))
        try:
            RETURNS[tkr] = align_and_fill_returns(tkr, dates)
        except:
            RETURNS[tkr] = [0.0] * len(ESG[tkr])

    agg_true = {s: [] for s in SERIES}
    agg_ev   = {s: [] for s in SERIES}
    agg_head = {s: [] for s in SERIES}

    n_used = 0
    for ticker in ESG.keys():
        if ticker not in ENV or ticker not in SOC or ticker not in GOV:
            continue
        if len(ESG[ticker]) < cfg.K + 5:
            continue

        preds_ev, preds_head, trues = forecast_company_compare(
            ticker,
            ESG[ticker],
            ENV[ticker],
            SOC[ticker],
            GOV[ticker],
            RETURNS.get(ticker, [0.0]*len(ESG[ticker])),
        )
        n_used += 1

        for s in SERIES:
            agg_true[s].extend(trues[s])
            agg_ev[s].extend(preds_ev[s])
            agg_head[s].extend(preds_head[s])

    print(f"\n[OOS] companies used in {EX}: {n_used}")

    for s in SERIES:
        if len(agg_true[s]) == 0:
            print(f"  {s}: N/A (no samples)")
            continue

        met_ev = compute_metrics_dict(agg_true[s], agg_ev[s])
        met_hd = compute_metrics_dict(agg_true[s], agg_head[s])

        print("\nSeries:", s)
        print("  TokenEV :", {k: round(v, 4) for k, v in met_ev.items()})
        print("  RegHead :", {k: round(v, 4) for k, v in met_hd.items()})

        for method_name, met in [("TokenEV", met_ev), ("RegHead", met_hd)]:
            rows.append({
                "exchange": EX,
                "series": s,
                "method": method_name,
                "n_samples": len(agg_true[s]),
                **met
            })

df_oos = pd.DataFrame(rows)
display(df_oos.sort_values(["exchange","series","method"]).reset_index(drop=True))

pivot = df_oos.pivot_table(
    index=["exchange","series"],
    columns="method",
    values=["MSE","MAE","MAPE","TheilU","Bias"],
    aggfunc="mean"
)
display(pivot)


import numpy as np
import torch
import matplotlib.pyplot as plt

SERIES_TO_TEST = "ESG"
MAX_SAMPLES    = 180
PAIR_SAMPLES   = 2000
DTW_BAND       = int(getattr(cfg, "DTW_BAND", 2))

NN_K           = 5
NN_QUERIES     = 40

TRIPLET_SAMPLES = 1500

Z_MODE = "geom"

L2_NORMALIZE_Z = True

print("\n" + "="*80)
print("[EXPERIMENT] DTW vs FACET Geometry Alignment (Improved)")
print(f"[SETTINGS] series={SERIES_TO_TEST} MAX_SAMPLES={MAX_SAMPLES} "
      f"PAIR={PAIR_SAMPLES} NN_K={NN_K} NN_QUERIES={NN_QUERIES} "
      f"TRIPLET={TRIPLET_SAMPLES} DTW_BAND={DTW_BAND} Z_MODE={Z_MODE} "
      f"L2_NORM={L2_NORMALIZE_Z}")
print("="*80)

def _extract_traj(d, ser):
    arr = d.get(ser, []) or []
    out = []
    for v in arr:
        try:
            fv = float(v)
            if np.isfinite(fv):
                out.append(fv)
        except:
            pass
    return out

def _split_z(z):
    H = z.shape[-1]
    h2 = H // 2
    return z[..., :h2], z[..., h2:]

def _choose_z(z_full):
    zg, zp = _split_z(z_full)
    if Z_MODE == "geom":
        return zg
    elif Z_MODE == "pred":
        return zp
    else:
        return z_full

def _l2(a, b):
    return float(np.linalg.norm(a - b))

def _dtw(a, b):
    return float(dtw_sakoe_chiba(a, b, band=DTW_BAND))

def _spearman_np(x, y):
    x = np.asarray(x, dtype=float)
    y = np.asarray(y, dtype=float)
    if x.size < 3:
        return np.nan

    def rankdata(a):
        temp = a.argsort()
        ranks = np.empty_like(temp, dtype=float)
        ranks[temp] = np.arange(len(a), dtype=float)
        sorted_a = a[temp]
        i = 0
        while i < len(a):
            j = i
            while j + 1 < len(a) and sorted_a[j + 1] == sorted_a[i]:
                j += 1
            if j > i:
                avg = (i + j) / 2.0
                ranks[temp[i:j+1]] = avg
            i = j + 1
        return ranks

    rx = rankdata(x); ry = rankdata(y)
    rx = rx - rx.mean(); ry = ry - ry.mean()
    denom = (np.sqrt((rx**2).sum()) * np.sqrt((ry**2).sum()))
    if denom < 1e-12:
        return np.nan
    return float((rx * ry).sum() / denom)

def _pearson_np(x, y):
    x = np.asarray(x, dtype=float); y = np.asarray(y, dtype=float)
    if x.size < 3:
        return np.nan
    x = x - x.mean()
    y = y - y.mean()
    denom = np.sqrt((x*x).sum()) * np.sqrt((y*y).sum())
    if denom < 1e-12:
        return np.nan
    return float((x*y).sum() / denom)

lm.eval()
Z_list, traj_list = [], []

with torch.no_grad():
    for batch in val_dl:
        series_list = batch["target_series"]
        idxs = [i for i, s in enumerate(series_list) if s == SERIES_TO_TEST]
        if len(idxs) == 0:
            continue

        out = lm(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["labels"],
            output_hidden_states=True,
        )
        h_last = out.hidden_states[-1]
        fpos = batch["facet_pos"]

        source_key = "traj_vals" if ("traj_vals" in batch) else "hist_vals"

        for i in idxs:
            p = int(fpos[i].item())
            if p < 0 or p >= h_last.size(1):
                continue

            z_full = h_last[i, p, :].detach().float().cpu().numpy()
            z = _choose_z(z_full)

            sample_dict = batch[source_key][i]
            tr = _extract_traj(sample_dict, SERIES_TO_TEST)
            if len(tr) < 3:
                continue

            if L2_NORMALIZE_Z:
                n = np.linalg.norm(z) + 1e-12
                z = z / n

            Z_list.append(z)
            traj_list.append(tr)

            if len(Z_list) >= MAX_SAMPLES:
                break
        if len(Z_list) >= MAX_SAMPLES:
            break

N = len(Z_list)
print(f"[COLLECT] gathered N={N} samples for series={SERIES_TO_TEST}")

import time

while True:
    print("Still alive...")
    time.sleep(60)


import numpy as np
import torch
import matplotlib.pyplot as plt

SERIES_TO_TEST = "ESG"
MAX_SAMPLES = 180
PAIR_SPEARMAN = 2000
NN_K = 5
DTW_BAND = int(getattr(cfg, "DTW_BAND", 2))

L2_NORMALIZE_Z = True

print("\n" + "="*80, flush=True)
print("[EXPERIMENT] DTW vs FACET Geometry Alignment (Z_GEOM only)", flush=True)
print(f"[SETTINGS] series={SERIES_TO_TEST} MAX_SAMPLES={MAX_SAMPLES} "
      f"PAIR_SPEARMAN={PAIR_SPEARMAN} NN_K={NN_K} DTW_BAND={DTW_BAND} "
      f"L2_NORMALIZE_Z={L2_NORMALIZE_Z}", flush=True)
print("="*80, flush=True)

def _spearman_np(x, y):
    x = np.asarray(x, dtype=float)
    y = np.asarray(y, dtype=float)
    if x.size < 3:
        return np.nan

    def rankdata(a):
        temp = a.argsort()
        ranks = np.empty_like(temp, dtype=float)
        ranks[temp] = np.arange(len(a), dtype=float)
        sorted_a = a[temp]
        i = 0
        while i < len(a):
            j = i
            while j + 1 < len(a) and sorted_a[j + 1] == sorted_a[i]:
                j += 1
            if j > i:
                avg = (i + j) / 2.0
                ranks[temp[i:j+1]] = avg
            i = j + 1
        return ranks

    rx = rankdata(x)
    ry = rankdata(y)
    rx = rx - rx.mean()
    ry = ry - ry.mean()
    denom = (np.sqrt((rx**2).sum()) * np.sqrt((ry**2).sum()))
    if denom < 1e-12:
        return np.nan
    return float((rx * ry).sum() / denom)

def _extract_traj(sample_hist_or_traj, ser):
    arr = sample_hist_or_traj.get(ser, []) or []
    out = []
    for v in arr:
        try:
            fv = float(v)
            if np.isfinite(fv):
                out.append(fv)
        except:
            pass
    return out

def _z_geom_from_full(z_full_np):
    H = z_full_np.shape[-1]
    h2 = H // 2
    return z_full_np[:h2]

lm.eval()
Z_list = []
traj_list = []

with torch.no_grad():
    for batch in val_dl:
        series_list = batch["target_series"]
        idx = [i for i, s in enumerate(series_list) if s == SERIES_TO_TEST]
        if len(idx) == 0:
            continue

        out = lm(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["labels"],
            output_hidden_states=True,
        )
        h_last = out.hidden_states[-1]
        fpos = batch["facet_pos"]

        source_key = "traj_vals" if ("traj_vals" in batch) else "hist_vals"

        for i in idx:
            p = int(fpos[i].item())
            if p < 0 or p >= h_last.size(1):
                continue

            z_full = h_last[i, p, :].detach().float().cpu().numpy()
            z = _z_geom_from_full(z_full)

            if L2_NORMALIZE_Z:
                z = z / (np.linalg.norm(z) + 1e-12)

            sample_dict = batch[source_key][i]
            tr = _extract_traj(sample_dict, SERIES_TO_TEST)
            if len(tr) < 3:
                continue

            Z_list.append(z)
            traj_list.append(tr)

            if len(Z_list) >= MAX_SAMPLES:
                break

        if len(Z_list) >= MAX_SAMPLES:
            break

N = len(Z_list)
print(f"[COLLECT] gathered N={N} samples for series={SERIES_TO_TEST}", flush=True)
if N < 30:
    print("[WARNING] Too few samples. Increase MAX_SAMPLES or check data.", flush=True)
    raise SystemExit

Z = np.stack(Z_list, axis=0)

rng = np.random.default_rng(int(getattr(cfg, "SEED", 0)) + 2026)
n_pairs = min(PAIR_SPEARMAN, N*(N-1)//2)

pairs = set()
while len(pairs) < n_pairs:
    i, j = rng.integers(0, N), rng.integers(0, N)
    if i == j:
        continue
    if i > j:
        i, j = j, i
    pairs.add((int(i), int(j)))
pairs = list(pairs)

dtw_vals = []
z_vals = []

print(f"[A] computing DTW + z distances for {len(pairs)} random pairs...", flush=True)
for t, (i, j) in enumerate(pairs, 1):
    xi = traj_list[i]
    xj = traj_list[j]
    dd = float(dtw_sakoe_chiba(xi, xj, band=DTW_BAND))
    zd = float(np.linalg.norm(Z[i] - Z[j]))
    if np.isfinite(dd) and np.isfinite(zd):
        dtw_vals.append(dd)
        z_vals.append(zd)

    if (t % 250) == 0:
        print(f"  [A] done {t}/{len(pairs)} pairs", flush=True)

dtw_vals = np.asarray(dtw_vals, dtype=float)
z_vals = np.asarray(z_vals, dtype=float)

rho = _spearman_np(dtw_vals, z_vals)

print("\n" + "-"*80, flush=True)
print("[RESULT A] Pairwise alignment test (random pairs) [Z_GEOM]", flush=True)
print(f"  #pairs used = {len(dtw_vals)}", flush=True)
print(f"  Spearman( DTW(traj_i,traj_j), ||z_geom_i - z_geom_j|| ) = {rho:.4f}", flush=True)
print("  Interpretation:", flush=True)
print("   - If FACET z_geom matches trajectory geometry, this should be strongly POSITIVE.", flush=True)
print("   - Values ~0 mean no monotonic relationship; negative means inverted geometry.", flush=True)
print("-"*80, flush=True)

plt.figure()
plt.scatter(dtw_vals, z_vals, s=6, alpha=0.25)
plt.xlabel("DTW distance between trajectories")
plt.ylabel("FACET z_geom distance ||z_i - z_j||")
plt.title(f"{SERIES_TO_TEST}: DTW vs z_geom distance (Spearman={rho:.3f})")
plt.show()

print("\n" + "-"*80, flush=True)
print("[RESULT B] Nearest-neighbor consistency (z_geom-NN vs DTW-NN)", flush=True)
print("  overlap@k close to 1.0 is excellent; random baseline is about k/(N-1).", flush=True)
print("-"*80, flush=True)

Dz = np.sqrt(((Z[:, None, :] - Z[None, :, :]) ** 2).sum(axis=2))
np.fill_diagonal(Dz, np.inf)

Dd = np.full((N, N), np.inf, dtype=float)
print(f"[B] computing full DTW distance matrix N={N} (this is the slow part)...", flush=True)

for i in range(N):
    for j in range(i+1, N):
        dd = float(dtw_sakoe_chiba(traj_list[i], traj_list[j], band=DTW_BAND))
        Dd[i, j] = dd
        Dd[j, i] = dd
    if ((i + 1) % 10) == 0:
        print(f"  [B] DTW rows done {i+1}/{N}", flush=True)

overlaps = []
for i in range(N):
    nn_z = np.argsort(Dz[i])[:NN_K]
    nn_d = np.argsort(Dd[i])[:NN_K]
    ov = len(set(nn_z.tolist()) & set(nn_d.tolist())) / float(NN_K)
    overlaps.append(ov)

overlaps = np.asarray(overlaps, dtype=float)
mean_ov = float(np.mean(overlaps))
median_ov = float(np.median(overlaps))
rand_base = NN_K / float(max(1, N-1))

print(f"[NN] overlap@{NN_K}: mean={mean_ov:.4f}  median={median_ov:.4f}", flush=True)
print(f"[NN] random-baseline≈{rand_base:.4f}  (expected if z_geom neighbors were unrelated to DTW)", flush=True)
print("  Interpretation:", flush=True)
print("   - If mean overlap >> baseline, z_geom neighborhoods preserve DTW neighborhoods.", flush=True)
print("   - This supports: “z_geom encodes trajectory geometry.”", flush=True)

plt.figure()
plt.hist(overlaps, bins=20)
plt.xlabel(f"overlap@{NN_K} (DTW-NN ∩ z_geom-NN)")
plt.ylabel("count")
plt.title(f"{SERIES_TO_TEST}: NN overlap@{NN_K} (mean={mean_ov:.3f}, baseline≈{rand_base:.3f})")
plt.show()

print("\n" + "="*80, flush=True)
print("[SUMMARY FOR PAPER / REPORT] (Z_GEOM)", flush=True)
print(f"Series: {SERIES_TO_TEST}", flush=True)
print(f"1) Spearman(DTW, ||z_geom-z_geom'||) = {rho:.4f}  (want high +)", flush=True)
print(f"2) NN overlap@{NN_K} mean = {mean_ov:.4f} vs random baseline {rand_base:.4f} (want >> baseline)", flush=True)
print("These two together are strong evidence that FACET z_geom encodes trajectory geometry.", flush=True)
print("="*80 + "\n", flush=True)

"""# BIAS CORRECTION"""


import numpy as np
import pandas as pd

def fit_series_bias(df_rows: pd.DataFrame) -> dict:
    """
    df_rows must have columns: ['series','y_true','y_pred'].
    Returns {series: bias_hat} where bias_hat = mean(y_pred - y_true).
    """
    bias_map = {}
    for ser, g in df_rows.groupby("series"):
        bias_map[ser] = float(np.mean(g["y_pred"].to_numpy(dtype=float) - g["y_true"].to_numpy(dtype=float)))
    return bias_map

def apply_series_bias(df_rows: pd.DataFrame, bias_map: dict, pred_col: str = "y_pred") -> pd.DataFrame:
    """
    Adds:
      bias_hat (per row from series)
      y_pred_bc = y_pred - bias_hat
    """
    df = df_rows.copy()
    df["bias_hat"] = df["series"].map(lambda s: float(bias_map.get(s, 0.0)))
    df["y_pred_bc"] = df[pred_col] - df["bias_hat"]
    return df

def metrics_from_rows(df_rows: pd.DataFrame, split_name: str, pred_col: str = "y_pred") -> tuple[pd.DataFrame, pd.DataFrame]:
    """
    Recompute:
      df_series_macro (macro across companies, equal weight per company)
      df_company (per company-series)
    using predictions in pred_col.
    """
    comp_rows = []
    for (ser, tick), g in df_rows.groupby(["series", "ticker"]):
        yt = g["y_true"].to_numpy(dtype=float)
        yp = g[pred_col].to_numpy(dtype=float)
        yv = g["y_prev"].to_numpy(dtype=float)

        comp_rows.append({
            "split": split_name,
            "series": ser,
            "ticker": tick,
            "n": int(len(g)),
            "MSE": float(np.mean((yp - yt) ** 2)),
            "MAE": float(np.mean(np.abs(yp - yt))),
            "MAPE": safe_mape(yt, yp),
            "TheilsU2": theils_u2_with_prev(yt, yp, yv),
            "Bias": bias(yt, yp),
        })

    df_company = pd.DataFrame(comp_rows).sort_values(["series", "MSE"])

    macro_rows = []
    for ser, g in df_company.groupby("series"):
        macro_rows.append({
            "split": split_name,
            "series": ser,
            "n_companies": int(g["ticker"].nunique()),
            "MSE": float(g["MSE"].mean()),
            "MAE": float(g["MAE"].mean()),
            "MAPE": float(g["MAPE"].mean()),
            "TheilsU2": float(g["TheilsU2"].mean()),
            "Bias": float(g["Bias"].mean()),
        })
    df_series_macro = pd.DataFrame(macro_rows).sort_values("series")
    return df_series_macro, df_company

@torch.no_grad()
def eval_forecast_metrics(dataloader, split_name="TEST", return_rows: bool = False):
    """
    Returns:
      df_series_macro: rows per series with macro metrics
      df_company: rows per (series,ticker)
      df_rows (optional): per-sample rows with y_true/y_pred/y_prev etc.
    """
    lm.eval()
    rows = []

    for batch_i, batch in enumerate(dataloader):
        out = lm(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["labels"],
            output_hidden_states=False,
        )

        y_true, y_pred = expected_numeric_pred(out.logits, batch["labels"], batch["target_series"])

        y_prev = []
        for i in range(len(batch["target_series"])):
            ser = batch["target_series"][i]
            hv = batch["hist_vals"][i].get(ser, [])
            if hv is None or len(hv) == 0:
                y_prev.append(float(y_true[i]))
            else:
                y_prev.append(float(hv[-1]))
        y_prev = np.array(y_prev, dtype=float)

        if split_name == "TEST" and getattr(cfg, "PRINT_ALL_TEST_SAMPLES", False) and batch_i < 3:
            preds_tok = predict_from_labels_position(out.logits, batch["labels"], batch["target_series"])
            for i in range(min(5, len(preds_tok))):
                log(f"[{split_name} PRED] ticker={batch['ticker'][i]} series={batch['target_series'][i]} "
                    f"true={batch['target_token'][i]} pred_tok={preds_tok[i]} pred_val={y_pred[i]:.2f}")
                print("----- PROMPT (HEAD) -----")
                print(batch["text"][i][:900])
                print("-------------------------\n")

        for i in range(len(y_true)):
            rows.append({
                "split": split_name,
                "ticker": batch["ticker"][i],
                "series": batch["target_series"][i],
                "t_index": int(batch["t_index"][i]),
                "y_true": float(y_true[i]),
                "y_pred": float(y_pred[i]),
                "y_prev": float(y_prev[i]),
                "text_head": batch["text"][i][:240] if "text" in batch else "",
            })

    df = pd.DataFrame(rows)

    comp_rows = []
    for (ser, tick), g in df.groupby(["series", "ticker"]):
        yt = g["y_true"].to_numpy(float)
        yp = g["y_pred"].to_numpy(float)
        yv = g["y_prev"].to_numpy(float)
        comp_rows.append({
            "split": split_name,
            "series": ser,
            "ticker": tick,
            "n": int(len(g)),
            "MSE": float(np.mean((yp-yt)**2)),
            "MAE": float(np.mean(np.abs(yp-yt))),
            "MAPE": safe_mape(yt, yp),
            "TheilsU2": theils_u2_with_prev(yt, yp, yv),
            "Bias": bias(yt, yp),
        })
    df_company = pd.DataFrame(comp_rows).sort_values(["series","MSE"])

    macro_rows = []
    for ser, g in df_company.groupby("series"):
        macro_rows.append({
            "split": split_name,
            "series": ser,
            "n_companies": int(g["ticker"].nunique()),
            "MSE": float(g["MSE"].mean()),
            "MAE": float(g["MAE"].mean()),
            "MAPE": float(g["MAPE"].mean()),
            "TheilsU2": float(g["TheilsU2"].mean()),
            "Bias": float(g["Bias"].mean()),
        })
    df_series_macro = pd.DataFrame(macro_rows).sort_values("series")

    if return_rows:
        return df_series_macro, df_company, df
    return df_series_macro, df_company

val_macro, val_company, df_val_rows = eval_forecast_metrics(val_dl, split_name="VAL", return_rows=True)
test_macro, test_company, df_test_rows = eval_forecast_metrics(test_dl, split_name="TEST", return_rows=True)

log("[METRICS] VAL per-series macro (raw):")
display(val_macro)

log("[METRICS] TEST per-series macro (raw):")
display(test_macro)

bias_map = fit_series_bias(df_val_rows)

log("[BIAS] Learned per-series bias on VAL (mean(y_pred - y_true)):")
display(pd.DataFrame([{"series": k, "bias_hat": v} for k, v in bias_map.items()]).sort_values("series"))

df_test_bc = apply_series_bias(df_test_rows, bias_map, pred_col="y_pred")

test_macro_bc, test_company_bc = metrics_from_rows(df_test_bc, split_name="TEST_BC", pred_col="y_pred_bc")

log("[METRICS] TEST per-series macro (AFTER bias correction):")
display(test_macro_bc)

log("[METRICS] TEST top-10 best company-series by MSE (raw):")
display(test_company.head(10))

log("[METRICS] TEST bottom-10 worst company-series by MSE (raw):")
display(test_company.tail(10))

log("[METRICS] TEST top-10 best company-series by MSE (AFTER bias correction):")
display(test_company_bc.head(10))

log("[METRICS] TEST bottom-10 worst company-series by MSE (AFTER bias correction):")
display(test_company_bc.tail(10))

log("[SAMPLES] TEST sample rows (first 25) showing BEFORE/AFTER bias correction:")
display(df_test_bc[[
    "ticker","series","t_index","y_prev","y_true","y_pred","bias_hat","y_pred_bc","text_head"
]].head(25))

df_test_bc = df_test_bc.copy()
df_test_bc["se2_before"] = (df_test_bc["y_pred"] - df_test_bc["y_true"])**2
df_test_bc["se2_after"]  = (df_test_bc["y_pred_bc"] - df_test_bc["y_true"])**2

log("[SAMPLES] Worst-10 by squared error BEFORE (and their AFTER):")
display(df_test_bc.sort_values("se2_before", ascending=False)[
    ["ticker","series","t_index","y_true","y_pred","y_pred_bc","se2_before","se2_after","text_head"]
].head(10))

df_test_bc["delta_se2"] = df_test_bc["se2_before"] - df_test_bc["se2_after"]
log("[SAMPLES] Top-10 biggest squared-error improvements from bias correction:")
display(df_test_bc.sort_values("delta_se2", ascending=False)[
    ["ticker","series","t_index","y_true","y_pred","y_pred_bc","se2_before","se2_after","delta_se2","text_head"]
].head(10))

@torch.no_grad()
def compute_facet_vectors(dataloader) -> pd.DataFrame:
    """
    For each sample, compute z = hidden_state at target facet position.
    Return columns: ticker, series, t_index, z (np array), hist (np array)
    """
    lm.eval()
    rows = []
    for batch in dataloader:
        out = lm(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["labels"],
            output_hidden_states=True,
        )
        h_last = out.hidden_states[-1]
        B,T,Hdim = h_last.shape
        for i in range(B):
            ser = batch["target_series"][i]
            p = int(batch["facet_pos"][i].item())
            if p < 0 or p >= T:
                continue

            z_full = h_last[i, p, :].detach()
            z_geom = z_full[: z_full.numel()//2].cpu().numpy()

            hist = np.array(batch["hist_vals"][i].get(ser, [])[-cfg.K:], dtype=float)
            rows.append({
                "ticker": batch["ticker"][i],
                "series": ser,
                "t_index": int(batch["t_index"][i]),
                "z": z_geom,
                "hist": hist,
            })
    return pd.DataFrame(rows)

df_z = compute_facet_vectors(test_dl)
log(f"[GEOM] extracted facet vectors on TEST: n={len(df_z)}")
display(df_z.head(3))

val_macro.to_csv("val_series_macro.csv", index=False)
test_macro.to_csv("test_series_macro.csv", index=False)
test_company.to_csv("test_company_metrics.csv", index=False)

test_macro_bc.to_csv("test_series_macro_bias_corrected.csv", index=False)
test_company_bc.to_csv("test_company_metrics_bias_corrected.csv", index=False)
df_test_bc.to_csv("test_pred_rows_with_bias_correction.csv", index=False)

log("[SAVED] wrote raw + bias-corrected metrics + sample-level rows.")


import os
import numpy as np
import pandas as pd
import torch
from collections import defaultdict

lm.eval()
if "value_head" in globals() and value_head is not None:
    value_head.eval()

EXCHANGES = ["AMEX", "CBOE", "NASDAQ", "NYSE", "OTC", "PNK"]
MAX_PRED_STEPS = 10
SERIES = ["ESG","ENV","SOC","GOV"]

def mse(y, yhat):
    y = np.asarray(y, dtype=float); yhat = np.asarray(yhat, dtype=float)
    return float(np.mean((y - yhat) ** 2))

def mae(y, yhat):
    y = np.asarray(y, dtype=float); yhat = np.asarray(yhat, dtype=float)
    return float(np.mean(np.abs(y - yhat)))

def mape(y, yhat):
    y = np.asarray(y, dtype=float); yhat = np.asarray(yhat, dtype=float)
    return float(np.mean(np.abs((y - yhat) / (np.abs(y) + 1e-8))) * 100.0)

def bias(y, yhat):
    y = np.asarray(y, dtype=float); yhat = np.asarray(yhat, dtype=float)
    return float(np.mean(yhat - y))

def theils_u(y, yhat):
    y = np.asarray(y, dtype=float); yhat = np.asarray(yhat, dtype=float)
    num = np.sqrt(np.mean((yhat - y) ** 2))
    den = np.sqrt(np.mean(y ** 2)) + 1e-8
    return float(num / den)

def compute_metrics_dict(y, yhat):
    return {
        "MSE": mse(y, yhat),
        "MAE": mae(y, yhat),
        "MAPE": mape(y, yhat),
        "TheilU": theils_u(y, yhat),
        "Bias": bias(y, yhat),
    }

def load_series_file(path):
    data = {}
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) < 2:
                continue
            ticker = parts[0]
            vals = []
            for x in parts[1:]:
                try:
                    vals.append(float(x))
                except:
                    vals.append(0.0)
            data[ticker] = vals
    return data

def _facet_token_for_series(series: str) -> str:
    return f"<FACET_{series}>"

def _find_facet_pos(input_ids_1d: torch.Tensor, series: str) -> int:
    """
    Find the position of the series-specific FACET token in the encoded prompt.
    Fallback: return last index if not found.
    """
    tok = _facet_token_for_series(series)
    tid = tokenizer.convert_tokens_to_ids(tok)
    if tid is None or tid < 0:
        return int(input_ids_1d.numel() - 1)

    hits = (input_ids_1d == tid).nonzero(as_tuple=False)
    if hits.numel() == 0:
        return int(input_ids_1d.numel() - 1)
    return int(hits[0].item())

@torch.no_grad()
def predict_token_ev(logits_last_1v: torch.Tensor, series: str) -> float:
    """
    logits_last_1v: [V] tensor (last position logits)
    """
    ids = allowed_ids[series]
    if ids.numel() == 0:
        pid = int(torch.argmax(logits_last_1v).item())
        tok = tokenizer.convert_ids_to_tokens(pid)
        try:
            return float(token_to_float(tok))
        except:
            return 0.0

    logits_1V = logits_last_1v.unsqueeze(0)
    ev = expected_value_from_logits(logits_1V, ids, id_to_val[series])[0].item()
    return float(ev)

@torch.no_grad()
def predict_reg_head(out_hidden_last: torch.Tensor, input_ids_1d: torch.Tensor, series: str) -> float:
    """
    out_hidden_last: [T,H] last hidden states for 1 sample
    input_ids_1d: [T]
    """
    if "value_head" not in globals() or value_head is None:
        return float("nan")

    p = _find_facet_pos(input_ids_1d, series)
    z_full = out_hidden_last[p, :]
    H = z_full.numel()
    h2 = H // 2
    z_pred = z_full[h2:]
    y = value_head(z_pred.unsqueeze(0)).squeeze(-1).item()
    return float(y)

@torch.no_grad()
def forecast_company_compare(
    ticker, esg, env, soc, gov, ret
):
    hist = {
        "ESG": esg[:cfg.K].copy(),
        "ENV": env[:cfg.K].copy(),
        "SOC": soc[:cfg.K].copy(),
        "GOV": gov[:cfg.K].copy(),
        "RET": ret[:cfg.K].copy(),
    }

    preds_ev   = {s: [] for s in SERIES}
    preds_head = {s: [] for s in SERIES}
    trues      = {s: [] for s in SERIES}

    max_t = min(len(esg) - 1, cfg.K + MAX_PRED_STEPS)

    for t in range(cfg.K, max_t):
        for series in SERIES:
            hist_tokens = {}
            for k in hist:
                hist_tokens[k] = [f"<{k}_{v:.2f}>" for v in hist[k][-cfg.K:]]

            prompt = build_prompt(
                ticker=ticker,
                start_dt=None,
                end_dt=None,
                hist_tokens_map=hist_tokens,
                senti_tokens=[],
                news_text="(no news)",
                target_series=series
            )

            enc = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(cfg.DEVICE)

            out = lm(**enc, output_hidden_states=True)
            logits = out.logits
            h_last = out.hidden_states[-1]

            logit_last = logits[0, -1, :]
            y_ev = predict_token_ev(logit_last, series)

            y_head = predict_reg_head(h_last[0], enc["input_ids"][0], series)

            y_true = {"ESG": esg[t], "ENV": env[t], "SOC": soc[t], "GOV": gov[t]}[series]

            preds_ev[series].append(float(y_ev))
            preds_head[series].append(float(y_head))
            trues[series].append(float(y_true))

        hist["ESG"].append(esg[t])
        hist["ENV"].append(env[t])
        hist["SOC"].append(soc[t])
        hist["GOV"].append(gov[t])
        hist["RET"].append(ret[t])

    return preds_ev, preds_head, trues

rows = []

for EX in EXCHANGES:
    print("\n" + "="*80)
    print("EXCHANGE:", EX)
    print("="*80)

    base = EX
    path_esg = os.path.join(base, "esg_risk_ratings_1.txt")
    path_env = os.path.join(base, "PaperReady_e_scores.txt")
    path_soc = os.path.join(base, "PaperReady_s_scores.txt")
    path_gov = os.path.join(base, "PaperReady_g_scores.txt")

    if not os.path.exists(path_esg):
        print("Missing folder:", EX, "-> skipping")
        continue

    ESG = load_series_file(path_esg)
    ENV = load_series_file(path_env)
    SOC = load_series_file(path_soc)
    GOV = load_series_file(path_gov)

    RETURNS = {}
    for tkr in ESG.keys():
        dates = list(range(len(ESG[tkr])))
        try:
            RETURNS[tkr] = align_and_fill_returns(tkr, dates)
        except:
            RETURNS[tkr] = [0.0] * len(ESG[tkr])

    agg_true = {s: [] for s in SERIES}
    agg_ev   = {s: [] for s in SERIES}
    agg_head = {s: [] for s in SERIES}

    n_used = 0
    for ticker in ESG.keys():
        if ticker not in ENV or ticker not in SOC or ticker not in GOV:
            continue
        if len(ESG[ticker]) < cfg.K + 5:
            continue

        preds_ev, preds_head, trues = forecast_company_compare(
            ticker,
            ESG[ticker],
            ENV[ticker],
            SOC[ticker],
            GOV[ticker],
            RETURNS.get(ticker, [0.0]*len(ESG[ticker])),
        )
        n_used += 1

        for s in SERIES:
            agg_true[s].extend(trues[s])
            agg_ev[s].extend(preds_ev[s])
            agg_head[s].extend(preds_head[s])

    print(f"\n[OOS] companies used in {EX}: {n_used}")

    for s in SERIES:
        if len(agg_true[s]) == 0:
            print(f"  {s}: N/A (no samples)")
            continue

        met_ev = compute_metrics_dict(agg_true[s], agg_ev[s])
        met_hd = compute_metrics_dict(agg_true[s], agg_head[s])

        print("\nSeries:", s)
        print("  TokenEV :", {k: round(v, 4) for k, v in met_ev.items()})
        print("  RegHead :", {k: round(v, 4) for k, v in met_hd.items()})

        for method_name, met in [("TokenEV", met_ev), ("RegHead", met_hd)]:
            rows.append({
                "exchange": EX,
                "series": s,
                "method": method_name,
                "n_samples": len(agg_true[s]),
                **met
            })

df_oos = pd.DataFrame(rows)
display(df_oos.sort_values(["exchange","series","method"]).reset_index(drop=True))

pivot = df_oos.pivot_table(
    index=["exchange","series"],
    columns="method",
    values=["MSE","MAE","MAPE","TheilU","Bias"],
    aggfunc="mean"
)
display(pivot)


import pandas as pd
from collections import defaultdict

@torch.no_grad()
def extract_supervised_positions(labels: torch.Tensor) -> torch.Tensor:
    """
    labels: [B,T], exactly one != -100 per row.
    Returns tpos [B] giving supervised token index.
    """
    pos = (labels != -100).nonzero(as_tuple=False)
    tpos = pos[:,1]
    return tpos

@torch.no_grad()
def expected_numeric_pred(out_logits: torch.Tensor, labels: torch.Tensor, series_list: List[str]) -> Tuple[np.ndarray, np.ndarray]:
    """
    Returns:
      y_true_numeric [N]
      y_pred_numeric [N]  (expected value over allowed series tokens)
    """
    B,T,V = out_logits.shape
    tpos = extract_supervised_positions(labels)
    logits_last = out_logits[torch.arange(B, device=out_logits.device), tpos, :]

    y_true = []
    y_pred = []
    for i in range(B):
        ser = series_list[i]
        tid = int(labels[i, int(tpos[i].item())].item())
        true_tok = tokenizer.convert_ids_to_tokens(tid)
        try:
            y_true_val = token_to_float(true_tok)
        except:
            y_true_val = 0.0

        ids = allowed_ids[ser]
        if ids.numel() == 0:
            pid = int(torch.argmax(logits_last[i]).item())
            pred_tok = tokenizer.convert_ids_to_tokens(pid)
            try:
                y_pred_val = token_to_float(pred_tok)
            except:
                y_pred_val = y_true_val
        else:
            ev = expected_value_from_logits(logits_last[i:i+1], ids, id_to_val[ser])[0].item()
            y_pred_val = float(ev)

        y_true.append(float(y_true_val))
        y_pred.append(float(y_pred_val))

    return np.array(y_true, dtype=float), np.array(y_pred, dtype=float)

def theils_u(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    """
    Theil's U (U1 variant):
      U = RMSE(pred) / (RMSE(naive))
    naive forecast is y_{t-1} -> y_t ; but we don't have y_{t-1} here.
    We'll approximate naive using y_pred = y_true mean-shift? Not correct.
    Better: compute U2 requires previous true.
    Since our dataset stores hist_vals, we'll compute Theil-U2 later using prev value.
    """
    return float("nan")

def theils_u2_with_prev(y_true: np.ndarray, y_pred: np.ndarray, y_prev: np.ndarray) -> float:
    """
    Theil's U2:
      U2 = RMSE(pred vs true) / RMSE(prev vs true)
    """
    rmse_p = np.sqrt(np.mean((y_pred - y_true)**2) + 1e-12)
    rmse_n = np.sqrt(np.mean((y_prev - y_true)**2) + 1e-12)
    return float(rmse_p / (rmse_n + 1e-12))

def safe_mape(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    denom = np.maximum(np.abs(y_true), 1e-6)
    return float(np.mean(np.abs((y_pred - y_true) / denom)) * 100.0)

def bias(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    return float(np.mean(y_pred - y_true))

@torch.no_grad()
def eval_forecast_metrics(dataloader, split_name="TEST") -> Tuple[pd.DataFrame, pd.DataFrame]:
    """
    Returns:
      df_series_macro: rows per series with overall macro metrics
      df_company: rows per (series,ticker) with metrics
    """
    lm.eval()
    rows = []

    for batch_i, batch in enumerate(dataloader):
        out = lm(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["labels"],
            output_hidden_states=False,
        )

        y_true, y_pred = expected_numeric_pred(out.logits, batch["labels"], batch["target_series"])

        y_prev = []
        for i in range(len(batch["target_series"])):
            ser = batch["target_series"][i]
            hv = batch["hist_vals"][i].get(ser, [])
            if hv is None or len(hv) == 0:
                y_prev.append(float(y_true[i]))
            else:
                y_prev.append(float(hv[-1]))
        y_prev = np.array(y_prev, dtype=float)

        if split_name == "TEST" and cfg.PRINT_ALL_TEST_SAMPLES and batch_i < 3:
            preds_tok = predict_from_labels_position(out.logits, batch["labels"], batch["target_series"])
            for i in range(min(5, len(preds_tok))):
                log(f"[{split_name} PRED] ticker={batch['ticker'][i]} series={batch['target_series'][i]} "
                    f"true={batch['target_token'][i]} pred_tok={preds_tok[i]} pred_val={y_pred[i]:.2f}")

                print("----- PROMPT (HEAD) -----")
                print(batch["text"][i][:900])
                print("-------------------------\n")

        for i in range(len(y_true)):
            rows.append({
                "split": split_name,
                "ticker": batch["ticker"][i],
                "series": batch["target_series"][i],
                "t_index": batch["t_index"][i],
                "y_true": float(y_true[i]),
                "y_pred": float(y_pred[i]),
                "y_prev": float(y_prev[i]),
            })

    df = pd.DataFrame(rows)

    comp_rows = []
    for (ser, tick), g in df.groupby(["series", "ticker"]):
        yt = g["y_true"].to_numpy(float)
        yp = g["y_pred"].to_numpy(float)
        yv = g["y_prev"].to_numpy(float)
        comp_rows.append({
            "split": split_name,
            "series": ser,
            "ticker": tick,
            "n": len(g),
            "MSE": float(np.mean((yp-yt)**2)),
            "MAE": float(np.mean(np.abs(yp-yt))),
            "MAPE": safe_mape(yt, yp),
            "TheilsU2": theils_u2_with_prev(yt, yp, yv),
            "Bias": bias(yt, yp),
        })
    df_company = pd.DataFrame(comp_rows).sort_values(["series","MSE"])

    macro_rows = []
    for ser, g in df_company.groupby("series"):
        macro_rows.append({
            "split": split_name,
            "series": ser,
            "n_companies": int(g["ticker"].nunique()),
            "MSE": float(g["MSE"].mean()),
            "MAE": float(g["MAE"].mean()),
            "MAPE": float(g["MAPE"].mean()),
            "TheilsU2": float(g["TheilsU2"].mean()),
            "Bias": float(g["Bias"].mean()),
        })
    df_series_macro = pd.DataFrame(macro_rows).sort_values("series")

    return df_series_macro, df_company

val_macro, val_company = eval_forecast_metrics(val_dl, split_name="VAL")
test_macro, test_company = eval_forecast_metrics(test_dl, split_name="TEST")

log("[METRICS] VAL per-series macro:")
display(val_macro)

log("[METRICS] TEST per-series macro:")
display(test_macro)

log("[METRICS] TEST top-10 best company-series by MSE:")
display(test_company.head(10))

log("[METRICS] TEST bottom-10 worst company-series by MSE:")
display(test_company.tail(10))

@torch.no_grad()
def compute_facet_vectors(dataloader) -> pd.DataFrame:
    """
    For each sample, compute z = hidden_state at target facet position.
    Return a dataframe with columns: ticker, series, t_index, z (np array), hist (np array)
    """
    lm.eval()
    rows = []
    for batch in dataloader:
        out = lm(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["labels"],
            output_hidden_states=True,
        )
        h_last = out.hidden_states[-1]
        B,T,Hdim = h_last.shape
        for i in range(B):
            ser = batch["target_series"][i]
            p = int(batch["facet_pos"][i].item())
            if p < 0 or p >= T:
                continue
            z_full = h_last[i, p, :].detach()
            z_geom = z_full[: z_full.numel()//2].cpu().numpy()

            hist = np.array(batch["hist_vals"][i].get(ser, [])[-cfg.K:], dtype=float)
            rows.append({
                "ticker": batch["ticker"][i],
                "series": ser,
                "t_index": int(batch["t_index"][i]),
                "z": z,
                "hist": hist,
            })
    return pd.DataFrame(rows)

df_z = compute_facet_vectors(test_dl)
log(f"[GEOM] extracted facet vectors on TEST: n={len(df_z)}")
display(df_z.head(3))

def pairwise_geom_metrics_for_series(df_ser: pd.DataFrame, max_pairs: int = 2000) -> Dict[str, float]:
    """
    Sample random pairs, compute:
      - DTW distance in raw space
      - L2 distance in z space
    Report Spearman correlation.
    """
    if len(df_ser) < 5:
        return {"spearman": float("nan"), "pairs": 0}

    rng = np.random.default_rng(cfg.SEED)
    idx = np.arange(len(df_ser))
    pairs = []
    for _ in range(min(max_pairs, len(df_ser)*(len(df_ser)-1)//2)):
        i, j = rng.choice(idx, size=2, replace=False)
        pairs.append((i,j))

    dtw_list = []
    z_list = []
    for i,j in pairs:
        xi = df_ser.iloc[i]["hist"]
        xj = df_ser.iloc[j]["hist"]
        if len(xi)==0 or len(xj)==0:
            continue
        dd = dtw_sakoe_chiba(xi.tolist(), xj.tolist(), band=cfg.DTW_BAND)
        zi = df_ser.iloc[i]["z"]
        zj = df_ser.iloc[j]["z"]
        zd = float(np.linalg.norm(zi - zj))
        dtw_list.append(dd)
        z_list.append(zd)

    if len(dtw_list) < 10:
        return {"spearman": float("nan"), "pairs": len(dtw_list)}

    from scipy.stats import spearmanr
    sp = float(spearmanr(dtw_list, z_list).correlation)
    return {"spearman": sp, "pairs": len(dtw_list)}

geom_rows = []
for ser in ["ESG","ENV","SOC","GOV"]:
    df_ser = df_z[df_z["series"] == ser]
    m = pairwise_geom_metrics_for_series(df_ser, max_pairs=2000)
    geom_rows.append({"series": ser, **m})
df_geom = pd.DataFrame(geom_rows)

log("[GEOM] Spearman corr(DTW, z-dist) per series:")
display(df_geom)

def knn_overlap(df_ser: pd.DataFrame, k: int = 5, n_queries: int = 30) -> float:
    """
    For random queries:
      - find top-k neighbors by z-dist
      - find top-k neighbors by DTW
      compute overlap fraction
    """
    if len(df_ser) < k + 2:
        return float("nan")

    rng = np.random.default_rng(cfg.SEED)
    idxs = np.arange(len(df_ser))
    Q = min(n_queries, len(df_ser))

    overlaps = []
    for qi in rng.choice(idxs, size=Q, replace=False):
        zq = df_ser.iloc[qi]["z"]
        xq = df_ser.iloc[qi]["hist"]

        zd = []
        for j in idxs:
            if j == qi:
                continue
            zj = df_ser.iloc[j]["z"]
            zd.append((j, float(np.linalg.norm(zq - zj))))
        z_neighbors = [j for j,_ in sorted(zd, key=lambda x: x[1])[:k]]

        dd = []
        for j in idxs:
            if j == qi:
                continue
            xj = df_ser.iloc[j]["hist"]
            d = dtw_sakoe_chiba(xq.tolist(), xj.tolist(), band=cfg.DTW_BAND)
            dd.append((j, float(d)))
        d_neighbors = [j for j,_ in sorted(dd, key=lambda x: x[1])[:k]]

        ov = len(set(z_neighbors).intersection(set(d_neighbors))) / float(k)
        overlaps.append(ov)

    return float(np.mean(overlaps))

overlap_rows = []
for ser in ["ESG","ENV","SOC","GOV"]:
    df_ser = df_z[df_z["series"] == ser]
    ov = knn_overlap(df_ser, k=5, n_queries=30)
    overlap_rows.append({"series": ser, "knn_overlap@5": ov})
df_overlap = pd.DataFrame(overlap_rows)

log("[RETRIEVE] kNN overlap@5 (z-space vs DTW) per series:")
display(df_overlap)


def ranking_experiment(df_pred: pd.DataFrame) -> pd.DataFrame:
    out = []
    for tick, g in df_pred.groupby("ticker"):
        g_esg = g[g["series"]=="ESG"]
        g_ret = g[g["series"]=="RET"]
        esg_improve = float(np.mean(g_esg["y_pred"] - g_esg["y_prev"])) if len(g_esg)>0 else 0.0
        ret_improve = float(np.mean(g_ret["y_pred"] - g_ret["y_prev"])) if len(g_ret)>0 else 0.0
        score = (-esg_improve) + (ret_improve)
        out.append({"ticker": tick, "esg_delta_pred": esg_improve, "ret_delta_pred": ret_improve, "score": score})
    return pd.DataFrame(out).sort_values("score", ascending=False)

@torch.no_grad()
def collect_pred_rows(dataloader, split_name="TEST") -> pd.DataFrame:
    lm.eval()
    rows = []
    for batch in dataloader:
        out = lm(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["labels"],
            output_hidden_states=False,
        )
        y_true, y_pred = expected_numeric_pred(out.logits, batch["labels"], batch["target_series"])

        y_prev = []
        for i in range(len(batch["target_series"])):
            ser = batch["target_series"][i]
            hv = batch["hist_vals"][i].get(ser, [])
            y_prev.append(float(hv[-1]) if hv and len(hv)>0 else float(y_true[i]))

        for i in range(len(y_true)):
            rows.append({
                "split": split_name,
                "ticker": batch["ticker"][i],
                "series": batch["target_series"][i],
                "t_index": batch["t_index"][i],
                "y_true": float(y_true[i]),
                "y_pred": float(y_pred[i]),
                "y_prev": float(y_prev[i]),
            })
    return pd.DataFrame(rows)

df_pred_test = collect_pred_rows(test_dl, "TEST")

rank_df = ranking_experiment(df_pred_test)
log("[RANK] Top-15 tickers by combined (ESG down good + RET up good) predicted score:")
display(rank_df.head(15))

val_macro.to_csv("val_series_macro.csv", index=False)
test_macro.to_csv("test_series_macro.csv", index=False)
test_company.to_csv("test_company_metrics.csv", index=False)
df_geom.to_csv("geom_spearman_by_series.csv", index=False)
df_overlap.to_csv("knn_overlap_by_series.csv", index=False)
rank_df.to_csv("ranking_experiment.csv", index=False)

log("[SAVED] metrics CSV files written to current folder.")
log("Phase 10 complete.")

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

def plot_cos_vs_dtw(
    df,
    series,
    max_pairs=3000,
    seed=0,
    bins=30,
    save_prefix="cos_vs_dtw"
):
    rng = np.random.default_rng(seed)

    def normalize_ts(x):
        x = np.asarray(x, dtype=float)
        mu = x.mean()
        sigma = x.std()
        if sigma < 1e-12:
            return x * 0.0
        return (x - mu) / sigma

    X = df[series].values
    n = len(X)

    pairs = set()
    while len(pairs) < min(max_pairs, n * (n - 1) // 2):
        i = rng.integers(0, n)
        j = rng.integers(0, n)
        if i != j:
            pairs.add((min(i, j), max(i, j)))

    cos_vals = []
    dtw_vals = []

    def dtw_distance(a, b):
        n, m = len(a), len(b)
        D = np.full((n+1, m+1), np.inf)
        D[0, 0] = 0.0
        for i in range(1, n+1):
            for j in range(1, m+1):
                cost = abs(a[i-1] - b[j-1])
                D[i, j] = cost + min(D[i-1, j], D[i, j-1], D[i-1, j-1])
        return D[n, m]

    for i, j in pairs:
        s1 = normalize_ts(X[i])
        s2 = normalize_ts(X[j])

        num = np.dot(s1, s2)
        den = np.linalg.norm(s1) * np.linalg.norm(s2) + 1e-12
        cos = num / den

        dtw = dtw_distance(s1, s2)

        cos_vals.append(cos)
        dtw_vals.append(dtw)

    cos_vals = np.array(cos_vals)
    dtw_vals = np.array(dtw_vals)

    spearman = pd.Series(cos_vals).corr(pd.Series(dtw_vals), method="spearman")

    plt.figure(figsize=(7, 5))
    plt.scatter(cos_vals, dtw_vals, s=6, alpha=0.4)
    plt.xlabel("Cosine similarity (z-normalized series)")
    plt.ylabel("DTW distance (z-normalized series)")
    plt.title(f"{series}: Spearman ρ = {spearman:.3f}")
    plt.grid(True, alpha=0.3)

    fname = f"{save_prefix}_{series}.png"
    plt.tight_layout()
    plt.savefig(fname, dpi=150)
    plt.close()

    return {
        "series": series,
        "spearman": spearman,
        "cos_mean": float(np.mean(cos_vals)),
        "dtw_mean": float(np.mean(dtw_vals)),
        "n_pairs": len(cos_vals),
    }

cos_rows = []
for ser in ["ESG", "ENV", "SOC", "GOV"]:
    r = plot_cos_vs_dtw(
        df_z,
        series=ser,
        max_pairs=3000,
        seed=cfg.SEED,
        bins=30,
        save_prefix="cos_vs_dtw"
    )
    cos_rows.append(r)

df_cos = pd.DataFrame(cos_rows)
log("[GEOM] Spearman corr(DTW, cosine(FACET)) per series:")
display(df_cos)

df_cos.to_csv("geom_cosine_vs_dtw_spearman.csv", index=False)
log("[SAVED] cos_vs_dtw_*.png + geom_cosine_vs_dtw_spearman.csv written.")

df_cos

import matplotlib.pyplot as plt

plt.figure(figsize=(8,6))
plt.scatter(df_cos["dtw_mean"], df_cos["cos_mean"], alpha=0.7)

plt.xlabel("dtw_mean")
plt.ylabel("cos_mean")
plt.title("Relationship between dtw_mean and cos_mean")

plt.grid(True)
plt.show()

import matplotlib.pyplot as plt

plt.figure(figsize=(10,7))

for s in df_cos["series"].unique():
    sub = df_cos[df_cos["series"] == s]
    plt.scatter(sub["dtw_mean"], sub["cos_mean"], label=str(s), alpha=0.7)

plt.xlabel("dtw_mean")
plt.ylabel("cos_mean")
plt.title("dtw_mean vs cos_mean by series")
plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
plt.grid(True)
plt.show()

import matplotlib.pyplot as plt
import seaborn as sns

plt.figure(figsize=(10,7))

df_plot = df_cos.sort_values("dtw_mean")

sns.lineplot(
    data=df_plot,
    x="dtw_mean",
    y="cos_mean",
    hue="series",
    marker="o"
)

plt.xlabel("dtw_mean")
plt.ylabel("cos_mean")
plt.title("dtw_mean vs cos_mean by series")
plt.grid(True)
plt.show()

import matplotlib.pyplot as plt

series_vals = df_cos["series"].dropna().unique()

for s in series_vals:
    sub = df_cos[df_cos["series"] == s].sort_values("dtw_mean")

    plt.figure(figsize=(8,5))
    plt.plot(sub["dtw_mean"], sub["cos_mean"], marker="o")
    plt.xlabel("dtw_mean")
    plt.ylabel("cos_mean")
    plt.title(f"dtw_mean vs cos_mean | series = {s}")
    plt.grid(True)
    plt.show()

df_cos.groupby("series").size()

import numpy as np
import matplotlib.pyplot as plt


PAIR_SAMPLES = 6000
N_BINS = 30

def _dtw(a, b):
    return float(dtw_sakoe_chiba(a.tolist(), b.tolist(), band=cfg.DTW_BAND))

def cosine_sim(a, b, eps=1e-12):
    a = np.asarray(a, dtype=np.float32)
    b = np.asarray(b, dtype=np.float32)
    na = np.linalg.norm(a)
    nb = np.linalg.norm(b)
    return float(np.dot(a, b) / (max(na * nb, eps)))

def sample_pairs(n, m, rng):
    idx = np.arange(n)
    pairs = []
    m = min(m, n * (n - 1) // 2)
    for _ in range(m):
        i, j = rng.choice(idx, size=2, replace=False)
        pairs.append((int(i), int(j)))
    return pairs

def dtw_vs_cos_line(arr, pair_samples=6000, n_bins=30, seed=0):
    if len(arr) < 25:
        return None

    rng = np.random.default_rng(seed)
    pairs = sample_pairs(len(arr), pair_samples, rng)

    dtw_list = []
    cos_list = []
    for i, j in pairs:
        dtw_list.append(_dtw(arr[i]["hist"], arr[j]["hist"]))
        cos_list.append(cosine_sim(arr[i]["z"], arr[j]["z"]))

    dtw_arr = np.asarray(dtw_list, dtype=np.float32)
    cos_arr = np.asarray(cos_list, dtype=np.float32)

    edges = np.quantile(dtw_arr, np.linspace(0, 1, n_bins + 1))

    xs = []
    ys = []
    ns = []

    for b in range(n_bins):
        lo, hi = edges[b], edges[b + 1]
        if b == n_bins - 1:
            mask = (dtw_arr >= lo) & (dtw_arr <= hi)
        else:
            mask = (dtw_arr >= lo) & (dtw_arr < hi)

        if mask.sum() == 0:
            continue

        xs.append(float(0.5 * (lo + hi)))
        ys.append(float(cos_arr[mask].mean()))
        ns.append(int(mask.sum()))

    return np.asarray(xs), np.asarray(ys), np.asarray(ns), edges

for ser in SERIES_LIST:
    out = dtw_vs_cos_line(
        store[ser],
        pair_samples=PAIR_SAMPLES,
        n_bins=N_BINS,
        seed=cfg.SEED + hash(ser) % 10_000
    )
    if out is None:
        print(f"{ser}: not enough data to plot.")
        continue

    xs, ys, ns, edges = out

    plt.figure(figsize=(7, 4))
    plt.plot(xs, ys, marker="o")
    plt.xlabel("DTW distance")
    plt.ylabel("FACET cosine similarity")
    plt.title(f"{ser}: DTW vs FACET cosine similarity (binned means)")
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

import numpy as np
import matplotlib.pyplot as plt

PAIR_SAMPLES = 1000
N_BINS = 10

def _dtw(a, b):
    return float(dtw_sakoe_chiba(a.tolist(), b.tolist(), band=cfg.DTW_BAND))

def cosine_sim(a, b, eps=1e-12):
    a = np.asarray(a, dtype=np.float32)
    b = np.asarray(b, dtype=np.float32)
    na = np.linalg.norm(a)
    nb = np.linalg.norm(b)
    return float(np.dot(a, b) / max(na * nb, eps))

def sample_pairs(n, m, rng):
    idx = np.arange(n)
    pairs = []
    m = min(m, n * (n - 1) // 2)
    for _ in range(m):
        i, j = rng.choice(idx, size=2, replace=False)
        pairs.append((int(i), int(j)))
    return pairs

def dtw_vs_cos_line(arr, pair_samples=6000, n_bins=30, seed=0):
    if len(arr) < 25:
        return None

    rng = np.random.default_rng(seed)
    pairs = sample_pairs(len(arr), pair_samples, rng)

    dtw_list = []
    cos_list = []

    for i, j in pairs:
        dtw_list.append(_dtw(arr[i]["hist"], arr[j]["hist"]))
        cos_list.append(cosine_sim(arr[i]["z"], arr[j]["z"]))

    dtw_arr = np.asarray(dtw_list, dtype=np.float32)
    cos_arr = np.asarray(cos_list, dtype=np.float32)

    dtw_min, dtw_max = dtw_arr.min(), dtw_arr.max()
    cos_min, cos_max = cos_arr.min(), cos_arr.max()

    dtw_n = (dtw_arr - dtw_min) / (dtw_max - dtw_min + 1e-9)
    cos_n = (cos_arr - cos_min) / (cos_max - cos_min + 1e-9)

    edges = np.quantile(dtw_n, np.linspace(0, 1, n_bins + 1))

    xs = []
    ys = []

    for b in range(n_bins):
        lo, hi = edges[b], edges[b + 1]
        if b == n_bins - 1:
            mask = (dtw_n >= lo) & (dtw_n <= hi)
        else:
            mask = (dtw_n >= lo) & (dtw_n < hi)

        if mask.sum() == 0:
            continue

        xs.append(float(0.5 * (lo + hi)))
        ys.append(float(cos_n[mask].mean()))

    return np.asarray(xs), np.asarray(ys)

plt.rcParams.update({
    "font.size": 14,
    "font.weight": "bold",
    "axes.labelweight": "bold",
    "axes.titleweight": "bold",
})

COLORS = {
    "ESG": "#ff0055",
    "ENV": "#00cc66",
    "SOC": "#0066ff",
    "GOV": "#ff9900",
}

for ser in SERIES_LIST:
    out = dtw_vs_cos_line(
        store[ser],
        pair_samples=PAIR_SAMPLES,
        n_bins=N_BINS,
        seed=cfg.SEED + hash(ser) % 10_000
    )

    if out is None:
        print(f"{ser}: not enough data.")
        continue

    xs, ys = out

    plt.figure(figsize=(5.2, 3.6))
    plt.plot(xs, ys, marker="o", linewidth=2.5, color=COLORS.get(ser, None))

    plt.xlabel("Normalized DTW distance", fontsize=14, fontweight="bold")
    plt.ylabel("Normalized FACET cosine similarity", fontsize=14, fontweight="bold")
    plt.title(f"{ser}", fontsize=15, fontweight="bold")

    plt.grid(True, alpha=0.3)
    plt.xlim(0, 0.5)
    plt.ylim( 0.6,1)

    plt.tight_layout()
    plt.show()

import numpy as np
import matplotlib.pyplot as plt

PAIR_SAMPLES = 1000
N_BINS = 10
STRETCH_BINNED_Y_TO_01 = False

def _dtw(a, b):
    return float(dtw_sakoe_chiba(a.tolist(), b.tolist(), band=cfg.DTW_BAND))

def cosine_sim(a, b, eps=1e-12):
    a = np.asarray(a, dtype=np.float32)
    b = np.asarray(b, dtype=np.float32)
    na = np.linalg.norm(a)
    nb = np.linalg.norm(b)
    return float(np.dot(a, b) / max(na * nb, eps))

def sample_pairs(n, m, rng):
    idx = np.arange(n)
    m = min(m, n * (n - 1) // 2)
    pairs = []
    for _ in range(m):
        i, j = rng.choice(idx, size=2, replace=False)
        pairs.append((int(i), int(j)))
    return pairs

def dtw_vs_cos_line(arr, pair_samples=1000, n_bins=10, seed=0):
    if len(arr) < 25:
        return None

    rng = np.random.default_rng(seed)
    pairs = sample_pairs(len(arr), pair_samples, rng)

    dtw_list, cos_list = [], []
    for i, j in pairs:
        dtw_list.append(_dtw(arr[i]["hist"], arr[j]["hist"]))
        cos_list.append(cosine_sim(arr[i]["z"], arr[j]["z"]))

    dtw_arr = np.asarray(dtw_list, dtype=np.float32)
    cos_arr = np.asarray(cos_list, dtype=np.float32)

    dtw_n = (dtw_arr - dtw_arr.min()) / (dtw_arr.max() - dtw_arr.min() + 1e-9)
    cos_n = (cos_arr - cos_arr.min()) / (cos_arr.max() - cos_arr.min() + 1e-9)

    edges = np.linspace(0.0, 1.0, n_bins + 1)

    xs = []
    ys = []
    for b in range(n_bins):
        lo, hi = edges[b], edges[b + 1]
        if b == n_bins - 1:
            mask = (dtw_n >= lo) & (dtw_n <= hi)
        else:
            mask = (dtw_n >= lo) & (dtw_n < hi)

        xs.append(float(0.5 * (lo + hi)))
        ys.append(float(np.mean(cos_n[mask])) if mask.sum() > 0 else np.nan)

    xs = np.asarray(xs, dtype=np.float32)
    ys = np.asarray(ys, dtype=np.float32)

    keep = ~np.isnan(ys)
    xs, ys = xs[keep], ys[keep]

    if STRETCH_BINNED_Y_TO_01 and len(ys) > 1:
        ys = (ys - ys.min()) / (ys.max() - ys.min() + 1e-9)

    return xs, ys

plt.rcParams.update({
    "font.size": 14,
    "font.weight": "bold",
    "axes.labelweight": "bold",
    "axes.titleweight": "bold",
})

COLORS = {
    "ESG": "#ff0055",
    "ENV": "#00cc66",
    "SOC": "#0066ff",
    "GOV": "#ff9900",
}

for ser in SERIES_LIST:
    out = dtw_vs_cos_line(
        store[ser],
        pair_samples=PAIR_SAMPLES,
        n_bins=N_BINS,
        seed=cfg.SEED + hash(ser) % 10_000
    )
    if out is None:
        print(f"{ser}: not enough data.")
        continue

    xs, ys = out

    plt.figure(figsize=(5.2, 3.6))
    plt.plot(xs, ys, marker="o", linewidth=2.5, color=COLORS.get(ser, None))

    plt.xlabel("DTW distance (scaled 0→1)", fontsize=14, fontweight="bold")
    plt.ylabel("Cosine similarity (scaled 0→1)", fontsize=14, fontweight="bold")
    plt.title(f"{ser}", fontsize=15, fontweight="bold")

    plt.grid(True, alpha=0.3)
    plt.xlim(0, 1)
    plt.ylim(0, 1)

    plt.tight_layout()
    plt.show()

import numpy as np
import matplotlib.pyplot as plt

PAIR_SAMPLES = 6000
N_BINS = 30

def _dtw(a, b):
    return float(dtw_sakoe_chiba(a.tolist(), b.tolist(), band=cfg.DTW_BAND))

def cosine_sim(a, b, eps=1e-12):
    a = np.asarray(a, dtype=np.float32)
    b = np.asarray(b, dtype=np.float32)
    na = np.linalg.norm(a)
    nb = np.linalg.norm(b)
    return float(np.dot(a, b) / max(na * nb, eps))

def sample_pairs(n, m, rng):
    idx = np.arange(n)
    m = min(m, n * (n - 1) // 2)
    pairs = []
    for _ in range(m):
        i, j = rng.choice(idx, size=2, replace=False)
        pairs.append((int(i), int(j)))
    return pairs

def dtw_vs_cos_line_uniform_bins(arr, pair_samples=6000, n_bins=30, seed=0):
    if len(arr) < 25:
        return None

    rng = np.random.default_rng(seed)
    pairs = sample_pairs(len(arr), pair_samples, rng)

    dtw_list, cos_list = [], []
    for i, j in pairs:
        dtw_list.append(_dtw(arr[i]["hist"], arr[j]["hist"]))
        cos_list.append(cosine_sim(arr[i]["z"], arr[j]["z"]))

    dtw = np.asarray(dtw_list, dtype=np.float32)
    cos = np.asarray(cos_list, dtype=np.float32)

    dtw01 = dtw / (dtw.max() + 1e-9)

    cos01 = (cos + 1.0) / 2.0

    edges = np.linspace(0.0, 1.0, n_bins + 1)

    xs, ys = [], []
    for b in range(n_bins):
        lo, hi = edges[b], edges[b + 1]
        if b == n_bins - 1:
            mask = (dtw01 >= lo) & (dtw01 <= hi)
        else:
            mask = (dtw01 >= lo) & (dtw01 < hi)

        if mask.sum() == 0:
            continue

        xs.append(float(0.5 * (lo + hi)))
        ys.append(float(cos01[mask].mean()))

    return np.asarray(xs), np.asarray(ys)

plt.rcParams.update({
    "font.size": 14,
    "font.weight": "bold",
    "axes.labelweight": "bold",
    "axes.titleweight": "bold",
})

COLORS = {
    "ESG": "#ff0055",
    "ENV": "#00cc66",
    "SOC": "#0066ff",
    "GOV": "#ff9900",
}

for ser in SERIES_LIST:
    out = dtw_vs_cos_line_uniform_bins(
        store[ser],
        pair_samples=PAIR_SAMPLES,
        n_bins=N_BINS,
        seed=cfg.SEED + (hash(ser) % 10_000),
    )
    if out is None:
        print(f"{ser}: not enough data.")
        continue

    xs, ys = out

    plt.figure(figsize=(5.2, 3.6))
    plt.plot(xs, ys, marker="o", linewidth=2.5, color=COLORS.get(ser))

    plt.xlabel("DTW distance (scaled to [0,1])", fontsize=14, fontweight="bold")
    plt.ylabel("FACET cosine similarity (mapped to [0,1])", fontsize=14, fontweight="bold")
    plt.title(f"{ser}", fontsize=15, fontweight="bold")

    plt.grid(True, alpha=0.3)
    plt.xlim(0, 1)
    plt.ylim(1.0, 0.6)
    plt.tight_layout()
    plt.show()

import numpy as np
import matplotlib.pyplot as plt

PAIR_SAMPLES = 6000
N_BINS = 30

def _dtw(a, b):
    return float(dtw_sakoe_chiba(a.tolist(), b.tolist(), band=cfg.DTW_BAND))

def cosine_sim(a, b, eps=1e-12):
    a = np.asarray(a, dtype=np.float32)
    b = np.asarray(b, dtype=np.float32)
    return float(np.dot(a, b) / max(np.linalg.norm(a) * np.linalg.norm(b), eps))

def sample_pairs(n, m, rng):
    idx = np.arange(n)
    m = min(m, n * (n - 1) // 2)
    pairs = []
    for _ in range(m):
        i, j = rng.choice(idx, size=2, replace=False)
        pairs.append((int(i), int(j)))
    return pairs

def make_line_xy(arr, pair_samples=6000, n_bins=30, seed=0):
    if len(arr) < 25:
        return None

    rng = np.random.default_rng(seed)
    pairs = sample_pairs(len(arr), pair_samples, rng)

    dtw_vals = np.empty(len(pairs), dtype=np.float32)
    cos_vals = np.empty(len(pairs), dtype=np.float32)

    for t, (i, j) in enumerate(pairs):
        dtw_vals[t] = _dtw(arr[i]["hist"], arr[j]["hist"])
        cos_vals[t] = cosine_sim(arr[i]["z"], arr[j]["z"])

    dtw_min, dtw_max = float(dtw_vals.min()), float(dtw_vals.max())
    x01 = (dtw_vals - dtw_min) / (dtw_max - dtw_min + 1e-9)

    y01 = (cos_vals + 1.0) / 2.0
    y01 = np.clip(y01, 0.0, 1.0)

    edges = np.linspace(0.0, 1.0, n_bins + 1)
    xs, ys = [], []

    for b in range(n_bins):
        lo, hi = edges[b], edges[b + 1]
        mask = (x01 >= lo) & (x01 < hi) if b < n_bins - 1 else (x01 >= lo) & (x01 <= hi)
        if mask.sum() == 0:
            xs.append(0.5 * (lo + hi))
            ys.append(np.nan)
        else:
            xs.append(0.5 * (lo + hi))
            ys.append(float(np.nanmean(y01[mask])))

    xs = np.asarray(xs, dtype=np.float32)
    ys = np.asarray(ys, dtype=np.float32)

    keep = ~np.isnan(ys)
    return xs[keep], ys[keep]

plt.rcParams.update({
    "font.size": 14,
    "font.weight": "bold",
    "axes.labelweight": "bold",
    "axes.titleweight": "bold",
})

COLORS = {
    "ESG": "#ff0055",
    "ENV": "#00e676",
    "SOC": "#2979ff",
    "GOV": "#ff9100",
}

for ser in SERIES_LIST:
    out = make_line_xy(
        store[ser],
        pair_samples=PAIR_SAMPLES,
        n_bins=N_BINS,
        seed=cfg.SEED + (hash(ser) % 10_000)
    )
    if out is None:
        print(f"{ser}: not enough data.")
        continue

    xs, ys = out

    plt.figure(figsize=(5.0, 3.3))
    plt.plot(xs, ys, marker="o", linewidth=3.0, color=COLORS.get(ser, "#ff00ff"))

    plt.xlabel("DTW distance (min-max scaled to [0,1])", fontsize=14, fontweight="bold")
    plt.ylabel("Cosine similarity (scaled to [0,1])", fontsize=14, fontweight="bold")
    plt.title(ser, fontsize=15, fontweight="bold")

    plt.grid(True, alpha=0.3)
    plt.xlim(0.0, 1.0)
    plt.ylim(0.0, 1.0)
    plt.tight_layout()
    plt.show()

import numpy as np
import matplotlib.pyplot as plt

PAIR_SAMPLES = 6000
N_BINS = 30

def _dtw(a, b):
    return float(dtw_sakoe_chiba(a.tolist(), b.tolist(), band=cfg.DTW_BAND))

def cosine_sim(a, b, eps=1e-12):
    a = np.asarray(a, dtype=np.float32)
    b = np.asarray(b, dtype=np.float32)
    return float(np.dot(a, b) / max(np.linalg.norm(a) * np.linalg.norm(b), eps))

def sample_pairs(n, m, rng):
    idx = np.arange(n)
    m = min(m, n * (n - 1) // 2)
    pairs = []
    for _ in range(m):
        i, j = rng.choice(idx, size=2, replace=False)
        pairs.append((int(i), int(j)))
    return pairs

def minmax01(x, eps=1e-12):
    x = np.asarray(x, dtype=np.float32)
    lo, hi = float(np.min(x)), float(np.max(x))
    return (x - lo) / (hi - lo + eps), lo, hi

def dtw_vs_cos_line_rescaled(arr, pair_samples=6000, n_bins=30, seed=0):
    if len(arr) < 25:
        return None

    rng = np.random.default_rng(seed)
    pairs = sample_pairs(len(arr), pair_samples, rng)

    dtw_vals = np.empty(len(pairs), dtype=np.float32)
    cos_vals = np.empty(len(pairs), dtype=np.float32)

    for t, (i, j) in enumerate(pairs):
        dtw_vals[t] = _dtw(arr[i]["hist"], arr[j]["hist"])
        cos_vals[t] = cosine_sim(arr[i]["z"], arr[j]["z"])

    x01, dtw_min, dtw_max = minmax01(dtw_vals)
    y01, cos_min, cos_max = minmax01(cos_vals)

    edges = np.linspace(0.0, 1.0, n_bins + 1)
    xs = np.empty(n_bins, dtype=np.float32)
    ys = np.empty(n_bins, dtype=np.float32)
    ys[:] = np.nan

    for b in range(n_bins):
        lo, hi = edges[b], edges[b + 1]
        xs[b] = 0.5 * (lo + hi)

        mask = (x01 >= lo) & (x01 < hi) if b < n_bins - 1 else (x01 >= lo) & (x01 <= hi)
        if mask.sum() > 0:
            ys[b] = float(np.mean(y01[mask]))

    keep = ~np.isnan(ys)
    return xs[keep], ys[keep], (dtw_min, dtw_max), (cos_min, cos_max)

plt.rcParams.update({
    "font.size": 14,
    "font.weight": "bold",
    "axes.labelweight": "bold",
    "axes.titleweight": "bold",
})

COLORS = {
    "ESG": "#ff0055",
    "ENV": "#00e676",
    "SOC": "#2979ff",
    "GOV": "#ff9100",
}

for ser in SERIES_LIST:
    out = dtw_vs_cos_line_rescaled(
        store[ser],
        pair_samples=PAIR_SAMPLES,
        n_bins=N_BINS,
        seed=cfg.SEED + (hash(ser) % 10_000)
    )

    if out is None:
        print(f"{ser}: not enough data.")
        continue

    xs, ys, (dtw_min, dtw_max), (cos_min, cos_max) = out

    plt.figure(figsize=(5.0, 3.3))
    plt.plot(xs, ys, marker="o", linewidth=3.0, color=COLORS.get(ser, "#ff00ff"))

    plt.xlabel("DTW distance (rescaled 0→1)", fontsize=14, fontweight="bold")
    plt.ylabel("Cosine similarity (rescaled 0→1)", fontsize=14, fontweight="bold")
    plt.title(f"{ser}", fontsize=15, fontweight="bold")

    plt.grid(True, alpha=0.3)
    plt.xlim(0.0, 1.0)
    plt.ylim(0.0, 1.0)

    print(f"{ser}  DTW range: [{dtw_min:.4f}, {dtw_max:.4f}]   Cos range: [{cos_min:.4f}, {cos_max:.4f}]")

    plt.tight_layout()
    plt.show()

import time

while True:
    print("Keeping Colab alive...")
    time.sleep(60)
