import argparse
import logging
import torch
import numpy as np
import warnings
from diffusers import DDPMScheduler

from data_utils import get_dataset
from pp_utils import PostProcess
from plot_helper import plot_rul
import matplotlib.pyplot as plt

warnings.filterwarnings("ignore", category=UserWarning, module="torch")

logging.basicConfig(
    level=logging.INFO,
    format="[%(asctime)s] %(levelname)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)


def load_flowbatt(input_shape, mask_size, device, battery_dataset):
    from nns import DiT, InitializerNet, generate_samples
    from Configs import TrainingConfig

    config = TrainingConfig()
    model = DiT(
        input_dim=config.sequence_length,
        input_shape=input_shape,
        num_blocks=config.num_blocks,
        num_channels=config.channels,
        class_dropout_prob=config.class_dropout_prob,
        mask_size=mask_size,
    ).to(device)

    initializer = InitializerNet(
        input_dim=config.sequence_length, mask_size=mask_size
    ).to(device)

    ckpt = f"./workspaces/trained_models_FM/{config.output_dir}_{battery_dataset}.pth"
    model.load_state_dict(torch.load(ckpt, weights_only=True))
    model.eval()

    return model, initializer, generate_samples, config


def load_diffbatt(input_shape, mask_size, device, battery_dataset):
    from nns import DiT, InitializerNet
    from nns_diffbatt import generate_samples
    from Configs_diffbatt import TrainingConfig

    config = TrainingConfig()
    model = DiT(
        input_dim=config.sequence_length,
        input_shape=input_shape,
        num_blocks=config.num_blocks,
        num_channels=config.channels,
        class_dropout_prob=config.class_dropout_prob,
        mask_size=mask_size,
    ).to(device)

    initializer = InitializerNet(
        input_dim=config.sequence_length, mask_size=mask_size
    ).to(device)

    ckpt = f"./workspaces/trained_models_diffbatt/{config.output_dir}_{battery_dataset}.pth"
    model.load_state_dict(torch.load(ckpt, weights_only=True))
    model.eval()

    scheduler = DDPMScheduler(num_train_timesteps=1000)
    return model, initializer, generate_samples, config, scheduler


def load_trf(input_shape, device, battery_dataset):
    from nns_trf import TrF, generate_samples
    from Configs_trf import TrainingConfig

    config = TrainingConfig()
    model = TrF(
        input_dim=config.sequence_length,
        input_shape=input_shape,
        num_blocks=config.num_blocks,
        num_channels=config.channels,
        class_dropout_prob=config.class_dropout_prob,
    ).to(device)

    return model, generate_samples, config


def main(args):
    # Device setup
    device = "cuda" if torch.cuda.is_available() else "cpu"
    logging.info(f"Using device: {device}")

    # Load dataset
    dataset = get_dataset(args.dataset)
    input_shape = (1,) + dataset.train_data.feature.cpu().numpy().shape[1:]
    mask_size = 2 if args.dataset == "mix_20_Q" else 10

    x_preds, records = [], []

    if args.model == "flowbatt":
        model, initializer, generate_samples, config = load_flowbatt(
            input_shape, mask_size, device, args.dataset
        )
        for seed in range(args.seeds):
            torch.manual_seed(seed)
            x_test, x_pred, recs = generate_samples(
                dataset, model, initializer, device, config, return_records=True
            )
            x_preds.append(x_pred)
            records.append(recs)

    elif args.model == "diffbatt":
        model, initializer, generate_samples, config, scheduler = load_diffbatt(
            input_shape, mask_size, device, args.dataset
        )
        for seed in range(args.seeds):
            torch.manual_seed(seed)
            x_test, x_pred, recs = generate_samples(
                dataset,
                model,
                initializer,
                scheduler,
                device,
                config,
                return_records=True,
            )
            x_preds.append(x_pred)
            records.append(recs)

    elif args.model == "trf":
        model, generate_samples, config = load_trf(
            input_shape, device, args.dataset
        )
        for seed in range(args.seeds):
            torch.manual_seed(seed)
            ckpt = (
                f"./workspaces/trained_models_TrF/{config.output_dir}_{seed}_{args.dataset}.pth"
            )
            model.load_state_dict(torch.load(ckpt, weights_only=True))
            model.eval()

            x_test, x_pred = generate_samples(dataset, model, device)
            x_preds.append(x_pred)

    else:
        raise ValueError(f"Unknown model: {args.model}")

    # Post-processing and evaluation
    eol = 90 if args.dataset == "mix_20_Q" else 80
    post_process = PostProcess(eol=eol)

    rmse, mape, soh_rmse, rul_preds = [], [], [], []

    for x_pred in x_preds:
        refs, preds = post_process.post_process(x_test.cpu(), x_pred.cpu())
        soh_rmse_ = post_process.eval_soh(refs, preds)
        rul_rmse, rul_mape, rul_ref, rul_pred = post_process.eval_rul(refs, preds)

        rmse.append(rul_rmse)
        mape.append(rul_mape)
        soh_rmse.append(soh_rmse_)
        rul_preds.append(rul_pred)

    logging.info(
        "RUL RMSE %.0f±%.0f | RUL MAPE %.0f±%.0f | SOH RMSE %.2f±%.2f",
        np.mean(rmse),
        np.std(rmse),
        np.mean(mape),
        np.std(mape),
        np.mean(soh_rmse),
        np.std(soh_rmse),
    )

    post_process.plot_sample(refs, preds)
    plot_rul(rul_ref, rul_pred, std=np.array(rul_preds).std(0), name=args.dataset, save=False)
    plt.show()

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Battery RUL/SoH prediction runner")
    parser.add_argument(
        "--dataset",
        type=str,
        default="matr_1_Q",
        choices=["matr_1_Q", "matr_2_Q", "hust_Q", "mix_100_Q", "mix_20_Q"],
    )
    parser.add_argument(
        "--model", type=str, default="flowbatt", choices=["flowbatt", "diffbatt", "trf"]
    )
    parser.add_argument("--seeds", type=int, default=10, help="Number of random seeds")

    args = parser.parse_args()
    main(args)
