# %%
import time
from collections import defaultdict
import copy
from pathlib import Path
import os
from tqdm import tqdm
import random

import numpy as np
import pandas as pd

from sklearn.utils import check_random_state
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import accuracy_score, f1_score

import torch
from torch import nn
from torch.amp import autocast


from temporal_norm.utils import get_subject_ids, get_dataloader
from temporal_norm.utils.unet import USleep
from temporal_norm.utils.transformer import CNNTransformer
from temporal_norm.utils.caresleepnet import CareSleepNet
from temporal_norm.utils import get_center_label
from temporal_norm.utils._psdnorm import welch_psd

import argparse

device = "cuda" if torch.cuda.is_available() else "cpu"


# %%

parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, default="ABC")
parser.add_argument("--n_subjects", type=int, default=40)
parser.add_argument("--filter_size", type=int, default=1)
parser.add_argument("--auto_filter_size", action="store_true")
parser.add_argument("--norm", type=str, default="BatchNorm")
parser.add_argument("--bias_learnable", action="store_true")
parser.add_argument("--target_learnable", action="store_true")
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--model_name", type=str, default="USleep")
parser.add_argument("--balanced", action="store_true")
parser.add_argument("--balanced_acc", action="store_true")
parser.add_argument("--use_amp", action="store_true")
parser.add_argument("--num_workers", type=int, default=5)
parser.add_argument("--print_tqdm", action="store_true")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--lr", type=float, default=1e-3)
parser.add_argument("--compile", action="store_true")
parser.add_argument("--torchinductor", action="store_true")
parser.add_argument("--results_path", type=str, default="results_LODO")
parser.add_argument("--norm_apply_to", type=str, default="encoder")
parser.add_argument("--detrend", action="store_true")
parser.add_argument("--deterministic", action="store_true")
parser.add_argument("--eager", action="store_true")
parser.add_argument("--whitening", action="store_true")
parser.add_argument("--filter_size_reduce", action="store_true")
parser.add_argument("--n_epochs", type=int, default=15)
parser.add_argument("--tma", action="store_true", help="Use TMA as preprocessing")


args = parser.parse_args()

if args.torchinductor:
    os.environ["TORCHINDUCTOR_CACHE_DIR"] = (
        "/lustre/fswork/projects/rech/chr/ujq48hj/.cache/"
    )

n_subjects = args.n_subjects
norm = args.norm
filter_size = args.filter_size
if args.auto_filter_size:
    filter_size = "auto"
bias_learnable = args.bias_learnable
target_learnable = args.target_learnable
norm_apply_to = args.norm_apply_to
whitening = args.whitening
filter_size_reduce = args.filter_size_reduce
batch_size = args.batch_size
dataset_target = args.dataset
model_name = args.model_name
balanced = args.balanced
balanced_acc = args.balanced_acc
use_amp = args.use_amp
num_workers = args.num_workers
print_tqdm = args.print_tqdm
lr = args.lr
if use_amp:
    print("BE CAREFUL! AMP is enabled.")

# %%
dataset_names = [
    "ABC",
    "CHAT",
    "CFS",
    "SHHS",
    "HOMEPAP",
    "CCSHS",
    "MASS",
    "PhysioNet",
    "SOF",
    "MROS",
]
print("loading metadata ...")
print("")
metadata = pd.read_parquet(
    "metadata/metadata_sleep.parquet",
    columns=["dataset_name", "subject_id", "session", "y", "sample"],
)

# %%

print(f"N_subjects: {n_subjects}")
modules = []

# Set experiment randomness
seed = args.seed
print(f"seed: {seed}")
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
rng = check_random_state(seed)


def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)


g_train = torch.Generator()
g_train.manual_seed(seed)

g_val = torch.Generator()
g_val.manual_seed(seed)

g_target = torch.Generator()
g_target.manual_seed(seed)


if args.deterministic:
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# dataloader
n_windows = 35
n_windows_stride = 21
n_windows_stride_inference = 1 if model_name == "DeepSleepNet" else n_windows_stride
n_sequences_balanced = int(len(metadata) / n_windows_stride)
if balanced:
    n_windows_stride = 1
