import argparse
import math
import time
import logging
from collections.abc import Iterable, Sequence
from pathlib import Path
from typing import Any

import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader, Dataset

try:
    from diffusers.models.transformers import FluxTransformer2DModel
except ImportError as exc:  # pragma: no cover
    raise RuntimeError("diffusers>=0.29.0 is required to load FluxTransformer2DModel.") from exc


logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


# ---------------------------------------------------------------------------
# LoRA
# ---------------------------------------------------------------------------
class LoRALinear(nn.Module):
    """Wrap an nn.Linear layer with a learnable low-rank update."""

    def __init__(
        self,
        base: nn.Linear,
        rank: int = 64,
        alpha: float = 128.0,
        dropout: float = 0.0,
    ) -> None:
        super().__init__()
        self.base = base
        self.rank = rank
        self.alpha = alpha
        self.scaling = alpha / rank if rank > 0 else 0.0
        self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()

        self.base.weight.requires_grad_(False)
        if self.base.bias is not None:
            self.base.bias.requires_grad_(False)

        if rank > 0:
            device = base.weight.device
            self.lora_a = nn.Parameter(torch.zeros(rank, base.in_features, device=device))
            self.lora_b = nn.Parameter(torch.zeros(base.out_features, rank, device=device))
            nn.init.kaiming_uniform_(self.lora_a, a=math.sqrt(5))
            nn.init.zeros_(self.lora_b)
        else:
            self.register_parameter("lora_a", None)
            self.register_parameter("lora_b", None)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = self.base(x)
        if self.rank == 0:
            return out

        shape = x.shape
        x_2d = self.dropout(x.reshape(-1, shape[-1]))
        lora = x_2d @ self.lora_a.t()
        lora = lora @ self.lora_b.t()
        lora = lora.view(*shape[:-1], self.base.out_features)
        return out + lora * self.scaling

    @property
    def weight(self) -> torch.Tensor:  # pragma: no cover
        return self.base.weight

    @property
    def bias(self) -> torch.Tensor | None:  # pragma: no cover
        return self.base.bias


def _locate_attr(module: nn.Module, path: str) -> tuple[nn.Module, str]:
    parts = path.split(".")
    parent = module
    for name in parts[:-1]:
        parent = getattr(parent, name)
    return parent, parts[-1]


def inject_lora(
    module: nn.Module,
    targets: Sequence[str],
    rank: int = 64,
    alpha: float = 128.0,
    dropout: float = 0.0,
) -> list[LoRALinear]:
    wrapped: list[LoRALinear] = []
    for path in targets:
        parent, name = _locate_attr(module, path)
        base = getattr(parent, name)
        if not isinstance(base, nn.Linear):
            raise TypeError(f"{path} is not an nn.Linear (got {type(base).__name__})")
        lora = LoRALinear(base, rank=rank, alpha=alpha, dropout=dropout)
        setattr(parent, name, lora)
        wrapped.append(lora)
    return wrapped


def freeze_except_lora(root: nn.Module) -> None:
    for p in root.parameters():
        p.requires_grad = False
    for m in root.modules():
        if isinstance(m, LoRALinear):
            for p in m.parameters():
                if p is not None:
                    p.requires_grad = True


def iter_lora_parameters(root: nn.Module) -> Iterable[nn.Parameter]:
    for m in root.modules():
        if isinstance(m, LoRALinear):
            for p in m.parameters():
                if p is not None and p.requires_grad:
                    yield p


