import logging
import importlib
from pathlib import Path

import torch
import torch.nn.functional as F
import wandb
from sklearn.model_selection import train_test_split
from torch import nn
from torch.cuda.amp import GradScaler
from tqdm import trange
import hydra
from omegaconf import OmegaConf

from experiments.data import INRDataset
from experiments.utils import (
    count_parameters,
    get_device,
    set_logger,
    set_seed,
)
from nn.models import DWSModelForClassification, MLPModelForClassification


set_logger()

def compute_psnr(original_image, reconstructed_image):
    mse = F.mse_loss(original_image, reconstructed_image)  # Compute the Mean Squared Error
    max_pixel_value = torch.max(original_image)  # Find the maximum pixel value

    psnr = 20 * torch.log10(max_pixel_value / torch.sqrt(mse))  # Compute the PSNR using the MSE

    return psnr.item()  # Return the PSNR value as a Python float

def create_object(name, **kwargs):
    module_name, class_name = name.rsplit(".", 1)
    module = importlib.import_module(module_name)
    class_ = getattr(module, class_name)
    return class_(**kwargs)


def train(cfg):
    OmegaConf.resolve(cfg)
    hydra_cfg = hydra.core.hydra_config.HydraConfig.get()
    ckpt_dir = Path(hydra_cfg.runtime.output_dir) / "checkpoints"
    ckpt_dir.mkdir(parents=True, exist_ok=True)

    if cfg.wandb.name is None:
        cfg.wandb.name = (
            f"mnist_clf_{cfg.model.name}_lr_{cfg.lr}"
            f"_bs_{cfg.batch_size}_seed_{cfg.seed}"
        )

    wandb.init(
        **cfg.wandb,
        settings=wandb.Settings(start_method="fork"),
        config=cfg,
    )

    device = get_device(gpus=cfg.gpu)

    # load dataset
    train_set = create_object(cfg.data.cls, **cfg.data.train)
    val_set = create_object(cfg.data.cls, **cfg.data.val)
    test_set = create_object(cfg.data.cls, **cfg.data.test)

    if cfg.n_samples is not None:
        _, train_indices = train_test_split(
            range(len(train_set)), test_size=cfg.n_samples
        )
        train_set = torch.utils.data.Subset(train_set, train_indices)

    train_loader = torch.utils.data.DataLoader(
        dataset=train_set,
        batch_size=cfg.batch_size,
        shuffle=True,
        num_workers=cfg.num_workers,
        pin_memory=True,
    )
    val_loader = torch.utils.data.DataLoader(
        dataset=val_set,
        batch_size=cfg.batch_size,
        num_workers=cfg.num_workers,
        shuffle=False,
        pin_memory=True,
    )
    test_loader = torch.utils.data.DataLoader(
        dataset=test_set,
        batch_size=cfg.batch_size,
        shuffle=False,
        num_workers=cfg.num_workers,
        pin_memory=True,
    )

    if cfg.compile:
        model = torch.compile(model, **cfg.compile_kwargs)

    parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
    optimizer = create_object(cfg.optim.cls, params=parameters, **cfg.optim.kwargs)
    if cfg.use_scheduler:
        scheduler = create_object(
            cfg.scheduler.cls,
            optimizer=optimizer,
            **cfg.scheduler.kwargs,
        )

    criterion = nn.MSELoss()
    best_val_acc = -1
    best_test_results, best_val_results = None, None
    test_acc, test_loss = -1.0, -1.0
    global_step = 0
    start_epoch = 0

    if cfg.load_ckpt:
        ckpt = torch.load(cfg.load_ckpt)
        model.load_state_dict(ckpt["model"])
        if "optimizer" in ckpt:
            optimizer.load_state_dict(ckpt["optimizer"])
        if "scheduler" in ckpt:
            scheduler.load_state_dict(ckpt["scheduler"])
        if "epoch" in ckpt:
            start_epoch = ckpt["epoch"]
        if "global_step" in ckpt:
            global_step = ckpt["global_step"]
        logging.info(f"loaded checkpoint {cfg.load_ckpt}")
    
    epoch_iter = trange(start_epoch, cfg.n_epochs)
    model.train()

    ckpt_dir = Path(hydra_cfg.runtime.output_dir) / wandb.run.path.split("/")[-1]
    ckpt_dir.mkdir(parents=True, exist_ok=True)

    scaler = GradScaler(**cfg.gradscaler)
    autocast_kwargs = dict(cfg.autocast)
    autocast_kwargs["dtype"] = getattr(torch, cfg.autocast.dtype, torch.float32)
    for epoch in epoch_iter:
        for i, batch in enumerate(train_loader):
            optimizer.zero_grad()

            batch = batch.to(device)
            inputs = (batch.weights, batch.biases)

            with torch.autocast(**autocast_kwargs):
                out = model(inputs)
                loss = criterion(out, batch.label)

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            log = {
                "train/loss": loss.item(),
                "global_step": global_step,
            }
            if cfg.clip_grad:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    parameters, cfg.clip_grad_max_norm
                )
                log["grad_norm"] = grad_norm
            scaler.step(optimizer)
            scaler.update()

            if cfg.use_scheduler:
                log["lr"] = scheduler.get_last_lr()[0]
                scheduler.step()
            wandb.log(log)

            epoch_iter.set_description(
                f"[{epoch} {i+1}], train loss: {loss.item():.3f}, test_loss: {test_loss:.3f}, test_acc: {test_acc:.3f}"
            )
            global_step += 1


@hydra.main(config_path="configs", config_name="config", version_base=None)
def main(cfg):
    torch.set_float32_matmul_precision(cfg.matmul_precision)
    torch.backends.cudnn.benchmark = cfg.cudnn_benchmark
    if cfg.seed is not None:
        set_seed(cfg.seed)

    train(cfg)


if __name__ == "__main__":
    main()