batch_size_inference = batch_size
pin_memory = True
persistent_workers = False

# model
in_chans = 2
n_classes = 5
input_size_samples = 3000

if norm == "BatchNorm":
    filter_size = 0
elif norm == "PSDNorm":
    filter_size = filter_size
elif norm == "InstanceNorm":
    filter_size = 0
elif norm == "LayerNorm":
    filter_size = 0
else:
    raise ValueError(f"Unknown normalization layer: {norm}")
print(f"Model: {model_name}")
print(f"Normalization Layer: {norm} (filter_size: {filter_size})")

# training
n_epochs = args.n_epochs
patience = 3
assert (
    n_windows - n_windows_stride
) % 2 == 0, "n_windows - n_windows_stride must be even"
first_window_idx = (n_windows - n_windows_stride) // 2
last_window_idx = first_window_idx + n_windows_stride

# %%
subject_ids = get_subject_ids(metadata, dataset_names)

subject_id_target = subject_ids[dataset_target]

dataset_sources = dataset_names.copy()
dataset_sources.remove(dataset_target)

# %%
subject_ids_train, subject_ids_val = dict(), dict()
n_subject_tot = 0

print("Datasets used for training and validation:")
for dataset_name in dataset_sources:
    subject_ids_all = subject_ids[dataset_name]
    n_subjects_ = min(n_subjects, len(subject_ids_all))
    n_subject_tot += n_subjects_

    print(f"Dataset: {dataset_name}, n_subjects: {n_subjects_}")

    subject_ids_dataset = rng.choice(subject_ids_all, n_subjects_, replace=False)

    subject_ids_train[dataset_name], subject_ids_val[dataset_name] = train_test_split(
        subject_ids_dataset, test_size=0.2, random_state=seed
    )

print(f"Target dataset: {dataset_target}")

# %%
# probs = get_probs(metadata, dataset_sources, alpha=0.5)


# Source train dataloader
dataloader_train = get_dataloader(
    metadata=metadata,
    dataset_names=dataset_sources,
    subject_ids=subject_ids_train,
    n_windows=n_windows,
    n_windows_stride=n_windows_stride,
    batch_size=batch_size,
    num_workers=num_workers,
    pin_memory=pin_memory,
    persistent_workers=persistent_workers,
    balanced=balanced,
    n_sequences_balanced=n_sequences_balanced,
    randomize=True,
    target_transform=get_center_label if model_name == "DeepSleepNet" else None,
    drop_last=True,
    generator=g_train,
    worker_init_fn=seed_worker,
)

if args.tma:
    psd_all = []
    for i, (batch_X, batch_y, _, _) in enumerate(
        tqdm(dataloader_train, desc="Training", unit="batch", disable=not print_tqdm)
    ):
        batch_X = batch_X.permute(0, 2, 1, 3)  # (B, C, S, T)
        batch_X = batch_X.flatten(start_dim=2)
        psd_batch = welch_psd(batch_X, window=None, nperseg=5, detrend=False)[1]

        psd_all.append(psd_batch)
    psd_all = torch.cat(psd_all, dim=0)
    barycenter = torch.mean(torch.sqrt(psd_all), dim=0) ** 2
    barycenter = barycenter.to(device)
    print("Barycenter shape:", barycenter.shape)
else:
    barycenter = None
# Source val dataloader
dataloader_val = get_dataloader(
    metadata=metadata,
    dataset_names=dataset_sources,
    subject_ids=subject_ids_val,
    n_windows=n_windows,
    n_windows_stride=n_windows_stride,
    batch_size=batch_size_inference,
    num_workers=num_workers,
    pin_memory=pin_memory,
    persistent_workers=persistent_workers,
    randomize=False,
    target_transform=get_center_label if model_name == "DeepSleepNet" else None,
    drop_last=True,
    generator=g_val,
    worker_init_fn=seed_worker,
)

# Target dataloader
dataloader_target = get_dataloader(
    metadata=metadata,
    dataset_names=[dataset_target],
    subject_ids={dataset_target: subject_id_target},
    n_windows=n_windows,
    n_windows_stride=n_windows_stride_inference,
    batch_size=batch_size_inference,
    num_workers=num_workers,
    pin_memory=pin_memory,
    persistent_workers=persistent_workers,
    randomize=False,
    target_transform=get_center_label if model_name == "DeepSleepNet" else None,
    drop_last=False,
    generator=g_target,
    worker_init_fn=seed_worker,
)