# ---------------------------------------------------------------------------
# Dataset
# ---------------------------------------------------------------------------
class FluxMultiRootDataset(Dataset):
    """
    Each __getitem__ returns a list of samples collected from multiple roots
    for the same step index.
    """

    def __init__(
        self,
        roots: Sequence[Path],
        block_idx: int,
        prev_block_idx: int,
        steps: Sequence[int],
        offset: int,
        meta_offset: int = 0,
        map_location: str | torch.device = "cpu",
        dtype: torch.dtype = torch.float32,
    ) -> None:
        self.roots = [Path(r) for r in roots]
        self.block_idx = block_idx
        self.prev_block_idx = prev_block_idx
        self.steps = list(steps)
        self.offset = offset
        self.meta_offset = meta_offset
        self.map_location = map_location
        self.dtype = dtype

    def _load_joint_kwargs(self, path: Path) -> dict[str, Any] | None:
        if path.exists():
            data = torch.load(path, map_location=self.map_location)
            if isinstance(data, dict):
                return data
        return None

    def _load_sample(self, root: Path, step: int) -> dict[str, Any] | None:
        src_step = step - self.offset
        meta_step = step - self.meta_offset

        x_path = root / f"singleblock_{self.prev_block_idx}_output_step_{src_step}.pt"
        y_path = root / f"singleblock_{self.block_idx}_output_step_{step}.pt"
        temb_path = root / f"temb_step_{meta_step}.pt"
        rot_path = root / f"image_rotary_emb_step_{meta_step}.pt"
        if not (x_path.exists() and y_path.exists() and temb_path.exists() and rot_path.exists()):
            return None

        try:
            return {
                "x_input": torch.load(x_path, map_location=self.map_location).to(self.dtype),
                "target": torch.load(y_path, map_location=self.map_location).to(self.dtype),
                "temb": torch.load(temb_path, map_location=self.map_location).to(self.dtype),
                "image_rotary_emb": torch.load(rot_path, map_location=self.map_location),
                "joint_attention_kwargs": self._load_joint_kwargs(root / f"joint_attention_kwargs_step_{meta_step}.pt"),
            }
        except Exception:
            return None

    def __getitem__(self, idx: int) -> list[dict[str, Any]]:
        step = self.steps[idx]
        batch_samples: list[dict[str, Any]] = []
        for root in self.roots:
            s = self._load_sample(root, step)
            if s is not None:
                s["step"] = step
                batch_samples.append(s)

        if not batch_samples:
            raise RuntimeError(f"No valid samples found for step {step}")
        return batch_samples

    def __len__(self) -> int:
        return len(self.steps)


# ---------------------------------------------------------------------------
# Utilities
# ---------------------------------------------------------------------------
def find_linear_paths(module: nn.Module, prefix: str = "") -> list[str]:
    paths: list[str] = []
    for name, child in module.named_children():
        child_prefix = f"{prefix}.{name}" if prefix else name
        if isinstance(child, nn.Linear):
            paths.append(child_prefix)
        else:
            paths.extend(find_linear_paths(child, child_prefix))
    return paths


def move_to(obj: Any, device: torch.device, dtype: torch.dtype) -> Any:
    if isinstance(obj, torch.Tensor):
        if torch.is_floating_point(obj):
            return obj.to(device=device, dtype=dtype)
        return obj.to(device=device)
    if isinstance(obj, dict):
        return {k: move_to(v, device, dtype) for k, v in obj.items()}
    if isinstance(obj, (list, tuple)):
        items = [move_to(v, device, dtype) for v in obj]
        return type(obj)(items)
    return obj


def autocast_dtype(name: str) -> torch.dtype:
    if name == "fp16":
        return torch.float16
    if name == "bf16":
        return torch.bfloat16
    if name == "fp32":
        return torch.float32
    raise ValueError(f"Unsupported precision {name}")


def save_predictions(
    block: nn.Module,
    loader: DataLoader,
    device: torch.device,
    dtype: torch.dtype,
    out_dir: Path,
    block_idx: int,
) -> None:
    out_dir.mkdir(parents=True, exist_ok=True)
    block.eval()
    with torch.no_grad():
        for sample in loader:
            step = sample["step"]
            x_input = move_to(sample["x_input"], device, dtype)
            temb = move_to(sample["temb"], device, dtype)
            image_rotary_emb = move_to(sample["image_rotary_emb"], device, dtype)
            joint_kwargs = move_to(sample["joint_attention_kwargs"], device, dtype)

            pred = block(
                hidden_states=x_input,
                temb=temb,
                image_rotary_emb=image_rotary_emb,
                joint_attention_kwargs=joint_kwargs,
            )
            torch.save(pred.cpu(), out_dir / f"pred_singleblock_{block_idx}_output_step_{step}.pt")


