import logging
import importlib
from pathlib import Path

import torch
import torch.nn.functional as F
import wandb
from sklearn.model_selection import train_test_split
from torch import nn
from torch.cuda.amp import GradScaler
from tqdm import trange
import hydra
from omegaconf import OmegaConf
from matplotlib import pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.manifold import TSNE
import numpy as np
import pandas as pd
import seaborn as sns

from experiments.ssl.loss import info_nce_loss
from experiments.data import INRDataset
from experiments.utils import (
    count_parameters,
    get_device,
    set_logger,
    set_seed,
)
from nn.models import DWSModelForClassification, MLPModelForClassification

set_logger()


def create_object(name, **kwargs):
    module_name, class_name = name.rsplit(".", 1)
    module = importlib.import_module(module_name)
    class_ = getattr(module, class_name)
    return class_(**kwargs)



@torch.no_grad()
def evaluate(model, projection, loader, temperature, device):
    model.eval()
    loss = 0.0
    correct = 0.0
    total = 0.0
    all_features = []
    all_labels = []
    for batch in loader:
        batch = batch.to(device)
        inputs = (
            tuple(
                torch.cat([w, aug_w])
                for w, aug_w in zip(batch.weights, batch.aug_weights)
            ),
            tuple(
                torch.cat([b, aug_b])
                for b, aug_b in zip(batch.biases, batch.aug_biases)
            ),
        )
        features = model(inputs)
        zs = projection(features)
        logits, labels = info_nce_loss(zs, temperature)
        loss += F.cross_entropy(logits, labels, reduction="sum")
        total += len(labels)
        real_bs = batch.weights[0].shape[0]
        pred = logits.argmax(1)
        correct += pred.eq(labels).sum()
        all_features.append(features[:real_bs, :].cpu().numpy().tolist())
        all_labels.extend(batch.label.cpu().numpy().tolist())

    model.train()
    avg_loss = loss / total
    avg_acc = correct / total

    return dict(
        avg_loss=avg_loss,
        avg_acc=avg_acc,
        features=np.concatenate(all_features),
        labels=np.array(all_labels),
    )

