import logging
import importlib
from pathlib import Path
from copy import deepcopy

import torch
import torch.nn.functional as F
import wandb
from sklearn.model_selection import train_test_split
from torch import nn
from torch.cuda.amp import GradScaler
from tqdm import trange
import hydra
from omegaconf import OmegaConf
from einops import rearrange
from torchvision.utils import make_grid

from experiments.data import INRDataset
from experiments.utils import (
    count_parameters,
    get_device,
    set_logger,
    set_seed,
)
from experiments.mnist.flow_matching import noise_to_params, flow_matching_loss, VFWrapper
from nn.models import DWSModelForClassification, MLPModelForClassification
from nn.inr import INRWrapper
from diffusion import create_diffusion
from diffusion.timestep_sampler import UniformSampler
from nn.rtt_time_diff import flat_to_inputs, inputs_to_flat
set_logger()


def create_object(name, **kwargs):
    module_name, class_name = name.rsplit(".", 1)
    module = importlib.import_module(module_name)
    class_ = getattr(module, class_name)
    return class_(**kwargs)


def scaled_all_reduce(cfg, tensors):
    """
    Performs the scaled all_reduce operation on the provided tensors.

    The input tensors are modified in-place. Currently supports only the sum
    reduction operator. The reduced values are scaled by the inverse size of the
    process group (equivalent to cfg.NUM_GPUS).
    """
    # There is no need for reduction in the single-proc case
    if cfg.num_gpus == 1:
        return tensors
    # Queue the reductions
    reductions = []
    for tensor in tensors:
        reduction = torch.distributed.all_reduce(tensor, async_op=True)
        reductions.append(reduction)
    # Wait for reductions to finish
    for reduction in reductions:
        reduction.wait()
    # Scale the results
    for tensor in tensors:
        tensor.mul_(1.0 / cfg.num_gpus)
    return tensors


def accumulate(model1, model2, decay=0.9999):  # https://github.com/rosinality/stylegan2-pytorch/blob/master/train.py
    par1 = dict(model1.named_parameters())
    par2 = dict(model2.named_parameters())
    alpha = 1 - decay

    for k in par1.keys():
        par1[k].data.mul_(decay).add_(par2[k].data, alpha=alpha)


class Normalize(nn.Module):
    def __init__(self):
        super().__init__()
        # self.weights_mean = [-0.0001166215879493393, -3.2710825053072767e-06, 7.234242366394028e-05]
        # self.weights_std = [0.06279338896274567, 0.01827024295926094, 0.11813738197088242]
        # self.biases_mean = [4.912401891488116e-06, -3.210141949239187e-05, -0.012279038317501545]
        # self.biases_std = [0.021347912028431892, 0.0109943225979805, 0.09998151659965515]

        self.weights_mean = [0] * 3
        self.weights_std = [0.0552541275198261] * 3
        self.biases_mean = [0] * 3
        self.biases_std = [0.0552541275198261] * 3


        self.in_coeff = 0.53

    def forward(self, weights, biases):
        new_weights = []
        new_biases = []
        for i in range(len(weights)):
            new_weights.append(self.in_coeff * (weights[i] - self.weights_mean[i]) / self.weights_std[i])
            new_biases.append(self.in_coeff * (biases[i] - self.biases_mean[i]) / self.biases_std[i])
        return new_weights, new_biases
    
    def inverse(self, weights, biases):
        new_weights = []
        new_biases = []
        for i in range(len(weights)):
            new_weights.append(weights[i] * self.weights_std[i] / self.in_coeff + self.weights_mean[i])
            new_biases.append(biases[i] * self.biases_std[i] / self.in_coeff + self.biases_mean[i])
        return new_weights, new_biases
    

def run_diffusion_vlb(cfg, diffusion, model, timestep_sampler, inputs, model_kwargs):
    """
    Computes the diffusion training loss for a batch of inputs.
    """
    # w_t, w_t1 = batch_dict["parameters_0"].cuda(), batch_dict["parameters_1"].cuda()
    # loss_t, loss_t1 = \
    #     batch_dict[f"{cfg.dataset.train_metric}_0"].cuda(), \
    #     batch_dict[f"{cfg.dataset.train_metric}_1"].cuda()
    t, vlb_weights = timestep_sampler.sample(inputs.shape[0], inputs.device)
    # with torch.cuda.amp.autocast(enabled=cfg.amp):
    losses = diffusion.training_losses(model, inputs, t, model_kwargs=model_kwargs)
    loss = (losses["loss"] * vlb_weights).mean()
    return loss, losses


