from __future__ import annotations

import argparse
import random
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Optional

import numpy as np
import torch
from PIL import Image
from tqdm import tqdm

_SCRIPT_DIR = Path(__file__).resolve().parent
_ROOT_DIR = _SCRIPT_DIR.parent
if str(_ROOT_DIR) not in sys.path:
    sys.path.insert(0, str(_ROOT_DIR))

from common.io import maybe_mkdir
from common.logging import setup_logging
from inference.attention_features import features_from_attentions
from inference.model import Qwen8B, ModelInputs
from inference.sources import (
    Source,
    SourceSpec,
    VisionTokenMap,
    build_source_membership,
    build_sources,
)
from inference.text_targets import targets_in_span
from training.loss import pearson_corr_loss
from training.model import LinearHeadEstimator


@dataclass(frozen=True)
class OnlineTrainConfig:
    weights_dir: Optional[Path]
    image: Path
    question: str
    out_dir: Path
    steps: int = 100
    lr: float = 1e-3
    num_masks: int = 32
    alpha: float = 0.5
    max_new_tokens: int = 256
    source_mode: str = "block"
    block_h: int = 2
    block_w: int = 2
    num_regions: int = 48
    source_seed: int = 0
    target_mode: str = "whole"
    checkpoint_every: int = 50
    seed: int = 42
    device: str = "cuda"
    allow_download: bool = False


def _parse_args() -> OnlineTrainConfig:
    p = argparse.ArgumentParser(description="End-to-end online training for vStream")
    p.add_argument("--weights_dir", type=Path, default=None)
    p.add_argument("--image", type=Path, required=True)
    p.add_argument("--question", type=str, required=True)
    p.add_argument("--out_dir", type=Path, required=True)

    p.add_argument("--steps", type=int, default=100)
    p.add_argument("--lr", type=float, default=1e-3)
    p.add_argument("--num_masks", type=int, default=32)
    p.add_argument("--alpha", type=float, default=0.5)
    p.add_argument("--max_new_tokens", type=int, default=256)

    p.add_argument("--source_mode", type=str, default="block")
    p.add_argument("--block_h", type=int, default=2)
    p.add_argument("--block_w", type=int, default=2)
    p.add_argument("--num_regions", type=int, default=48)
    p.add_argument("--source_seed", type=int, default=0)

    p.add_argument("--target_mode", type=str, default="whole")
    p.add_argument("--checkpoint_every", type=int, default=50)
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--device", type=str, default="cuda")
    p.add_argument("--allow_download", action="store_true")

    a = p.parse_args()
    return OnlineTrainConfig(
        weights_dir=None if a.weights_dir is None else Path(a.weights_dir),
        image=Path(a.image),
        question=str(a.question),
        out_dir=Path(a.out_dir),
        steps=int(a.steps),
        lr=float(a.lr),
        num_masks=int(a.num_masks),
        alpha=float(a.alpha),
        max_new_tokens=int(a.max_new_tokens),
        source_mode=str(a.source_mode),
        block_h=int(a.block_h),
        block_w=int(a.block_w),
        num_regions=int(a.num_regions),
        source_seed=int(a.source_seed),
        target_mode=str(a.target_mode),
        checkpoint_every=int(a.checkpoint_every),
        seed=int(a.seed),
        device=str(a.device),
        allow_download=bool(a.allow_download),
    )


