from pathlib import Path
from tqdm import tqdm
import torch
import hydra
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf

from spastra import configs
from spastra.astra import SASTRA
from spastra.astra import IHTSparsifier
from spastra.evaluate import get_model_sparsity
from spastra.stats import StatsCollector
from spastra.data.llm import RandomTokens

from transformers import AutoModelForCausalLM, AutoTokenizer

import torch
import os
from hydra import initialize, compose

HF_TOKEN = "???????????????????"
HF_HOME = "/buckets/datasets/huggingface"


def single_batch_train(model, inputs, targets, criterion, optimizer):
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()
    return loss, outputs


with initialize(version_base=None, config_path="config"):
    cfg = compose(config_name="prune")




@hydra.main(version_base=None, config_path="config", config_name="prune_llm")
def main(cfg: DictConfig):
    print(OmegaConf.to_yaml(cfg, resolve=True))
    run = configs.init_wandb(cfg.wandb, cfg)

    model_name = "HuggingFaceTB/SmolLM3-3B"
    device = torch.device(cfg.device if torch.cuda.is_available() else "cpu")

    tokenizer = AutoTokenizer.from_pretrained(model_name)

    teacher = AutoModelForCausalLM.from_pretrained(
        model_name, torch_dtype="auto", device_map="cuda"
    ).eval()

    model = AutoModelForCausalLM.from_pretrained(
        model_name, torch_dtype="float32", device_map="cuda"
    )
    ds = RandomTokens(tokenizer, seq_len=512)
    model
    tokenizer.decode([127500, 2, 127490, 120000, 111000, 2, 50000])

    ds[0]
    tokenizer("<code>python\n hi there</code> <|reserved_special_token_247|>")

    optimizer = configs.get_optimizer(cfg.optimizer, model.parameters())
    lr_scheduler = configs.get_lr_scheduler(cfg.lr_scheduler, optimizer)

    sp_cfg = cfg.sparsifier
    name_to_spec = configs.get_sparsity_specs(
        sp_cfg.specs, list(model.named_parameters()), sp_cfg.exclude
    )

    groups = configs.get_sparsity_groups(
        sp_cfg.coupling, name_to_spec, default_sparsity=sp_cfg.sparsity
    )
    ema_grad = configs.get_ema(sp_cfg.ema)
    lamb_controller = configs.get_lambdas(sp_cfg["lambda"], device=device)
    alphas = configs.get_alphas(sp_cfg.alphas, name_to_spec)

    sastra = SASTRA(
        groups=groups,
        lambdas=lamb_controller,
        ema_grad=ema_grad,
        alphas=alphas,
        device=device,
    )
    sastra.attach_optimizer(optimizer)

    criterion = instantiate({"_target_": cfg.criterion})

    num_batches = cfg.num_batches
    num_epochs = cfg.num_epochs
    batch_size = cfg.batch_size
    warmup = int(num_epochs * sp_cfg.warmup)
    freeze = int(num_epochs * sp_cfg.freeze)

    stats_controller = StatsCollector(
        list(model.named_parameters()),
        refresh_every=int(cfg.stats.refresh_every * num_batches),
    )

    def one_epoch_train(epoch, sandwich_call):
        model.train()
        train_loss = 0
        correct = 0
        total = 0

        prog_bar = tqdm(
            total=num_batches,
            desc="[LR=?]",
            leave=True,
        )

        for batch_idx, (inputs, targets) in prog_bar:
            num_step = (epoch * num_batches + batch_idx) * batch_size
            inputs, targets = inputs.to(device), targets.to(device)
            loss, outputs = single_batch_train(
                model, inputs, targets, criterion, optimizer
            )

            sandwich_call(e, num_step)

            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            desc = (
                f"[LR={lr_scheduler.get_last_lr()[0]:.8f}] "
                f"Loss: {train_loss / (batch_idx + 1):.4f} "
                f"| Acc:{100.0 * correct / total:.3f} {correct}/{total}"
            )
            prog_bar.set_description(desc, refresh=True)

        lr_scheduler.step()
        model_sparsity = get_model_sparsity(model)

        print(
            f"Epoch {e + 1}/{num_epochs},",
            f"Model sparsity: {100.0 * model_sparsity:.4f}%",
        )
        run.log(
            {
                "lr": lr_scheduler.get_last_lr()[0],
                "real_sparsity": model_sparsity,
                "train_accuracy": correct / total,
                "train_loss": train_loss / num_batches,
            },
            step=(epoch + 1) * num_batches * batch_size,
        )

    def stats_sandwich(epoch, n_step):
        stats = stats_controller.step()
        if stats is not None:
            run.log(data=stats, step=n_step)

    def sparsifier_sandwich(epoch, n_step):
        sastra.step(sparsify=epoch >= warmup)
        stats_sandwich(epoch, n_step)

    for e in range(freeze):
        one_epoch_train(e, sparsifier_sandwich)

    sastra.detach_all_optimizers()

    iht = IHTSparsifier(groups=sastra.groups)
    print("Freezing support of parameters via hard thresholding")
    iht.freeze_support(optimizer)

    for e in range(freeze, num_epochs):
        one_epoch_train(e, stats_sandwich)


if __name__ == "__main__":
    main()