def train(cfg):
    OmegaConf.resolve(cfg)
    hydra_cfg = hydra.core.hydra_config.HydraConfig.get()
    ckpt_dir = Path(hydra_cfg.runtime.output_dir) / "checkpoints"
    ckpt_dir.mkdir(parents=True, exist_ok=True)

    if cfg.wandb.name is None:
        cfg.wandb.name = (
            f"mnist_clf_{cfg.model.name}_lr_{cfg.lr}"
            f"_bs_{cfg.batch_size}_seed_{cfg.seed}"
        )

    wandb.init(
        **cfg.wandb,
        settings=wandb.Settings(start_method="fork"),
        config=cfg,
    )

    device = get_device(gpus=cfg.gpu)

    # load dataset
    train_set = create_object(cfg.data.cls, **cfg.data.train)
    val_set = create_object(cfg.data.cls, **cfg.data.val)
    test_set = create_object(cfg.data.cls, **cfg.data.test)

    train_loader = torch.utils.data.DataLoader(
        dataset=train_set,
        batch_size=cfg.batch_size,
        shuffle=True,
        num_workers=cfg.num_workers,
        pin_memory=True,
    )
    val_loader = torch.utils.data.DataLoader(
        dataset=val_set,
        batch_size=cfg.batch_size,
        num_workers=cfg.num_workers,
        shuffle=False,
        pin_memory=True,
    )
    test_loader = torch.utils.data.DataLoader(
        dataset=test_set,
        batch_size=cfg.batch_size,
        shuffle=False,
        num_workers=cfg.num_workers,
        pin_memory=True,
    )

    logging.info(
        f"train size {len(train_set)}, "
        f"val size {len(val_set)}, "
        f"test size {len(test_set)}"
    )

    point = train_set[0]
    weight_shapes = tuple(w.shape[:2] for w in point.weights)
    bias_shapes = tuple(b.shape[:1] for b in point.biases)

    layer_layout = [weight_shapes[0][0]] + [b[0] for b in bias_shapes]

    logging.info(f"weight shapes: {weight_shapes}, bias shapes: {bias_shapes}")

    # todo: make defaults for MLP so that parameters for MLP and DWS are the same
    if cfg.model.name == "mlp":
        model = MLPModelForClassification(
            in_dim=sum([w.numel() for w in weight_shapes + bias_shapes]),
            **cfg.model.kwargs,
        ).to(device)
    elif cfg.model.name == "dwsnet":
        model = DWSModelForClassification(
            weight_shapes=weight_shapes, bias_shapes=bias_shapes, **cfg.model.kwargs
        ).to(device)
    else:
        model = create_object(
            cfg.model.cls,
            layer_layout=layer_layout,
            **cfg.model.kwargs,
        ).to(device)

    projection = nn.Sequential(
        nn.Linear(cfg.embedding_dim, cfg.embedding_dim),
        nn.ReLU(),
        nn.Linear(cfg.embedding_dim, cfg.embedding_dim),
    ).to(device)

    logging.info(f"number of parameters: {count_parameters(model)}")

    if cfg.compile:
        model = torch.compile(model, **cfg.compile_kwargs)

    parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
    optimizer = create_object(cfg.optim.cls, params=parameters, **cfg.optim.kwargs)
    if cfg.use_scheduler:
        scheduler = create_object(
            cfg.scheduler.cls,
            optimizer=optimizer,
            **cfg.scheduler.kwargs,
        )

    criterion = nn.CrossEntropyLoss()
    best_val_loss = 1e6
    best_test_results, best_val_results = None, None
    test_acc, test_loss = -1.0, -1.0
    global_step = 0
    start_epoch = 0

    if cfg.load_ckpt:
        ckpt = torch.load(cfg.load_ckpt)
        model.load_state_dict(ckpt["model"])
        if "optimizer" in ckpt:
            optimizer.load_state_dict(ckpt["optimizer"])
        if "scheduler" in ckpt:
            scheduler.load_state_dict(ckpt["scheduler"])
        if "epoch" in ckpt:
            start_epoch = ckpt["epoch"]
        if "global_step" in ckpt:
            global_step = ckpt["global_step"]
        logging.info(f"loaded checkpoint {cfg.load_ckpt}")
    
    epoch_iter = trange(start_epoch, cfg.n_epochs)
    model.train()

    ckpt_dir = Path(hydra_cfg.runtime.output_dir) / wandb.run.path.split("/")[-1]
    ckpt_dir.mkdir(parents=True, exist_ok=True)

    scaler = GradScaler(**cfg.gradscaler)
    autocast_kwargs = dict(cfg.autocast)
    autocast_kwargs["dtype"] = getattr(torch, cfg.autocast.dtype, torch.float32)
    for epoch in epoch_iter:
        for i, batch in enumerate(train_loader):
            optimizer.zero_grad()

            batch = batch.to(device)
            # inputs = (batch.weights, batch.biases)
            inputs = (
                tuple(
                    torch.cat([w, aug_w])
                    for w, aug_w in zip(batch.weights, batch.aug_weights)
                ),
                tuple(
                    torch.cat([b, aug_b])
                    for b, aug_b in zip(batch.biases, batch.aug_biases)
                ),
            )

            with torch.autocast(**autocast_kwargs):
                features = model(inputs)
                zs = projection(features)
                logits, labels = info_nce_loss(zs, cfg.temperature)
                loss = criterion(logits, labels)

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            log = {
                "train/loss": loss.item(),
                "global_step": global_step,
            }
            if cfg.clip_grad:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    parameters, cfg.clip_grad_max_norm
                )
                log["grad_norm"] = grad_norm
            scaler.step(optimizer)
            scaler.update()

            if cfg.use_scheduler:
                log["lr"] = scheduler.get_last_lr()[0]
                scheduler.step()
            wandb.log(log)

            epoch_iter.set_description(
                f"[{epoch} {i+1}], train loss: {loss.item():.3f}, test_loss: {test_loss:.3f}, test_acc: {test_acc:.3f}"
            )
            global_step += 1

        if (epoch + 1) % cfg.eval_every == 0 or epoch == cfg.n_epochs - 1:
            torch.save(
                {
                    "model": model.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "epoch": epoch,
                    "cfg": cfg,
                    "global_step": global_step,
                },
                ckpt_dir / "latest.ckpt",
            )

            val_loss_dict = evaluate(model, projection, val_loader, cfg.temperature, device)
            test_loss_dict = evaluate(model, projection, test_loader, cfg.temperature, device)
            val_loss = val_loss_dict["avg_loss"]
            val_acc = val_loss_dict["avg_acc"]
            test_loss = test_loss_dict["avg_loss"]
            test_acc = test_loss_dict["avg_acc"]

            best_val_criteria = val_loss <= best_val_loss

            if best_val_criteria:
                best_val_loss = val_loss
                best_test_results = test_loss_dict
                best_val_results = val_loss_dict
                torch.save(
                    {
                        "model": model.state_dict(),
                        "optimizer": optimizer.state_dict(),
                        "epoch": epoch,
                        "cfg": cfg,
                        "global_step": global_step,
                    },
                    ckpt_dir / "best_val.ckpt",
                )

            log = {
                "val/loss": val_loss,
                "val/acc": val_acc,
                "val/best_loss": best_val_results["avg_loss"],
                "val/best_acc": best_val_results["avg_acc"],
                "test/loss": test_loss,
                "test/acc": test_acc,
                "test/best_loss": best_test_results["avg_loss"],
                "test/best_acc": best_test_results["avg_acc"],
                "epoch": epoch,
                "epoch": epoch,
                "global_step": global_step,
            }
            if (epoch + 1) % (cfg.eval_every * 1) == 0:
                train_loss_dict = evaluate(model, projection, train_loader, cfg.temperature, device)

                reg = LinearRegression().fit(
                    train_loss_dict["features"], train_loss_dict["labels"]
                )
                preds_test = reg.predict(test_loss_dict["features"])
                preds_val = reg.predict(val_loss_dict["features"])

                reg_mse_loss = np.square(
                    test_loss_dict["labels"] - preds_test
                ).mean()
                reg_mae_loss = np.abs(test_loss_dict["labels"] - preds_test).mean()

                val_reg_mse_loss = np.square(
                    val_loss_dict["labels"] - preds_val
                ).mean()
                val_reg_mae_loss = np.abs(
                    val_loss_dict["labels"] - preds_val
                ).mean()

                if cfg.embedding_dim == 2:
                    low_dim_features = test_loss_dict["features"]
                else:
                    low_dim_features = TSNE(
                        n_components=2, random_state=42
                    ).fit_transform(test_loss_dict["features"])

                data = [
                    [*x, *y]
                    for (x, y) in zip(low_dim_features, test_loss_dict["labels"])
                ]
                table = wandb.Table(
                    data=data, columns=["f1", "f2", "label1", "label2"]
                )
                df = pd.DataFrame(data, columns=["f1", "f2", "label1", "label2"])
                fig, ax = plt.subplots()
                extra_params = dict(
                    palette="RdBu"
                )  # sns.cubehelix_palette(as_cmap=True))
                sns.scatterplot(
                    data=df,
                    x="f1",
                    y="f2",
                    hue="label1",
                    size="label2",
                    ax=ax,
                    **extra_params,
                )

                log.update(
                    {
                        "test/scatter": wandb.Image(plt),
                        "test/reg_mse_loss": reg_mse_loss,
                        "test/reg_mae_loss": reg_mae_loss,
                        "val/reg_mse_loss": val_reg_mse_loss,
                        "val/reg_mae_loss": val_reg_mae_loss,
                        "pred_table": table,
                    }
                )
                plt.close(fig)

            wandb.log(log)


@hydra.main(config_path="configs", config_name="config_ssl", version_base=None)
def main(cfg):
    torch.set_float32_matmul_precision(cfg.matmul_precision)
    torch.backends.cudnn.benchmark = cfg.cudnn_benchmark
    if cfg.seed is not None:
        set_seed(cfg.seed)

    train(cfg)


if __name__ == "__main__":
    main()