print()
print(f"Number of source subjects: {n_subject_tot}")
print(
    f"Number of training subjects: {sum([len(v) for v in subject_ids_train.values()])}"
)
print(
    f"Number of validation subjects: {sum([len(v) for v in subject_ids_val.values()])}"
)
print(f"Number of target subjects: {len(subject_id_target)}")
print()

print(f"Number of training batches: {len(dataloader_train)}")
print(f"Number of validation batches: {len(dataloader_val)}")
print(f"Number of target batches: {len(dataloader_target)}")
print()

# %%


def count_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


if model_name == "USleep":
    model = USleep(
        n_chans=in_chans,
        sfreq=100,
        depth=12,
        with_skip_connection=True,
        n_outputs=n_classes,
        n_times=input_size_samples,
        norm=norm,
        filter_size=filter_size,
        bias_learnable=bias_learnable,
        target_learnable=target_learnable,
        norm_apply_to=norm_apply_to,
        detrend="constant" if args.detrend else False,
        whitening=whitening,
        filter_size_reduce=filter_size_reduce,
        barycenter=barycenter,
    )

elif model_name == "CareSleepNet":
    model = CareSleepNet(
        n_chans=in_chans,
        n_outputs=n_classes,
        n_windows=n_windows,
        filter_size=filter_size,
    )

elif model_name == "CNNTransformer":
    model = CNNTransformer(
        n_channels=in_chans,
        n_classes=n_classes,
        transformer_layers=2,
        filter_size=filter_size,
        nhead=8,
        d_model=768,
        dropout=0.1,
        filter_size_reduce=filter_size_reduce,
        bias_learnable=bias_learnable,
        target_learnable=target_learnable,
        norm=norm,
        detrend="constant" if args.detrend else False,
        whitening=whitening,
    )
    print(f"CNNTransformer: CNN trainable params: {count_params(model.cnn):,}")
    print(
        "CNNTransformer: Transformer trainable params: "
        f"{count_params(model.transformer):,}"
    )

num_trainable_params = count_params(model)
print(f"Trainable parameters: {num_trainable_params:,}")

model.to(device)
if use_amp:
    model = model.to(torch.bfloat16)
if args.compile:
    if args.eager:
        print("Compiling model with torch.compile (eager)")
        model = torch.compile(model, backend="eager")
    else:
        print("Compiling model with torch.compile")
        model = torch.compile(model)
if balanced_acc:
    metadata_source = metadata[
        metadata["dataset_name"].isin(dataset_sources)
        & metadata["subject_id"].isin(subject_ids_train[dataset_sources[0]])
    ]
    class_weights = compute_class_weight(
        class_weight="balanced",
        classes=np.unique(metadata_source.y),
        y=metadata_source.y,
    )
    class_weights = torch.tensor(class_weights, dtype=torch.float32).to(device)
    print("Class weights:", class_weights)
else:
    class_weights = None


criterion = nn.CrossEntropyLoss(weight=class_weights)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
history = []

