from __future__ import annotations

from copy import deepcopy
from typing import Iterable, Optional
import torch
from tqdm.auto import tqdm

from src.model.config import Transformer
from src.run.eval import eval_loss
from src.run.config import RunConfig
from src.run.utils import get_batch, get_select_mask
from src.run.logger import get_tqdm_kwargs
from src.run.dataloader import InterleavedDataLoader
from src.run.distributed import barrier, get_raw_model, reduce_tensor

def do_finetune(
    model: Transformer,
    config: RunConfig,
    data_labels: Iterable[str],
    expert_labels: Optional[Iterable[str]] = None,
    lr: float = 5e-5,
    num_batches: int | None = 75,
    num_val_batches: int | None = 10,
    early_stop: bool = True,
    single_batch: bool = True,
    patience: int = 15,
) -> tuple[Transformer, dict]:
    """
    Finetune a model on specified data labels with optional expert routing.
    
    Args:
        model: Model to finetune
        config: Run configuration
        data_labels: Labels to train on
        expert_labels: Expert labels for routing (routed models only)
        lr: Learning rate
        num_batches: Number of batches to train for
        early_stop: Whether to use early stopping
        single_batch: Whether to use single batch mode
        patience: Early stopping patience
    
    Returns:
        Finetuned model
    """

    # unpack run config
    loaders = config.loaders
    logger = config.logger
    labels = config.aux_labels + ["core"]
    acc_steps = config.accumulation_steps

    logger.info(f"---- Begin FT | Expert Labels: {expert_labels} | Data Labels: {data_labels} ----")

    assert all(label in labels for label in data_labels), f"all data labels must be in {labels}"

    # Get raw model for accessing model_type
    raw_model = get_raw_model(model)

    # Setup training components
    model.train()

    # Optimizer
    opt = torch.optim.AdamW(model.parameters(), lr=lr, fused=True)

    # Loader for data labels
    loader = InterleavedDataLoader([loaders[label]["train"] for label in data_labels], weighted=False)
    loader.reset()

    total_loss = 0.0
    best_val_loss = float("inf")
    best_state_dict = {}
    no_improvement_steps = 0
    if early_stop:
        best_state_dict = deepcopy(model.state_dict())

    # Store one batch per individual loader when using single_batch mode
    first_batches = {} if single_batch else None
    individual_loaders = {label: loaders[label]["train"] for label in data_labels}
    if single_batch:
        for label in data_labels:
            individual_loaders[label].reset()
            first_batches[label] = get_batch(individual_loaders[label])

    train_batches = num_batches
    if train_batches is None:
        train_batches = len(loader)

    # Effective accumulation: accumulate over data_labels * acc_steps batches
    eff_acc_steps = len(data_labels) * acc_steps

    pbar = tqdm(range(train_batches), **get_tqdm_kwargs(logger, desc=f"FT", ncols=150))
    for batch_idx in pbar:
        if batch_idx % eff_acc_steps == 0:
            batch_loss = 0.0

        # Get training batch
        if not single_batch:
            batch_data = get_batch(loader)
        else:
            # Cycle through the stored batches from each loader
            label_idx = batch_idx % len(data_labels)
            current_label = list(data_labels)[label_idx]
            batch_data = first_batches[current_label]

        x, y, _ = batch_data

        # Perform training step
        if raw_model.model_type == "routed":
            labels = ["core"] + config.aux_labels
            sel_mask = get_select_mask(labels, expert_labels, device=x.device)
            loss = model(x, targets=y, select_mask=sel_mask)[1]
        else:
            loss = model(x, targets=y)[1]

        # Scale loss for gradient accumulation (data_labels * acc_steps)
        loss = loss / eff_acc_steps
        
        batch_loss += loss.item()
        total_loss += loss.item()
        loss.backward()

        if batch_idx % eff_acc_steps == eff_acc_steps - 1:

            opt.step()
            opt.zero_grad(set_to_none=True)

            # Handle validation tracking
            if early_stop:

                with torch.inference_mode():

                    arr = []
                    for label in data_labels:
                        temp = eval_loss(
                            model,
                            config,
                            data_label=label,
                            expert_labels=expert_labels,
                            num_batches=num_val_batches,
                        )
                        arr.append(temp)
                    val_loss = sum(arr) / len(arr)
                    
                    # Synchronize val_loss across GPUs so all make the same early stopping decision
                    val_loss_tensor = torch.tensor(val_loss, device=config.device)
                    val_loss = reduce_tensor(val_loss_tensor).item()
                    
                    pbar.set_description(f"FT Loss: {batch_loss:.4f} | Val Loss: {val_loss:.4f} | current label: {current_label}")
                    logger.info(f"FT Loss: {batch_loss:.4f} | Val Loss: {val_loss:.4f} | current label: {current_label} @ step {batch_idx}")

                    if val_loss < best_val_loss:
                        best_val_loss = val_loss
                        best_state_dict = deepcopy(model.state_dict())
                        logger.info(f"New best validation loss: {best_val_loss:.4f} @ step {batch_idx}")
                        no_improvement_steps = 0
                    else:
                        no_improvement_steps += 1
                        if no_improvement_steps > patience:
                            logger.info(
                                f"Stop FT @ {batch_idx+1}: no val improvement for {patience} steps"
                            )
                            break

            else:
                pbar.set_description(f"FT Loss: {loss:.4f}")

    # Handle any leftover accumulated gradients
    if (batch_idx + 1) % eff_acc_steps != 0:
        opt.step()
        opt.zero_grad(set_to_none=True)

    actual_num_batches = batch_idx + 1
    avg_loss = total_loss / actual_num_batches
    
    logger.info(f"Final FT train loss: {avg_loss:.4f}")

    if early_stop:
        # Load best model state
        model.load_state_dict(best_state_dict)
        logger.info(f"Best validation loss {best_val_loss:.4f}")
    
    barrier()
    state = dict()

    return model, state