from typing import Iterable, Dict

from reconstruction.losses import *
from reconstruction.model import TimeSeriesPretrainer




class TrainConfig:
    __slots__ = ("task_mix_recon", "cutfill_cfg", "grad_clip")
    def __init__(self, task_mix_recon=0.5, cutfill_cfg=None, grad_clip=1.0):
        self.task_mix_recon = float(task_mix_recon)
        self.cutfill_cfg = cutfill_cfg if cutfill_cfg is not None else CutFillConfig()
        self.grad_clip = None if grad_clip is None else float(grad_clip)
    def __repr__(self):
        return (f"TrainConfig(task_mix_recon={self.task_mix_recon}, "
                f"cutfill_cfg={self.cutfill_cfg}, grad_clip={self.grad_clip})")


def train_epoch(model: TimeSeriesPretrainer,
                loader: Iterable,
                optimizer: torch.optim.Optimizer,
                device: str,
                cfg: TrainConfig,
                subject: bool) -> Dict[str, float]:
    model.train()
    total, n = 0.0, 0
    total_recon, total_cutfill = 0.0, 0.0
    for x in loader:  # x: (B, T, D)

        if subject:
            x_orig, x_sub, yb = x
            x_sub = x_sub.to(device)
        else:
            x_orig, yb = x

        x_orig = x_orig.to(device)

        do_recon = (torch.rand(()) < cfg.task_mix_recon)
        do_recon = True
        optimizer.zero_grad(set_to_none=True)

        if do_recon:
            # Plain reconstruction
            x_in = x_orig
            if subject:
                x_hat = model(x_in, x_sub)
            else:
                x_hat = model(x_in)
            loss = reconstruction_loss(x_hat, x_orig)
            total_recon += loss.item()
        else:
            # Cut-and-Fill
            x_in, mask, _ = apply_cut_and_fill(x_orig, cfg.cutfill_cfg)
            if subject:
                x_hat = model(x_in, x_sub)
            else:
                x_hat = model(x_in)
            if cfg.cutfill_cfg.loss_on_full:
                loss = reconstruction_loss(x_orig, x_hat)
            else:
                loss = reconstruction_loss(x_orig, x_hat, mask=mask)
            total_cutfill += loss.item()

        loss.backward()
        optimizer.step()

        total += loss.item()
        n += 1

    return {
        "loss": total / max(1, n),
        "recon_loss": total_recon / max(1, n),
        "cutfill_loss": total_cutfill / max(1, n),
        "steps": n,
    }


@torch.no_grad()
def evaluate(model: TimeSeriesPretrainer, loader: Iterable, device: str,
             cutfill_eval_cfg: CutFillConfig, subject: bool) -> Dict[str, float]:
    model.eval()
    total_recon, total_cutfill = 0.0, 0.0
    n1, n2 = 0, 0
    for x in loader:

        if subject:
            x_orig, x_sub, yb = x
            x_sub = x_sub.to(device)
        else:
            x_orig, yb = x

        x_orig = x_orig.to(device)

        # Full reconstruction error
        if subject:
            x_hat = model(x_orig, x_sub)
        else:
            x_hat = model(x_orig)

        total_recon += reconstruction_loss(x_orig, x_hat).item()
        n1 += 1

        # Cut-and-fill evaluation: compute loss only over masked region
        x_in, mask, _ = apply_cut_and_fill(x_orig, cutfill_eval_cfg)
        if subject:
            x_hat2 = model(x_orig, x_sub)
        else:
            x_hat2 = model(x_orig)
        total_cutfill += reconstruction_loss(x_orig, x_hat2, mask=mask).item()
        n2 += 1

    return {
        "val_recon_mse": total_recon / max(1, n1),
        "val_cutfill_mse": total_cutfill / max(1, n2),
    }