print()
print("Start training")
min_val_loss = np.inf
time_epochs = []
for epoch in range(n_epochs):
    time_start = time.time()
    model.train()
    train_loss = np.zeros(len(dataloader_train))
    y_pred_all, y_true_all = list(), list()

    running_loss = 0.0
    running_window = len(dataloader_train) // 20  # Number of batches for averaging loss
    for i, (batch_X, batch_y, _, _) in enumerate(
        tqdm(dataloader_train, desc="Training", unit="batch", disable=not print_tqdm)
    ):
        optimizer.zero_grad()
        batch_X = batch_X.to(device, non_blocking=True)
        batch_y = batch_y.to(device, non_blocking=True)

        with autocast(device_type=device, dtype=torch.bfloat16, enabled=use_amp):
            output = model(batch_X)
            loss_batch = criterion(output, batch_y)

        loss_batch.backward()
        optimizer.step()

        y_pred_all.append(output.argmax(axis=1).detach())
        y_true_all.append(batch_y.detach())
        train_loss[i] = loss_batch.item()

        # Update tqdm progress bar every running_window batches with average loss
        running_loss += loss_batch.item()
        if (i + 1) % running_window == 0 and print_tqdm:
            avg_loss = running_loss / running_window
            tqdm.write(f"Batch {i+1}/{len(dataloader_train)}, Avg Loss: {avg_loss:.3f}")
            running_loss = 0.0

    y_pred_all = [y.cpu().numpy() for y in y_pred_all]
    y_true_all = [y.cpu().numpy() for y in y_true_all]
    y_pred = np.concatenate(y_pred_all)
    y_true = np.concatenate(y_true_all)
    if model_name != "DeepSleepNet":
        y_pred = y_pred[:, first_window_idx:last_window_idx]
        y_true = y_true[:, first_window_idx:last_window_idx]

    perf = accuracy_score(y_true.flatten(), y_pred.flatten())
    f1 = f1_score(y_true.flatten(), y_pred.flatten(), average="weighted")

    model.eval()
    with torch.no_grad():
        val_loss = np.zeros(len(dataloader_val))
        y_pred_all, y_true_all = list(), list()
        for i, (batch_X, batch_y, _, _) in enumerate(
            tqdm(
                dataloader_val, desc="Validation", unit="batch", disable=not print_tqdm
            )
        ):
            batch_X = batch_X.to(device, non_blocking=True)
            batch_y = batch_y.to(device, non_blocking=True)

            with autocast(device_type=device, dtype=torch.bfloat16, enabled=use_amp):
                output = model(batch_X)

            loss_batch = criterion(output, batch_y)

            y_pred_all.append(output.argmax(axis=1).detach())
            y_true_all.append(batch_y.detach())
            val_loss[i] = loss_batch.item()

        y_pred_all = [y.cpu().numpy() for y in y_pred_all]
        y_true_all = [y.cpu().numpy() for y in y_true_all]

        y_pred = np.concatenate(y_pred_all)
        y_true = np.concatenate(y_true_all)
        if model_name != "DeepSleepNet":
            y_pred = y_pred[:, first_window_idx:last_window_idx]
            y_true = y_true[:, first_window_idx:last_window_idx]
        perf_val = accuracy_score(y_true.flatten(), y_pred.flatten())
        std_val = np.std(perf_val)
        f1_val = f1_score(y_true.flatten(), y_pred.flatten(), average="weighted")
        std_f1_val = np.std(f1_val)
    time_end = time.time()
    time_epoch = time_end - time_start
    time_epochs.append(time_epoch)
    history.append(
        {
            "epoch": epoch,
            "train_loss": np.mean(train_loss),
            "train_acc": perf,
            "train_f1": f1,
            "val_loss": np.mean(val_loss),
            "val_acc": perf_val,
            "val_std": std_val,
            "val_f1": f1_val,
            "val_f1_std": std_f1_val,
            "time_epoch": time_epoch,
        }
    )

    print(
        "Ep:",
        epoch,
        "Loss:",
        round(np.mean(train_loss), 4),
        "Acc:",
        round(np.mean(perf), 2),
        "LossVal:",
        round(np.mean(val_loss), 4),
        "AccVal:",
        round(np.mean(perf_val), 2),
        "Time:",
        round(time_end - time_start, 2),
    )

    # do early stopping
    if min_val_loss > np.mean(val_loss):
        min_val_loss = np.mean(val_loss)
        patience_counter = 0
        best_model = copy.deepcopy(model)
    else:
        patience_counter += 1
        if patience_counter > patience:
            print("Early stopping")
            break

folder = Path(args.results_path)
folder.mkdir(parents=True, exist_ok=True)
folder_history = folder / "history"
folder_history.mkdir(parents=True, exist_ok=True)
history_path = (
    folder_history
    / f"history_{model_name}_{norm}_{filter_size}_{n_subjects}"
    f"_LODO_{dataset_target}_bias_{bias_learnable}_target_{target_learnable}_{seed}.pkl"
)
df_history = pd.DataFrame(history)
df_history.to_pickle(history_path)

