import copy
from typing import Dict, List

import hydra
import numpy as np
import omegaconf
import torch
from loguru import logger
from omegaconf import DictConfig
from torch.utils.data import DataLoader
from tqdm import tqdm

import wandb
from dataset import get_dataloader
from utils import initialize_and_load_model, retry, seed_everything, test

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

torch.set_default_dtype(torch.float64)


def interpolate_weights(
    model_a: torch.nn.Module, model_b: torch.nn.Module, lam: float
) -> Dict[str, torch.Tensor]:
    model_a_dict = copy.deepcopy(model_a.state_dict())
    model_b_dict = model_b.state_dict()
    interpolated_dict = {}

    for name in model_a_dict:
        interpolated_dict[name] = (1 - lam) * model_a_dict[name] + lam * model_b_dict[
            name
        ]

    return interpolated_dict


def interpolate_and_test(
    model_a: torch.nn.Module,
    model_b: torch.nn.Module,
    lambdas: torch.Tensor,
    device: torch.device,
    loaders: Dict[str, DataLoader],
    process_train_data: bool = True,
) -> Dict[str, List[float]]:
    accuracies = {"train": [], "test": []}
    losses = {"train": [], "test": []}

    model_a_interpolated = copy.deepcopy(model_a)
    for lam in tqdm(lambdas, desc="Interpolating"):
        interpolated_weights = interpolate_weights(model_a, model_b, lam)
        model_a_interpolated.load_state_dict(interpolated_weights)
        for phase, loader in loaders.items():
            if phase == "train" and not process_train_data:
                logger.info(
                    "Debug mode is detected. Training data evaluation is skipped."
                )
                continue  # 訓練データの処理をスキップ
            loss, acc = test(model_a_interpolated.to(device), loader)
            losses[phase].append(loss)
            accuracies[phase].append(acc)
    return losses, accuracies


@retry(max_attempts=3, timeout=120, initial_wait=15, backoff_factor=4)
def wandb_logging(
    cfg: DictConfig,
    results: Dict[str, List[float]],
    lambdas: torch.Tensor,
    mode: str,
    strategy: str,
) -> None:
    log_dict = {}
    log_dict.update(results)
    for k1, v1 in results.items():
        for k2, v2 in v1.items():
            if len(v2):
                _barrier = []
                if "loss" in k1:
                    for i, lam in enumerate(lambdas):
                        _barrier.append(v2[i] - ((1 - lam) * v2[0] + lam * v2[-1]))
                    barrier = max(_barrier)
                elif "accuracy" in k1:
                    for i, lam in enumerate(lambdas):
                        _barrier.append(((1 - lam) * v2[0] + lam * v2[-1]) - v2[i])
                    barrier = max(_barrier)

                log_dict.update(
                    {
                        f"{k2}_barrier_{k1}": barrier,
                    }
                )
    log_dict.update({"mode": mode, "strategy": strategy})

    wandb.log(log_dict)


def initialize_and_interpolate_test_models(
    strategy, mode, cfg, lambdas, device, loaders
):
    model_b_name = f"{mode}_{cfg.model_b}"
    model_a_suffix = f"{mode}_{cfg.model_a}"
    if mode == "non_oblivious":
        model_variants = {
            "naive": "",
            "perm": f"{strategy}_permuted_noop_",
        }
        if cfg.depth <= 3:
            model_variants.update(
                {
                    "flip_perm": f"{strategy}_permuted_flip_",
                }
            )

    elif mode in ("decision_list", "decision_list_no_binary"):
        model_variants = {
            "naive": "",
            "perm": f"{strategy}_permuted_noop_",
        }

    elif mode == "oblivious":
        model_variants = {
            "naive": "",
            "perm_noop": f"{strategy}_permuted_noop_",
        }
        if cfg.depth <= 3:
            model_variants.update(
                {
                    "perm_onlyflip": f"{strategy}_permuted_onlyflip_",
                    "perm_onlyorder": f"{strategy}_permuted_onlyorder_",
                    "perm_orderflip": f"{strategy}_permuted_orderflip_",
                }
            )

    elif mode in ("relu_mlp", "sigmoid_mlp"):
        model_variants = {
            "naive": "",
            "perm": f"{strategy}_permuted_",
        }
    else:
        raise NotImplementedError

    results = {}
    for variant_name, variant_suffix in model_variants.items():
        model_a_name = f"{variant_suffix}{model_a_suffix}"
        model_a = initialize_and_load_model(
            f"./output/model/{model_a_name}",
            cfg.dataset.input_dim,
            cfg.dataset.output_dim,
            cfg.depth,
            cfg.alpha,
            cfg.n_tree,
            mode,
        ).to(device)
        model_b = initialize_and_load_model(
            f"./output/model/{model_b_name}",
            cfg.dataset.input_dim,
            cfg.dataset.output_dim,
            cfg.depth,
            cfg.alpha,
            cfg.n_tree,
            mode,
        ).to(device)
        logger.info(f"model-a ({variant_name}) vs model-b")
        losses, accuracies = interpolate_and_test(
            model_a,
            model_b,
            lambdas,
            device,
            loaders,
            process_train_data=(
                cfg.seed_a != cfg.seed_b
            ),  # シードが同じ場合はデバッグ用
        )
        results[f"loss_{variant_name}"] = losses
        results[f"accuracy_{variant_name}"] = accuracies
    return results


@hydra.main(config_path="./config/", config_name="config", version_base=None)
def main(cfg: DictConfig) -> None:
    seed_everything(cfg.seed_a)
    train_loader, test_loader = get_dataloader(
        cfg.dataset.name, batch_size=cfg.batch_size
    )
    loaders = {"train": train_loader, "test": test_loader}

    lambdas = torch.linspace(0, 1, steps=25)

    for strategy in ["wm", "am"]:
        for mode in [
            "non_oblivious",
            "oblivious",
            "decision_list",
            "decision_list_no_binary",
            "relu_mlp",
            "sigmoid_mlp",
        ]:
            if (cfg.depth > 3) and mode == "non_oblivious":
                logger.info(f"depth={cfg.depth}. {mode} is skipped.")
                continue
            if strategy == "am" and mode in ("relu_mlp", "sigmoid_mlp"):
                logger.info(f"stategy={strategy}. {mode} is skipped.")
                continue
            logger.info(f"Strategy: {strategy}, Mode: {mode}")
            results = initialize_and_interpolate_test_models(
                strategy, mode, cfg, lambdas, device, loaders
            )
            if cfg.wandb:
                wandb.init(
                    project=f"{cfg.wandb_project}",
                    config=omegaconf.OmegaConf.to_container(
                        cfg, resolve=True, throw_on_missing=True
                    ),
                    mode="offline",
                )
                try:
                    wandb_logging(cfg, results, lambdas, mode, strategy)
                    wandb.finish()
                except:
                    wandb.run.finish(exit_code=1)


if __name__ == "__main__":
    main()
