import logging
import importlib
from pathlib import Path

import torch
import torch.nn.functional as F
import wandb
from sklearn.model_selection import train_test_split
from torch import nn
from torch.cuda.amp import GradScaler
from tqdm import trange
import hydra
from omegaconf import OmegaConf
import cv2
import numpy as np
from torchvision import transforms
from einops import rearrange

from experiments.data_nfn import SirenAndOriginalDataset
from experiments.utils import (
    count_parameters,
    get_device,
    set_logger,
    set_seed,
)
from nn.models import DWSModel, MLPModel
from experiments import data_nfn, data as dws_data

set_logger()


def create_object(name, **kwargs):
    module_name, class_name = name.rsplit(".", 1)
    module = importlib.import_module(module_name)
    class_ = getattr(module, class_name)
    return class_(**kwargs)


def params_to_func_params(params):
    """Convert our WeightSpaceFeatures object to a tuple of parameters for the functional model."""
    out_params = []
    for weight, bias in zip(*params):
        assert weight.shape[1] == bias.shape[1] == 1
        out_params.append(weight.squeeze(1))
        out_params.append(bias.squeeze(1))
    return tuple(out_params)


@torch.no_grad()
def evaluate(model, loader, device, batch_siren, log_n_imgs=0, num_batches=None):
    model.eval()
    log_n_imgs = min(log_n_imgs, loader.batch_size)
    loss = 0.0
    imgs, preds = [], []
    losses = []
    for i, batch in enumerate(loader):
        if num_batches is not None and i >= num_batches:
            break
        params, img = batch
        img = img.to(device)
        params = params.to(device)
        inputs = (params.weights, params.biases)
        delta_weights, delta_biases = model(inputs)
        # new_params = []
        new_weights = []
        new_biases = []
        for j in range(len(params.weights)):
            old_weight_i = params.weights[j]
            old_bias_i = params.biases[j]
            new_weights.append((old_weight_i + delta_weights[j]))
            new_biases.append((old_bias_i + delta_biases[j]))
        
        outs = batch_siren(new_weights, new_biases)
        outs = rearrange(outs, "b (h w) c -> b c h w", h=28)
        loss = ((outs - img)**2).mean(dim=(1, 2, 3))
        losses.append(loss.detach().cpu())

        if i == 0 and log_n_imgs > 0:
            imgs.extend([wandb.Image(img[n].detach().cpu().numpy()) for n in range(log_n_imgs)])
            preds.extend([wandb.Image(outs[n].clamp(min=0, max=1).detach().cpu().numpy()) for n in range(log_n_imgs)])

    losses = torch.cat(losses)
    losses = losses.mean()

    model.train()
    return {
        "avg_loss": losses,
        "imgs/gt": imgs,
        "imgs/pred": preds,
    }