# ---------------------------------------------------------------------------
# Train
# ---------------------------------------------------------------------------
def train(args: argparse.Namespace) -> None:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dtype = torch.float32
    amp_dtype = autocast_dtype(args.precision)
    amp_enabled = amp_dtype != torch.float32 and device.type == "cuda"

    # Load model
    model = FluxTransformer2DModel.from_pretrained(args.ckpt_dir, subfolder="transformer")

    # Optional: stack multiple single blocks for joint training
    if args.stack_blocks is not None and len(args.stack_blocks) > 0:
        stack_indices = [int(i) for i in args.stack_blocks]

        class StackedBlocks(nn.Module):
            def __init__(self, blocks: list[nn.Module]) -> None:
                super().__init__()
                self.blocks = nn.ModuleList(blocks)

            def forward(
                self,
                hidden_states: torch.Tensor,
                temb: torch.Tensor,
                image_rotary_emb: Any,
                joint_attention_kwargs: Any,
            ) -> torch.Tensor:
                x = hidden_states
                for b in self.blocks:
                    x = b(
                        hidden_states=x,
                        temb=temb,
                        image_rotary_emb=image_rotary_emb,
                        joint_attention_kwargs=joint_attention_kwargs,
                    )
                return x

        blocks_to_stack = [model.single_transformer_blocks[i] for i in stack_indices]
        block: nn.Module = StackedBlocks(blocks_to_stack)
    else:
        block = model.single_transformer_blocks[args.block_idx]

    # Select target Linear modules
    available_paths = find_linear_paths(block)
    if args.target_modules == ["auto"]:
        preferred = {"linear1", "linear2", "lin"}
        target_paths = [p for p in available_paths if p.split(".")[-1] in preferred]
        target_paths = target_paths or available_paths
    else:
        target_paths = args.target_modules
        missing = [p for p in target_paths if p not in available_paths]
        if missing:
            raise ValueError(f"Requested target modules not found: {missing}")

    logger.info("LoRA target modules: %s", target_paths)

    inject_lora(block, target_paths, rank=args.rank, alpha=args.alpha, dropout=args.dropout)
    freeze_except_lora(block)
    block.to(device=device, dtype=dtype)

    # Dataset / Loader
    dataset = FluxMultiRootDataset(
        roots=[Path(p) for p in args.feature_roots],
        block_idx=args.block_idx,
        prev_block_idx=args.prev_block_idx,
        steps=args.steps,
        offset=args.offset,
        dtype=dtype,
    )
    loader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=lambda b: b[0])

    optimizer = torch.optim.AdamW(iter_lora_parameters(block), lr=args.lr, weight_decay=0.0)
    scaler = torch.cuda.amp.GradScaler(enabled=amp_enabled)

    logger.info(
        "Start training: %d step-indices × %d roots per step",
        len(dataset),
        len(args.feature_roots),
    )

    # Early stop settings (kept generic)
    window = args.earlystop_window
    threshold = args.earlystop_threshold
    patience = args.earlystop_patience

    prev_window_avg: float | None = None
    stagnant_windows = 0
    early_stopped = False

    for epoch in range(args.epochs):
        block.train()
        total_loss = 0.0
        t0 = time.time()

        for samples in loader:
            optimizer.zero_grad(set_to_none=True)

            batch_loss = 0.0
            valid = 0

            for sample in samples:
                x_input = move_to(sample["x_input"], device, dtype)
                target = move_to(sample["target"], device, dtype)
                temb = move_to(sample["temb"], device, dtype)
                image_rotary_emb = move_to(sample["image_rotary_emb"], device, dtype)
                joint_kwargs = move_to(sample["joint_attention_kwargs"], device, dtype)

                if amp_enabled:
                    with torch.cuda.amp.autocast(dtype=amp_dtype):
                        pred = block(
                            hidden_states=x_input,
                            temb=temb,
                            image_rotary_emb=image_rotary_emb,
                            joint_attention_kwargs=joint_kwargs,
                        )
                        loss = F.mse_loss(pred.float(), target.float())
                else:
                    pred = block(
                        hidden_states=x_input,
                        temb=temb,
                        image_rotary_emb=image_rotary_emb,
                        joint_attention_kwargs=joint_kwargs,
                    )
                    loss = F.mse_loss(pred.float(), target.float())

                if not torch.isfinite(loss):
                    logger.warning("Non-finite loss at epoch=%d step=%s; skipped.", epoch + 1, sample["step"])
                    continue

                batch_loss += loss
                valid += 1

            if valid == 0:
                continue

            batch_loss = batch_loss / valid

            scaler.scale(batch_loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(list(iter_lora_parameters(block)), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()

            total_loss += float(batch_loss.item())

        avg_loss = total_loss / max(len(loader), 1)
        dt = time.time() - t0

        if (epoch + 1) % args.log_every == 0 or epoch == 0:
            logger.info(
                "Epoch %d/%d | mean_loss=%.6f | time=%.2fs",
                epoch + 1,
                args.epochs,
                avg_loss,
                dt,
            )

        if window > 0 and (epoch + 1) % window == 0:
            if prev_window_avg is not None:
                rel_change = abs(avg_loss - prev_window_avg) / max(prev_window_avg, 1e-12)
                if rel_change < threshold:
                    stagnant_windows += 1
                else:
                    stagnant_windows = 0

                if stagnant_windows >= patience:
                    logger.info("Early stopping triggered (loss change below threshold across windows).")
                    early_stopped = True
                    break
            prev_window_avg = avg_loss

    if early_stopped:
        logger.info("Training stopped before reaching max epochs.")

    # Optional: dump predictions for inspection
    if args.eval_dir is not None:
        save_predictions(
            block=block,
            loader=loader,
            device=device,
            dtype=dtype,
            out_dir=Path(args.eval_dir),
            block_idx=args.block_idx,
        )

    # Save
    save_path = Path(args.save_path)
    save_path.parent.mkdir(parents=True, exist_ok=True)

    payload: dict[str, Any] = {
        "state_dict": block.state_dict(),
        "block_idx": args.block_idx,
    }
    if args.stack_blocks is not None and len(args.stack_blocks) > 0:
        payload["stack_blocks"] = [int(i) for i in args.stack_blocks]

    torch.save(payload, save_path)
    logger.info("Saved adapter checkpoint: %s", str(save_path))


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="LoRA training script for Flux single_transformer_blocks.")
    parser.add_argument("--ckpt-dir", type=str, required=True, help="Directory containing the original Flux weights.")
    parser.add_argument(
        "--feature-roots",
        type=str,
        nargs="+",
        required=True,
        help="List of feature directories.",
    )
    parser.add_argument("--steps", type=int, nargs="+", required=True, help="Training step indices.")
    parser.add_argument("--block-idx", type=int, default=37, help="Target single block index.")
    parser.add_argument("--prev-block-idx", type=int, default=36, help="Index that produced x_input.")
    parser.add_argument("--offset", type=int, default=1, help="x_input comes from step - offset.")
    parser.add_argument("--rank", type=int, default=64)
    parser.add_argument("--alpha", type=float, default=128.0)
    parser.add_argument("--dropout", type=float, default=0.0)
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--epochs", type=int, default=200)
    parser.add_argument("--precision", choices=("fp16", "bf16", "fp32"), default="fp16")
    parser.add_argument("--target-modules", nargs="+", default=["auto"], help="Specific linear module paths to wrap.")
    parser.add_argument("--save-path", type=str, default="flux_block_lora.pt")
    parser.add_argument(
        "--stack-blocks",
        type=int,
        nargs="+",
        default=[37],
        help="Optional list of block indices to stack for joint training.",
    )
    parser.add_argument("--eval-dir", type=str, default=None, help="Optional directory to dump predictions.")

    # Logging / early stop knobs (kept generic and stable)
    parser.add_argument("--log-every", type=int, default=100)
    parser.add_argument("--earlystop-window", type=int, default=100)
    parser.add_argument("--earlystop-threshold", type=float, default=0.0)
    parser.add_argument("--earlystop-patience", type=int, default=2)
    return parser.parse_args()


if __name__ == "__main__":
    train(parse_args())