folder_model = folder / "models"
folder_model.mkdir(parents=True, exist_ok=True)
torch.save(
    best_model,
    folder_model
    / f"models_{model_name}_{norm}_{filter_size}_{n_subjects}"
    f"_LODO_{dataset_target}_bias_{bias_learnable}_target_{target_learnable}_{seed}.pt",
)
# save optimizer
torch.save(
    optimizer.state_dict(),
    folder_model
    / f"optimizer_{model_name}_{norm}_{filter_size}_{n_subjects}"
    f"_LODO_{dataset_target}_bias_{bias_learnable}_target_{target_learnable}_{seed}.pt",
)

results = []
folder_pickle = folder / "pickles"
folder_pickle.mkdir(parents=True, exist_ok=True)
results_path = (
    folder_pickle
    / f"results_{model_name}_{norm}_{filter_size}_{n_subjects}_LODO_"
    f"{dataset_target}_bias_{bias_learnable}_target_{target_learnable}_{seed}.pkl"
)

# Accumulate predictions and targets on GPU per subject
results_by_subject = defaultdict(lambda: {"y_pred": [], "y_true": []})

best_model.eval()
time_start = time.time()
with torch.no_grad():
    for batch_X, batch_y, batch_sub_id, batch_session_id in tqdm(
        dataloader_target,
        desc="Inference on target",
        unit="batch",
        disable=not print_tqdm,
    ):
        batch_X = batch_X.to(device, non_blocking=True)
        batch_y = batch_y.to(device, non_blocking=True)
        batch_sub_id = batch_sub_id.to(device, non_blocking=True)

        with autocast(device_type=device, dtype=torch.bfloat16, enabled=use_amp):
            output = best_model(batch_X)

        preds = output.argmax(dim=1)

        # Gather predictions per subject
        for y_t, y_p, subj in zip(batch_y, preds, batch_sub_id):
            if model_name != "DeepSleepNet":
                y_t = y_t[first_window_idx:last_window_idx]
                y_p = y_p[first_window_idx:last_window_idx]
            results_by_subject[int(subj.item())]["y_true"].append(y_t)
            results_by_subject[int(subj.item())]["y_pred"].append(y_p)
time_end = time.time()

time_inference = time_end - time_start
results = []
for subj_id, data in results_by_subject.items():
    if model_name != "DeepSleepNet":
        y_pred_tensor = torch.cat(data["y_pred"])
        y_true_tensor = torch.cat(data["y_true"])
    else:
        y_pred_tensor = torch.cat([t.unsqueeze(0) for t in data["y_pred"]])
        y_true_tensor = torch.cat([t.unsqueeze(0) for t in data["y_true"]])

    results.append(
        {
            "subject": subj_id,
            "seed": seed,
            "dataset": dataset_target,
            "dataset_type": "target",
            "norm": norm,
            "filter_size": filter_size,
            "bias_learnable": bias_learnable,
            "target_learnable": target_learnable,
            "norm_apply_to": norm_apply_to,
            "detrend": args.detrend,
            "whitening": whitening,
            "filter_size_reduce": filter_size_reduce,
            "n_subject_train": n_subject_tot,
            "n_subject_test": len(subject_id_target),
            "n_windows": n_windows,
            "n_windows_stride": n_windows_stride,
            "batch_size": batch_size,
            "batch_size_inference": batch_size_inference,
            "num_workers": num_workers,
            "n_epochs": n_epochs,
            "patience": patience,
            "n_subjects": n_subjects,
            "model_name": model_name,
            "y_pred": y_pred_tensor.cpu().numpy().flatten(),
            "y_true": y_true_tensor.cpu().numpy().flatten(),
            "time_inference": time_inference,
            # "time_epochs": time_epochs,
            # "time_per_epoch": np.mean(time_epochs),
            # "time_total": sum(time_epochs),
            # "time_per_epoch_std": np.std(time_epochs),
            # "time_epochs": time_epochs,
        }
    )
try:
    df_results = pd.read_pickle(results_path)
except FileNotFoundError:
    df_results = pd.DataFrame()
df_results = pd.concat((df_results, pd.DataFrame(results)))
df_results.to_pickle(results_path)

print("Target Inference Done")