def train(cfg):
    OmegaConf.resolve(cfg)
    hydra_cfg = hydra.core.hydra_config.HydraConfig.get()
    ckpt_dir = Path(hydra_cfg.runtime.output_dir) / "checkpoints"
    ckpt_dir.mkdir(parents=True, exist_ok=True)

    if cfg.wandb.name is None:
        cfg.wandb.name = (
            f"mnist_clf_{cfg.model.name}_lr_{cfg.lr}"
            f"_bs_{cfg.batch_size}_seed_{cfg.seed}"
        )

    kernel = np.ones((3, 3), np.uint8)
    style_to_function = {
        'dilate': lambda im: cv2.dilate(im, kernel, iterations=1),
        # 'contrast': inrease_contrast
    }

    wandb.init(
        **cfg.wandb,
        settings=wandb.Settings(start_method="fork"),
        config=cfg,
    )

    device = get_device(gpus=cfg.gpu)

    data_tfm = transforms.Compose([
        transforms.Lambda(np.array),
        transforms.Lambda(style_to_function[cfg.style]),
        transforms.ToTensor(),
        # transforms.Normalize(torch.Tensor([0.5]), torch.Tensor([0.5]))
    ])

    if cfg.data.cls.endswith("SirenAndOriginalDataset"):
        dset = SirenAndOriginalDataset(cfg.siren_path, "randinit_smaller", "./experiments/data", data_tfm)
        
        train_set = torch.utils.data.Subset(dset, range(45_000))
        val_set = torch.utils.data.Subset(dset, range(45_000, 50_000))
        test_set = torch.utils.data.Subset(dset, range(50_000, 60_000))
        # batch_siren = data_nfn.get_batch_siren(dset.data_type)
        batch_siren = data_nfn.BatchSiren(dset.data_type, input_init=data_nfn.get_mgrid(28, 2)).to(device)
        train_loader = torch.utils.data.DataLoader(train_set, batch_size=cfg.batch_size, shuffle=True, num_workers=8, drop_last=True)
        val_loader = torch.utils.data.DataLoader(val_set, batch_size=32, shuffle=False, num_workers=8, drop_last=True)
        test_loader = torch.utils.data.DataLoader(test_set, batch_size=cfg.batch_size, num_workers=8)
    elif cfg.data.name == "dws_mnist":
        # load dataset
        train_set = create_object(cfg.data.cls, **cfg.data.train)
        val_set = create_object(cfg.data.cls, **cfg.data.val)
        test_set = create_object(cfg.data.cls, **cfg.data.test)

        batch_siren = dws_data.BatchSiren(2, inr_d_out=1).to(device)

        train_loader = torch.utils.data.DataLoader(
            dataset=train_set,
            batch_size=cfg.batch_size,
            shuffle=True,
            num_workers=cfg.num_workers,
            pin_memory=True,
        )
        val_loader = torch.utils.data.DataLoader(
            dataset=val_set,
            batch_size=cfg.batch_size,
            num_workers=cfg.num_workers,
            shuffle=False,
            pin_memory=True,
        )
        test_loader = torch.utils.data.DataLoader(
            dataset=test_set,
            batch_size=cfg.batch_size,
            shuffle=False,
            num_workers=cfg.num_workers,
            pin_memory=True,
        )

    logging.info(
        f"train size {len(train_set)}, "
        f"val size {len(val_set)}, "
        f"test size {len(test_set)}"
    )

    if cfg.data.name == "dws_mnist":
        point = train_set[0][0]
        weight_shapes = tuple(w.shape[:2] for w in point.weights)
        bias_shapes = tuple(b.shape[:1] for b in point.biases)
    else:
        point = train_set[0][0]
        weight_shapes = tuple(w.transpose(-1,-2).shape[1:] for w in point[0])
        bias_shapes = tuple(b.shape[1:] for b in point[1])

    layer_layout = [weight_shapes[0][0]] + [b[0] for b in bias_shapes]

    logging.info(f"weight shapes: {weight_shapes}, bias shapes: {bias_shapes}")

    # todo: make defaults for MLP so that parameters for MLP and DWS are the same
    if cfg.model.name == "mlp":
        model = MLPModel(
            in_dim=sum([w.numel() for w in weight_shapes + bias_shapes]),
            **cfg.model.kwargs,
        ).to(device)
    elif cfg.model.name == "dwsnet":
        model = DWSModel(
            weight_shapes=weight_shapes, bias_shapes=bias_shapes, **cfg.model.kwargs
        ).to(device)
    else:
        model = create_object(
            cfg.model.cls,
            layer_layout=layer_layout,
            **cfg.model.kwargs,
        ).to(device)

    logging.info(f"number of parameters: {count_parameters(model)}")

    if cfg.compile:
        model = torch.compile(model, **cfg.compile_kwargs)

    parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
    optimizer = create_object(cfg.optim.cls, params=parameters, **cfg.optim.kwargs)
    if cfg.use_scheduler:
        scheduler = create_object(
            cfg.scheduler.cls,
            optimizer=optimizer,
            **cfg.scheduler.kwargs,
        )

    best_val_loss = 1e5
    best_test_results, best_val_results = None, None
    test_loss = -1.0
    global_step = 0
    start_epoch = 0

    if cfg.load_ckpt:
        ckpt = torch.load(cfg.load_ckpt)
        model.load_state_dict(ckpt["model"])
        if "optimizer" in ckpt:
            optimizer.load_state_dict(ckpt["optimizer"])
        if "scheduler" in ckpt:
            scheduler.load_state_dict(ckpt["scheduler"])
        if "epoch" in ckpt:
            start_epoch = ckpt["epoch"]
        if "global_step" in ckpt:
            global_step = ckpt["global_step"]
        logging.info(f"loaded checkpoint {cfg.load_ckpt}")
    
    epoch_iter = trange(start_epoch, cfg.n_epochs)
    model.train()

    ckpt_dir = Path(hydra_cfg.runtime.output_dir) / wandb.run.path.split("/")[-1]
    ckpt_dir.mkdir(parents=True, exist_ok=True)

    scaler = GradScaler(**cfg.gradscaler)
    autocast_kwargs = dict(cfg.autocast)
    autocast_kwargs["dtype"] = getattr(torch, cfg.autocast.dtype, torch.float32)
    for epoch in epoch_iter:
        for i, batch in enumerate(train_loader):
            optimizer.zero_grad()

            params, img = batch
            img = img.to(device)
            if cfg.data.name == "dws_mnist":
                params = params.to(device)
                weights = params.weights
                biases = params.biases
                inputs = (params.weights, params.biases)
            else:
                weights, biases = params
                weights, biases = ([w.to(device) for w in weights], 
                                [b.to(device) for b in biases])
                inputs = ([w.to(device).permute(0,3,2,1) for w in weights], 
                        [b.to(device).permute(0,2,1) for b in biases])

            with torch.autocast(**autocast_kwargs):
                delta_weights, delta_biases = model(inputs)
                # new_params = []
                new_weights = []
                new_biases = []
                for j in range(len(weights)):
                    if cfg.data.name == "dws_mnist":
                        old_weight_i = weights[j]
                        old_bias_i = biases[j]
                    else:
                        old_weight_i = weights[j].permute(0,3,2,1)
                        old_bias_i = biases[j].permute(0,2,1) 
                    new_weights.append((old_weight_i + delta_weights[j]))
                    new_biases.append((old_bias_i + delta_biases[j]))
                
                outs = batch_siren(new_weights, new_biases)
                outs = rearrange(outs, "b (h w) c -> b c h w", h=28)
                loss = ((outs - img)**2).mean()

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            log = {
                "train/loss": loss.item(),
                "global_step": global_step,
            }
            if cfg.clip_grad:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    parameters, cfg.clip_grad_max_norm
                )
                log["grad_norm"] = grad_norm
            scaler.step(optimizer)
            scaler.update()

            if i == 0:
                log["img/train/gt"] = [wandb.Image(img[i].cpu().numpy()) for i in range(cfg.log_n_imgs)]
                log["img/train/pred"] = [wandb.Image(outs[i].clamp(min=0, max=1).detach().cpu().numpy()) for i in range(cfg.log_n_imgs)]

            if cfg.use_scheduler:
                log["lr"] = scheduler.get_last_lr()[0]
                scheduler.step()
            wandb.log(log)

            epoch_iter.set_description(
                f"[{epoch} {i+1}], train loss: {loss.item():.3f}, test_loss: {test_loss:.3f}"
            )
            global_step += 1

        if (epoch + 1) % cfg.eval_every == 0 or epoch == cfg.n_epochs - 1:
            torch.save(
                {
                    "model": model.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "epoch": epoch,
                    "cfg": cfg,
                    "global_step": global_step,
                },
                ckpt_dir / "latest.ckpt",
            )

            val_loss_dict = evaluate(model, val_loader, device, batch_siren, 
                                     log_n_imgs=cfg.log_n_imgs)
            test_loss_dict = evaluate(model, test_loader, device, batch_siren, 
                                     log_n_imgs=cfg.log_n_imgs)
            val_loss = val_loss_dict["avg_loss"]
            test_loss = test_loss_dict["avg_loss"]
            train_loss_dict = evaluate(model, train_loader, device, batch_siren, 
                                     log_n_imgs=cfg.log_n_imgs, num_batches=100)

            best_val_criteria = val_loss < best_val_loss

            if best_val_criteria:
                torch.save(
                    {
                        "model": model.state_dict(),
                        "optimizer": optimizer.state_dict(),
                        "epoch": epoch,
                        "cfg": cfg,
                        "global_step": global_step,
                    },
                    ckpt_dir / "best_val.ckpt",
                )
                best_test_results = test_loss_dict
                best_val_results = val_loss_dict
                best_val_loss = val_loss

            log = {
                "train/avg_loss": train_loss_dict["avg_loss"],
                "val/best_loss": best_val_results["avg_loss"],
                "test/best_loss": best_test_results["avg_loss"],
                **{f"val/{k}": v for k, v in val_loss_dict.items()},
                **{f"test/{k}": v for k, v in test_loss_dict.items()},
                "epoch": epoch,
                "global_step": global_step,
            }

            wandb.log(log)


@hydra.main(config_path="configs", config_name="config_style", version_base=None)
def main(cfg):
    torch.set_float32_matmul_precision(cfg.matmul_precision)
    torch.backends.cudnn.benchmark = cfg.cudnn_benchmark
    if cfg.seed is not None:
        set_seed(cfg.seed)

    train(cfg)


if __name__ == "__main__":
    main()
