import torch
import torchvision.models as models
from pathlib import Path
from tqdm import tqdm
import hydra
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
from spastra import configs
from spastra.data import get_dataloaders
from spastra.astra import SASTRA
from spastra.astra import IHTSparsifier
from spastra.evaluate import evaluate_accuracy
from spastra.evaluate import get_model_sparsity
from spastra.stats import StatsCollector
from torchvision.datasets import ImageNet
from torch.utils.data import DataLoader
import time


def single_batch_train(model, inputs, targets, criterion, optimizer):
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()
    return loss, outputs


@hydra.main(version_base=None, config_path="configs", config_name="config")
def main(cfg: DictConfig):
    print(OmegaConf.to_yaml(cfg, resolve=True))
    run = configs.init_wandb(cfg.wandb, cfg)

    env_vars = configs.get_envs([cfg.data_dir_env])
    # datasets_path = env_vars[cfg.data_dir_env]
    checkpoint_path = env_vars[cfg.checkpoint_dir_env]
    datasets_path = Path("/buckets/datasets/torchvision/imagenet/")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    device = torch.device(cfg.device if torch.cuda.is_available() else "cpu")

    # Define paths and device

    print(f"Using {device} for evaluation")
    weights = models.ResNet50_Weights.IMAGENET1K_V2
    model = models.resnet50(weights=weights)
    model.to(device)
    val_ds = ImageNet(datasets_path, split="train")
    val_loader = DataLoader(val_ds, batch_size=32)

    weights = models.ResNet50_Weights.IMAGENET1K_V2
    teacher = models.resnet50(weights=weights)
    teacher = teacher.to(device)
    teacher.eval()

    correct = 0
    total = 0
    start_time = time.time()
    pbar = tqdm(total=len(val_loader), desc="Evaluating")

    env_vars = configs.get_envs([cfg.data_dir_env])
    datasets_path = env_vars[cfg.data_dir_env]
    device = torch.device(cfg.device if torch.cuda.is_available() else "cpu")

    model = configs.get_model(cfg.model).to(device)

    optimizer = configs.get_optimizer(cfg.optimizer, model.parameters())
    lr_scheduler = configs.get_lr_scheduler(cfg.lr_scheduler, optimizer)

    sp_cfg = cfg.sparsifier
    name_to_spec = configs.get_sparsity_specs(
        sp_cfg.specs, list(model.named_parameters()), sp_cfg.exclude
    )

    groups = configs.get_sparsity_groups(
        sp_cfg.coupling, name_to_spec, default_sparsity=sp_cfg.sparsity
    )
    ema_grad = configs.get_ema(sp_cfg.ema)
    lamb_controller = configs.get_lambdas(sp_cfg["lambda"], device=device)
    alphas = configs.get_alphas(sp_cfg.alphas, name_to_spec)

    sastra = SASTRA(
        groups=groups,
        lambdas=lamb_controller,
        ema_grad=ema_grad,
        alphas=alphas,
        device=device,
    )
    sastra.attach_optimizer(optimizer)

    

    criterion = instantiate({"_target_": cfg.criterion})
    train_loader, test_loader = get_dataloaders(
        cfg.dataset.name,
        datasets_path,
        cfg.batch_size,
        cfg.dataset.num_workers,
    )

    # for i, (images, labels) in tqdm(enumerate(val_loader)):
    
    #     loss, outputs = single_batch_train(model, images, target, criterion, optimizer=)
    #     output = model(images)

    num_epochs = cfg.num_epochs
    batch_size = cfg.batch_size
    num_batches = len(train_loader)
    warmup = int(num_epochs * sp_cfg.warmup)
    freeze = int(num_epochs * sp_cfg.freeze)

    stats_controller = StatsCollector(
        list(model.named_parameters()),
        refresh_every=int(cfg.stats.refresh_every * num_batches),
    )

    def one_epoch_train(epoch, sandwich_call):
        model.train()
        train_loss = 0
        correct = 0
        total = 0

        prog_bar = tqdm(
            enumerate(train_loader),
            total=num_batches,
            desc="[LR=?]",
            leave=True,
        )

        for batch_idx, (inputs, raw_targets) in prog_bar:
            num_step = (epoch * num_batches + batch_idx) * batch_size
            inputs, raw_targets = inputs.to(device), raw_targets.to(device)
            with torch.no_grad():
                targets = teacher(inputs)

        
            loss, outputs = single_batch_train(
                model, inputs, targets, criterion, optimizer
            )

            sandwich_call(e, num_step)

            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += raw_targets.size(0)
            correct += predicted.eq(raw_targets).sum().item()
            desc = (
                f"[LR={lr_scheduler.get_last_lr()[0]:.8f}] "
                f"Loss: {train_loss / (batch_idx + 1):.4f} "
                f"| Acc:{100.0 * correct / total:.3f} {correct}/{total}"
            )
            prog_bar.set_description(desc, refresh=True)

        lr_scheduler.step()
        test_accuracy = evaluate_accuracy(model, test_loader)
        model_sparsity = get_model_sparsity(model)

        print(
            f"Epoch {e + 1}/{num_epochs},",
            f"Test Accuracy: {100.0 * test_accuracy:.4f}%,",
            f"Model sparsity: {100.0 * model_sparsity:.4f}%",
        )
        run.log(
            {
                "lr": lr_scheduler.get_last_lr()[0],
                "test_accuracy": test_accuracy,
                "real_sparsity": model_sparsity,
                "train_accuracy": correct / total,
                "train_loss": train_loss / num_batches,
                "sparse_accuracy": test_accuracy
                + 0.05 * min(model_sparsity + 0.025 - sp_cfg.sparsity, 0),
            },
            step=(epoch + 1) * num_batches * batch_size,
        )
        torch.save(model.state_dict(), checkpoint_path.join(f"resnet50_imagenet_{epoch}.ckpt"))

    def stats_sandwich(epoch, n_step):
        stats = stats_controller.step()
        if stats is not None:
            run.log(data=stats, step=n_step)

    def sparsifier_sandwich(epoch, n_step):
        sastra.step(sparsify=epoch >= warmup)
        stats_sandwich(epoch, n_step)

    for e in range(freeze):
        one_epoch_train(e, sparsifier_sandwich)

    sastra.detach_all_optimizers()

    iht = IHTSparsifier(groups=sastra.groups)
    print("Freezing support of parameters via hard thresholding")
    iht.freeze_support(optimizer)

    for e in range(freeze, num_epochs):
        one_epoch_train(e, stats_sandwich)


if __name__ == "__main__":
    main()