def _set_seeds(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def sample_masks(
    num_masks: int, num_sources: int, alpha: float, seed: int
) -> np.ndarray:
    rng = np.random.RandomState(seed)
    return rng.choice(
        [False, True], size=(num_masks, num_sources), p=[1 - alpha, alpha]
    )


def masks_to_sign(masks_bool: np.ndarray) -> np.ndarray:
    return masks_bool.astype(np.float32) * 2.0 - 1.0


def token_logit_scores(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
    log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
    gathered = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
    return gathered


def sequence_log_odds(token_log_odds: torch.Tensor) -> np.ndarray:
    return token_log_odds.sum(dim=-1).cpu().numpy()


def main() -> None:
    cfg = _parse_args()
    logger = setup_logging("INFO")

    _set_seeds(cfg.seed)
    maybe_mkdir(cfg.out_dir)

    logger.info("Loading Qwen3-VL-8B-Thinking...")
    backend = Qwen8B(
        weights_dir=str(cfg.weights_dir) if cfg.weights_dir else None,
        device=cfg.device,
        allow_download=cfg.allow_download,
        attn_implementation="eager",
    )

    logger.info("Loading image: %s", cfg.image)
    image = Image.open(cfg.image).convert("RGB")

    logger.info("Building inputs and generating...")
    inputs, prepared_image = backend.build_inputs_and_image(
        image=image, question=cfg.question, force_think=True
    )

    ids, prompt_len = backend.generate(
        inputs, max_new_tokens=cfg.max_new_tokens, stop_at_end_think=True
    )

    spans = backend.find_spans(ids, prompt_len=prompt_len)
    if spans.think_start is None:
        raise RuntimeError("No <think> token found in generation")

    thought_span = spans.thought_span(seq_len=len(ids))
    if thought_span is None:
        raise RuntimeError("No valid thought span")
    thought_start, thought_end = thought_span

    token_positions, grid_h, grid_w = backend.build_vision_token_map(ids, inputs=inputs)
    token_map = VisionTokenMap(
        token_positions=token_positions, grid_h=grid_h, grid_w=grid_w
    )

    source_spec = SourceSpec(
        mode=cfg.source_mode,
        block_h=cfg.block_h,
        block_w=cfg.block_w,
        num_regions=cfg.num_regions,
        seed=cfg.source_seed,
        image_path=str(cfg.image),
    )
    sources = build_sources(token_map, source_spec)
    num_sources = len(sources)
    membership = build_source_membership(
        sources, num_visual_tokens=len(token_positions)
    )

    logger.info(
        "Generation: %d tokens, thought span [%d, %d), %d sources",
        len(ids),
        thought_start,
        thought_end,
        num_sources,
    )

    targets = targets_in_span(
        tokenizer=backend.tokenizer,
        full_input_ids=ids,
        span_start=thought_start,
        span_end=thought_end,
        mode=cfg.target_mode,
    )
    if not targets:
        raise RuntimeError("No targets found in thought span")
    logger.info("Targets: %d", len(targets))

    full_ids = torch.tensor([ids], device=backend.device, dtype=torch.long)
    full_attention_mask = torch.ones_like(full_ids, dtype=torch.long)

    new_inputs = ModelInputs(
        input_ids=full_ids,
        attention_mask=full_attention_mask,
        pixel_values=inputs.pixel_values,
        image_grid_thw=inputs.image_grid_thw,
    )

    position_ids = backend.compute_position_ids(
        input_ids=full_ids, attention_mask=full_attention_mask, inputs=new_inputs
    )

    logger.info("Extracting attention features...")
    with torch.no_grad():
        out = backend.forward_with_attentions(
            input_ids=full_ids,
            attention_mask=full_attention_mask,
            position_ids=position_ids,
            inputs=new_inputs,
        )

    feats = (
        features_from_attentions(
            attentions=out.attentions,
            target_token_ranges=targets,
            vision_token_positions=token_positions,
            source_membership=membership,
        )
        .cpu()
        .numpy()
    )  # (T, S, F)

    del out
    torch.cuda.empty_cache()

    logger.info("Features shape: %s", feats.shape)
    num_features = int(feats.shape[2])

    logger.info("Preparing inputs_embeds for ablation...")
    with torch.no_grad():
        cached_inputs_embeds = backend.prepare_inputs_embeds(new_inputs)

    source_positions: list[list[int]] = []
    for src in sources:
        positions = [token_positions[i] for i in src.token_indices]
        source_positions.append(positions)

    estimator = LinearHeadEstimator(num_features=num_features).to(cfg.device)
    opt = torch.optim.Adam(estimator.parameters(), lr=cfg.lr)

    logger.info("Starting online training for %d steps...", cfg.steps)
    t0 = time.perf_counter()

    pbar = tqdm(range(cfg.steps), desc="train")
    for step in pbar:
        mask_seed = cfg.seed + step
        masks_bool = sample_masks(cfg.num_masks, num_sources, cfg.alpha, mask_seed)
        masks_sign = masks_to_sign(masks_bool)

        thought_len = thought_end - thought_start
        token_log_odds_all = np.zeros((cfg.num_masks, thought_len), dtype=np.float32)

        for m_idx in range(cfg.num_masks):
            attn_mask = full_attention_mask[0].clone()
            for s_idx, keep in enumerate(masks_bool[m_idx]):
                if keep:
                    continue
                for pos in source_positions[s_idx]:
                    attn_mask[pos] = 0

            attn_mask_batch = attn_mask.unsqueeze(0)
            inputs_embeds_batch = cached_inputs_embeds.clone()

            with torch.no_grad():
                abl_out = backend.forward_ablation(
                    input_ids=full_ids,
                    attention_mask=attn_mask_batch,
                    position_ids=position_ids,
                    inputs_embeds=inputs_embeds_batch,
                )

            logits = abl_out.logits
            logits_slice = logits[:, thought_start - 1 : thought_end - 1, :]
            labels_slice = full_ids[:, thought_start:thought_end]
            log_odds = token_logit_scores(logits_slice, labels_slice)
            token_log_odds_all[m_idx] = log_odds[0].detach().cpu().float().numpy()

        outputs_by_target = np.zeros((len(targets), cfg.num_masks), dtype=np.float32)
        for t_idx, (t0_tok, t1_tok) in enumerate(targets):
            rs = t0_tok - thought_start
            re = t1_tok - thought_start
            cur = torch.tensor(token_log_odds_all[:, rs:re], dtype=torch.float32)
            outputs_by_target[t_idx] = sequence_log_odds(cur)

        feats_by_mask = np.einsum("ms,tsf->tmf", masks_sign, feats).astype(np.float32)

        x = torch.tensor(feats_by_mask, device=cfg.device, dtype=torch.float32)
        y = torch.tensor(outputs_by_target, device=cfg.device, dtype=torch.float32)

        opt.zero_grad(set_to_none=True)
        pred = estimator(x)
        loss = pearson_corr_loss(pred, y)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(estimator.parameters(), 1.0)
        opt.step()

        loss_val = float(loss.item())
        pbar.set_postfix(loss=f"{loss_val:.4f}")

        if cfg.checkpoint_every > 0 and (step + 1) % cfg.checkpoint_every == 0:
            ckpt_path = cfg.out_dir / f"estimator-step{step + 1}.pt"
            ckpt = {
                "state_dict": {
                    k: v.detach().cpu() for k, v in estimator.state_dict().items()
                },
                "num_features": int(estimator.num_features),
                "step": int(step + 1),
            }
            torch.save(ckpt, ckpt_path)
            logger.info("Checkpoint: %s", ckpt_path)

    estimator.finalize()
    final_path = cfg.out_dir / "estimator.pt"
    final_ckpt = {
        "state_dict": {k: v.detach().cpu() for k, v in estimator.state_dict().items()},
        "num_features": int(estimator.num_features),
        "finalized": True,
        "step": int(cfg.steps),
    }
    torch.save(final_ckpt, final_path)

    elapsed = time.perf_counter() - t0
    logger.info("Training complete. Saved: %s (%.1fs)", final_path, elapsed)


if __name__ == "__main__":
    main()
