import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR

from enums import LossFunction


def _get_scheduler(opt_config, optimizer):
    # Calculate warmup steps as percentage of total steps
    # warmup_iters is interpreted as percentage (e.g., 10 means 10%)
    warmup_steps = int(opt_config.max_steps * opt_config.warmup_iters / 100.0) if opt_config.warmup_iters > 0 else 0

    if opt_config.lr_scheduler == "cosine":
        if warmup_steps > 0:
            warmup_scheduler = LinearLR(
                optimizer,
                start_factor=opt_config.warmup_start_factor,
                end_factor=1.0,
                total_iters=warmup_steps
            )
            main_scheduler = CosineAnnealingLR(optimizer,
                                               T_max=opt_config.max_steps - warmup_steps)
            scheduler = SequentialLR(
                optimizer,
                schedulers=[warmup_scheduler, main_scheduler],
                milestones=[warmup_steps]
            )
        else:
            scheduler = CosineAnnealingLR(optimizer,
                                          T_max=opt_config.max_steps
                                          )
    elif opt_config.lr_scheduler == "linear":
        if warmup_steps > 0:
            warmup_scheduler = LinearLR(
                optimizer,
                start_factor=opt_config.warmup_start_factor,
                end_factor=1.0,
                total_iters=warmup_steps
            )
            main_scheduler = LinearLR(
                optimizer,
                start_factor=1.0,
                end_factor=0.0,
                total_iters=opt_config.max_steps - warmup_steps
            )
            scheduler = SequentialLR(
                optimizer,
                schedulers=[warmup_scheduler, main_scheduler],
                milestones=[warmup_steps]
            )
        else:
            scheduler = LinearLR(optimizer, start_factor=1.0, end_factor=0.0,
                                 total_iters=opt_config.max_steps
                                 )
    else:
        scheduler = None

    return scheduler


def train_transform_matrix(model, shared_learned_transforms, opt_config, dataloader, loss_fn, device):
    # Make sure transform params are float32
    for s in shared_learned_transforms:
        s.to(torch.float32)

    # Parameter groups: R1 (first transform) vs R2+ (all remaining transforms)
    r1_params = list(shared_learned_transforms[0].parameters()) if len(shared_learned_transforms) > 0 else []
    r2_params = []
    for s in shared_learned_transforms[1:]:
        r2_params += list(s.parameters())

    param_groups = []
    if r1_params:
        param_groups.append({"params": r1_params, "lr": opt_config.learning_rate})
    if r2_params:
        param_groups.append({"params": r2_params, "lr": opt_config.learning_rate})

    optimizer = AdamW(
        param_groups,
        weight_decay=opt_config.weight_decay,
        betas=opt_config.betas,
    )
    scheduler = _get_scheduler(opt_config, optimizer)

    step = 0
    dataloader_iter = iter(dataloader)
    interval_loss = 0.0
    interval_err_loss = 0.0
    interval_steps = 0

    while step < opt_config.max_steps:
        print(f"Transform optimization step {step}...")
        try:
            batch = next(dataloader_iter)
        except StopIteration:
            dataloader_iter = iter(dataloader)
            batch = next(dataloader_iter)

        batch = batch.to(device)

        optimizer.zero_grad()

        if opt_config.loss_function == LossFunction.UNEMBED_DISTILLATION.value:
            loss_fn.set_batch(batch)
            loss_fn.compute_float_outputs()
        elif opt_config.loss_function == LossFunction.FLAT_Q_DISTILLATION.value:
            loss_fn.set_batch(batch)
            loss_fn.compute_float_outputs()

        # optional: if your transform has a cache API
        for s in shared_learned_transforms:
            if hasattr(s, "reset_cache"):
                s.reset_cache()

        outputs = model(batch)
        loss = loss_fn(outputs.logits, batch)

        reg_loss = 0
        for s in shared_learned_transforms:
            log_s_params = []
            for name, p in s.named_parameters():
                if "log_s" in name:
                    log_s_params.append(p)

            if log_s_params:
                # Stack all log_s tensors
                log_s_tensor = torch.cat([p.view(-1) for p in log_s_params])
                reg = opt_config.reg_lambda * log_s_tensor.sum() ** 2
            else:
                reg = 0.0
            reg_loss = reg_loss + reg

        err_loss = loss.item()
        total_loss = loss + reg_loss

        interval_loss += total_loss.item()
        interval_err_loss += err_loss
        interval_steps += 1

        with torch.no_grad():
            if shared_learned_transforms[0].matrix is not None:
                R1 = shared_learned_transforms[0].matrix().to(torch.float32)
                m = (R1.T @ R1)
                ort_dist = torch.norm(m - torch.eye(R1.shape[0], device=R1.device))
                spec = torch.linalg.norm(m - torch.eye(R1.shape[0], device=R1.device), ord=2)
                det = torch.det(R1.T @ R1)

                block_ids = torch.arange(R1.shape[0], device=R1.device) // 32
                off_block_mask = block_ids[:, None] != block_ids[None, :]
                R1_off = R1 * off_block_mask
                off_block_norm = torch.linalg.norm(R1_off)
                off_block_norm_2 = torch.linalg.norm(R1_off, ord=2)

                print(f"loss: {total_loss.item()}, "
                      f"R1 ortho dist: {ort_dist.item():.6f}, "
                      f"spec: {spec.item():.6f}, det(R1): det(R1^T R1): {det.item():.6f}, "
                      f"off norm: {off_block_norm.item():.6f}, "
                      f"off norm 2: {off_block_norm_2.item():.6f}")

        total_loss.backward()

        optimizer.step()
        if scheduler is not None:
            scheduler.step()

        step += 1

    # Cleanup hooks
    if opt_config.loss_function == LossFunction.UNEMBED_DISTILLATION.value:
        loss_fn.remove_hooks()
    elif opt_config.loss_function == LossFunction.FLAT_Q_DISTILLATION.value:
        loss_fn.remove_hooks()

def print_optimizer_param_norms(optimizer: torch.optim.Optimizer) -> None:
    for gi, group in enumerate(optimizer.param_groups):
        for pi, p in enumerate(group["params"]):
            if p is None or pi > 3:
                continue
            with torch.no_grad():
                p_norm = p.detach().float().norm(p=2).item()
                p_max = p.detach().float().abs().max().item()
                g_norm = (
                    p.grad.detach().float().norm(p=2).item()
                    if p.grad is not None
                    else None
                )
            torch.set_printoptions(precision=7, sci_mode=False)
            params = group["params"]
            if len(params) >= 4 and params[3].ndim >= 2:
                u_preview = params[3][0, :10]
            else:
                u_preview = None
            print(
                f"group={gi} lr={group.get('lr', None)} param={pi} shape={tuple(p.shape)} "
                f"param_norm={p_norm:.6g} param_max={p_max:.6g} grad_norm={('None' if g_norm is None else f'{g_norm:.6g}')} "
                f"U_preview={u_preview}"
            )