@torch.no_grad()
def evaluate(cfg, diffusion, model, noise, normalizer):
    model.eval()
    # vector_field = VFWrapper(model)
    # noise = normalizer(*noise)
    noise_flat, noise_shape = inputs_to_flat(*noise)
    sample = diffusion.p_sample_loop(model, noise_flat.shape, noise=noise_flat, clip_denoised=False, 
                            model_kwargs=dict(input_shapes=noise_shape), **cfg.sample_kwargs)
    weights, biases = flat_to_inputs(sample, noise_shape)
    # weights, biases = noise_to_params(vector_field, noise)
    # normalizer = Normalize()
    weights, biases = normalizer.inverse(weights, biases)
    siren_model = INRWrapper(inr_kwargs=dict(in_dim=2, n_layers=3, up_scale=16, out_channels=1))
    imgs = siren_model(weights, biases)
    imgs = rearrange(imgs, "b (h w) 1 -> b 1 h w", h=28)

    model.train()

    return {
        "imgs": imgs,
        **{"sampled/weights_avg": w.mean() for w in weights},
        **{"sampled/weights_std": w.std() for w in weights},
        **{"sampled/biases_avg": b.mean() for b in biases},
        **{"sampled/biases_std": b.std() for b in biases},
    }, (weights, biases)


@torch.inference_mode()
def test_epoch(cfg, diffusion, model, test_loader, timestep_sampler, device, normalizer):
    """
    Evaluate G.pt on test set (unseen) neural networks.
    """
    log_dict = {}
    model.eval()
    for batch_ind, batch in enumerate(test_loader):
        batch = batch.to(device)
        inputs = normalizer(batch.weights, batch.biases)
        inputs_flat, input_shapes = inputs_to_flat(*inputs)
        loss, loss_dict = run_diffusion_vlb(cfg, diffusion, model, timestep_sampler, 
                                            inputs_flat, dict(input_shapes=input_shapes))
        loss_dict["loss"] = loss.view(1)
        loss_dict = {
            k: scaled_all_reduce(cfg, [v.mean()])[0].item() for k, v in loss_dict.items()
        }
        for k, v in loss_dict.items():
            if k not in log_dict:
                log_dict[k] = []
            log_dict[k].append(v)
    log_dict = {f"test/{k}": sum(v) / len(v) for k, v in log_dict.items()}
    return log_dict

