from pathlib import Path
from tqdm import tqdm
import torch
import hydra
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf

from spastra import configs
from spastra.data import get_dataloaders
from spastra.astra import IHTSparsifier
from spastra.evaluate import evaluate_accuracy
from spastra.evaluate import get_model_sparsity
from spastra.stats import StatsCollector


def single_batch_train(model, inputs, targets, criterion, optimizer):
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()
    return loss, outputs


@hydra.main(version_base=None, config_path="configs", config_name="config")
def main(cfg: DictConfig):
    print(OmegaConf.to_yaml(cfg, resolve=True))
    run = configs.init_wandb(cfg.wandb, cfg)

    # datasets_path = Path(cfg.data_dir)
    env_vars = configs.get_envs([cfg.data_dir_env])
    datasets_path = env_vars[cfg.data_dir_env]
    device = torch.device(cfg.device if torch.cuda.is_available() else "cpu")

    model = configs.get_model(cfg.model).to(device)

    optimizer = configs.get_optimizer(cfg.optimizer, model.parameters())
    lr_scheduler = configs.get_lr_scheduler(cfg.lr_scheduler, optimizer)

    sp_cfg = cfg.sparsifier
    name_to_spec = configs.get_sparsity_specs(
        sp_cfg.specs, list(model.named_parameters()), sp_cfg.exclude
    )

    groups = configs.get_sparsity_groups(
        sp_cfg.coupling, name_to_spec, default_sparsity=sp_cfg.sparsity
    )

    sparsifier = IHTSparsifier(
        groups=groups,
        device=device,
    )

    criterion = instantiate({"_target_": cfg.criterion})
    train_loader, test_loader = get_dataloaders(
        cfg.dataset.name,
        datasets_path,
        cfg.batch_size,
        cfg.dataset.num_workersr,
    )

    num_epochs = cfg.num_epochs
    batch_size = cfg.batch_size
    num_batches = len(train_loader)
    warmup = int(num_epochs * sp_cfg.warmup)
    freeze = int(num_epochs * sp_cfg.freeze)

    stats_controller = StatsCollector(
        list(model.named_parameters()),
        refresh_every=int(cfg.stats.refresh_every * num_batches),
    )
    
    # print("Freezing support of parameters via hard thresholding")
    # iht.freeze_support(optimizer)

    def one_epoch_train(epoch, sandwich_call):
        model.train()
        train_loss = 0
        correct = 0
        total = 0

        prog_bar = tqdm(
            enumerate(train_loader),
            total=num_batches,
            desc="[LR=?]",
            leave=True,
        )

        for batch_idx, (inputs, targets) in prog_bar:
            num_step = (epoch * num_batches + batch_idx) * batch_size
            inputs, targets = inputs.to(device), targets.to(device)
            loss, outputs = single_batch_train(
                model, inputs, targets, criterion, optimizer
            )

            sandwich_call(e, num_step)

            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            desc = (
                f"[LR={lr_scheduler.get_last_lr()[0]:.8f}] "
                f"Loss: {train_loss / (batch_idx + 1):.4f} "
                f"| Acc:{100.0 * correct / total:.3f} {correct}/{total}"
            )
            prog_bar.set_description(desc, refresh=True)

        lr_scheduler.step()
        test_accuracy = evaluate_accuracy(model, test_loader)
        model_sparsity = get_model_sparsity(model)

        print(
            f"Epoch {e + 1}/{num_epochs},",
            f"Test Accuracy: {100.0 * test_accuracy:.4f}%,",
            f"Model sparsity: {100.0 * model_sparsity:.4f}%",
        )
        run.log(
            {
                "lr": lr_scheduler.get_last_lr()[0],
                "test_accuracy": test_accuracy,
                "real_sparsity": model_sparsity,
                "train_accuracy": correct / total,
                "train_loss": train_loss / num_batches,
                "sparse_accuracy": test_accuracy
                + 0.05 * min(model_sparsity + 0.025 - sp_cfg.sparsity, 0),
            },
            step=(epoch + 1) * num_batches * batch_size,
        )

    def stats_sandwich(epoch, n_step):
        stats = stats_controller.step()
        if stats is not None:
            run.log(data=stats, step=n_step)

    def sparsifier_sandwich(epoch, n_step):
        sparsifier.step(sparsify=epoch >= warmup)
        stats_sandwich(epoch, n_step)

    for e in range(freeze):
        one_epoch_train(e, sparsifier_sandwich)

    iht = IHTSparsifier(groups=sparsifier.groups)
    print("Freezing support of parameters via hard thresholding")
    iht.freeze_support(optimizer)

    for e in range(freeze, num_epochs):
        one_epoch_train(e, stats_sandwich)


if __name__ == "__main__":
    main()