def train(cfg):
    hydra_cfg = hydra.core.hydra_config.HydraConfig.get()
    ckpt_dir = Path(hydra_cfg.runtime.output_dir) / "checkpoints"
    ckpt_dir.mkdir(parents=True, exist_ok=True)

    if cfg.wandb.name is None:
        cfg.wandb.name = (
            f"mnist_clf_{cfg.model.name}_lr_{cfg.lr}"
            f"_bs_{cfg.batch_size}_seed_{cfg.seed}"
        )

    wandb.init(
        **cfg.wandb,
        settings=wandb.Settings(start_method="fork"),
        config=cfg,
    )
    OmegaConf.resolve(cfg)

    device = get_device(gpus=cfg.gpu)

    # load dataset
    train_set = create_object(cfg.data.cls, **cfg.data.train)
    # val_set = create_object(cfg.data.cls, **cfg.data.val)
    test_set = create_object(cfg.data.cls, **cfg.data.test)

    train_loader = torch.utils.data.DataLoader(
        dataset=train_set,
        batch_size=cfg.batch_size,
        shuffle=True,
        num_workers=cfg.num_workers,
        pin_memory=True,
    )
    # val_loader = torch.utils.data.DataLoader(
    #     dataset=val_set,
    #     batch_size=cfg.batch_size,
    #     num_workers=cfg.num_workers,
    #     shuffle=False,
    #     pin_memory=True,
    # )
    test_loader = torch.utils.data.DataLoader(
        dataset=test_set,
        batch_size=cfg.batch_size,
        shuffle=False,
        num_workers=cfg.num_workers,
        pin_memory=True,
    )

    logging.info(
        f"train size {len(train_set)}, "
        # f"val size {len(val_set)}, "
        # f"test size {len(test_set)}"
    )

    point = train_set[0]
    weight_shapes = tuple(w.shape[:2] for w in point.weights)
    bias_shapes = tuple(b.shape[:1] for b in point.biases)

    layer_layout = [weight_shapes[0][0]] + [b[0] for b in bias_shapes]

    logging.info(f"weight shapes: {weight_shapes}, bias shapes: {bias_shapes}")

    # todo: make defaults for MLP so that parameters for MLP and DWS are the same
    if cfg.model.name == "mlp":
        model = MLPModelForClassification(
            in_dim=sum([w.numel() for w in weight_shapes + bias_shapes]),
            **cfg.model.kwargs,
        ).to(device)
    elif cfg.model.name == "dwsnet":
        model = DWSModelForClassification(
            weight_shapes=weight_shapes, bias_shapes=bias_shapes, **cfg.model.kwargs
        ).to(device)
    else:
        model = create_object(
            cfg.model.cls,
            layer_layout=layer_layout,
            **cfg.model.kwargs,
        ).to(device)

    logging.info(f"number of parameters: {count_parameters(model)}")

    if cfg.compile:
        model = torch.compile(model, **cfg.compile_kwargs)

    if cfg.ema:
        ema = create_object(
            cfg.model.cls,
            layer_layout=layer_layout,
            **cfg.model.kwargs,
        ).to(device)
        for p in ema.parameters():
            p.requires_grad = False
        # Initialize the EMA model with an exact copy of weights
        accumulate(ema, model, 0)
    else:
        ema = None

    parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
    optimizer = create_object(cfg.optim.cls, params=parameters, **cfg.optim.kwargs)
    # model.configure_optimizer
    if cfg.use_scheduler:
        scheduler = create_object(
            cfg.scheduler.cls,
            optimizer=optimizer,
            **cfg.scheduler.kwargs,
        )

    # best_val_acc = -1
    # best_test_results, best_val_results = None, None
    test_acc, test_loss = -1.0, -1.0
    global_step = 0
    start_epoch = 0

    if cfg.load_ckpt:
        ckpt = torch.load(cfg.load_ckpt)
        model.load_state_dict(ckpt["model"])
        if "optimizer" in ckpt:
            optimizer.load_state_dict(ckpt["optimizer"])
        if "scheduler" in ckpt:
            scheduler.load_state_dict(ckpt["scheduler"])
        if "epoch" in ckpt:
            start_epoch = ckpt["epoch"]
        if "global_step" in ckpt:
            global_step = ckpt["global_step"]
        logging.info(f"loaded checkpoint {cfg.load_ckpt}")
    
    batch_ = next(iter(train_loader)).to(device)
    noise_weights = [torch.randn_like(w)[:cfg.n_samples] for w in batch_.weights]
    noise_biases = [torch.randn_like(b)[:cfg.n_samples] for b in batch_.biases]
    for i in range(len(batch_.weights)):
        noise_weights[i][:] = noise_weights[i][:1]
        noise_biases[i][:] = noise_biases[i][:1]
    noise = (*noise_weights, *noise_biases)

    siren_model = INRWrapper(inr_kwargs=dict(in_dim=2, n_layers=3, up_scale=16, out_channels=1))
    imgs = siren_model(batch_.weights, batch_.biases)
    imgs = rearrange(imgs, "b (h w) 1 -> b 1 h w", h=28)
    imgs = imgs.clamp(0, 1)
    img_grid = make_grid(imgs)
    log = {
        "gt_imgs": wandb.Image(img_grid),
        "epoch": start_epoch,
        "global_step": global_step,
        **{f"gt/{i}/weights_avg": w.mean() for i, w in enumerate(batch_.weights)},
        **{f"gt/{i}/weights_std": w.std() for i, w in enumerate(batch_.weights)},
        **{f"gt/{i}/biases_avg": b.mean() for i, b in enumerate(batch_.biases)},
        **{f"gt/{i}/biases_std": b.std() for i, b in enumerate(batch_.biases)},
    }
    wandb.log(log)
    normalizer = Normalize()

    # diffusion = create_object(cfg.diffusion.cls, **cfg.diffusion.kwargs)
    diffusion = create_diffusion(**cfg.diffusion)
    timestep_sampler = UniformSampler(diffusion)

    epoch_iter = trange(start_epoch, cfg.n_epochs)
    model.train()

    ckpt_dir = Path(hydra_cfg.runtime.output_dir) / wandb.run.path.split("/")[-1]
    ckpt_dir.mkdir(parents=True, exist_ok=True)

    scaler = GradScaler(**cfg.gradscaler)
    autocast_kwargs = dict(cfg.autocast)
    autocast_kwargs["dtype"] = getattr(torch, cfg.autocast.dtype, torch.float32)
    for epoch in epoch_iter:
        for i, batch in enumerate(train_loader):
            optimizer.zero_grad()

            if cfg.debug:
                if i > 10:
                    break

            batch = batch.to(device)
            inputs = normalizer(batch.weights, batch.biases)
            inputs_flat, input_shapes = inputs_to_flat(*inputs)


            with torch.autocast(**autocast_kwargs):
                loss, loss_dict = run_diffusion_vlb(cfg, diffusion, model, timestep_sampler, 
                                                    inputs_flat, model_kwargs=dict(input_shapes=input_shapes))
                # loss = flow_matching_loss(model, inputs, 
                #                           noise=([w[:len(batch)] for w in noise_weights], 
                #                                  [b[:len(batch)] for b in noise_biases]))
                loss = loss.mean()

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            log = {
                "train/loss": loss.item(),
                "global_step": global_step,
            }
            if cfg.clip_grad:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    parameters, cfg.clip_grad_max_norm
                )
                log["grad_norm"] = grad_norm
            scaler.step(optimizer)
            scaler.update()

            if cfg.ema:
                accumulate(ema, model, cfg.ema_decay)

            if cfg.use_scheduler:
                log["lr"] = scheduler.get_last_lr()[0]
                scheduler.step()
            wandb.log(log)

            epoch_iter.set_description(
                f"[{epoch} {i+1}], train loss: {loss.item():.3f}, test_loss: {test_loss:.3f}, test_acc: {test_acc:.3f}"
            )
            global_step += 1

        if (epoch + 1) % cfg.eval_every == 0 or epoch == cfg.n_epochs - 1:
            torch.save(
                {
                    "model": model.state_dict(),
                    "ema": ema.state_dict() if cfg.ema else None,
                    "optimizer": optimizer.state_dict(),
                    "epoch": epoch,
                    "cfg": cfg,
                    "global_step": global_step,
                },
                ckpt_dir / "latest.ckpt",
            )

            # log_dict = test_epoch(cfg, diffusion, ema if cfg.ema else model, test_loader, timestep_sampler, device, normalizer)

            eval_dict, (weights, biases) = evaluate(cfg, diffusion, ema if cfg.ema else model, 
                                               (noise_weights, noise_biases), normalizer)

            mse_w = [(w_ - w).square().mean() for w_, w in zip(batch_.weights, weights)]
            mse_b = [(b_ - b).square().mean() for b_, b in zip(batch_.biases, biases)]
            # imgs = evaluate(model, noise)["imgs"]
            imgs = eval_dict["imgs"]
            imgs = imgs.clamp(0, 1)
            img_grid = make_grid(imgs)
            log_dict = {
                "imgs": wandb.Image(img_grid),
                "epoch": epoch,
                "global_step": global_step,
                **{f"test/{i}/mse_w": w for i, w in enumerate(mse_w)},
                **{f"test/{i}/mse_b": b for i, b in enumerate(mse_b)},
                # **log_dict,
            }

            wandb.log(log_dict)


@hydra.main(config_path="configs", config_name="config_diff", version_base=None)
def main(cfg):
    torch.set_float32_matmul_precision(cfg.matmul_precision)
    torch.backends.cudnn.benchmark = cfg.cudnn_benchmark
    if cfg.seed is not None:
        set_seed(cfg.seed)

    train(cfg)


if __name__ == "__main__":
    main()
