import os
import torch
import random
import argparse
import numpy as np

from model import InvarGC
from train import train_adam
from regularizer import *
from sklearn.metrics import roc_auc_score, average_precision_score

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def load_dataset(name: str):
    """
    Load dataset by name.
    Returns:
        X: torch.Tensor of shape [num_env, B, T, d]
        label: np.ndarray of shape [d, d] or None
    """
    name = name.lower()
    X, label = None, None

    if name == "linear_synthetic_1":
        X = np.load("synthetic/linear_data/linear_envs3_conf1_itv1.npy")
        X = X[:, np.newaxis, :, :]
        X = torch.from_numpy(X).float()
        label = np.load("synthetic/ob_gc_1.npy")
        print("X shape:", X.shape)
        print("label shape:", label.shape)

    elif name == "nonlinear_synthetic_1":
        X = np.load("synthetic/non_linear_data/non_linear_envs3_conf1_itv1_leaky.npy")
        X = X[:, np.newaxis, :, :]
        X = torch.from_numpy(X).float()
        label = np.load("synthetic/ob_gc_1.npy")
        print("X shape:", X.shape)
        print("label shape:", label.shape)

    elif name == "conftep_woi":
        path = "conftep/conftep_woi"
        files = sorted([f for f in os.listdir(path) if f.startswith("x_") and f.endswith(".npy")])

        data_list = []
        for f in files:
            arr = np.load(os.path.join(path, f))
            tensor = torch.from_numpy(arr).float().unsqueeze(0)
            data_list.append(tensor)
        X = torch.stack(data_list, dim=0)
        label_path = os.path.join(path, "gt.npy")
        if os.path.exists(label_path):
            label = np.load(label_path)

    elif name == "conftep_wi":
        path = "conftep/conftep_wi"
        files = sorted([f for f in os.listdir(path) if f.startswith("x_") and f.endswith(".npy")])

        data_list = []
        for f in files:
            arr = np.load(os.path.join(path, f))
            tensor = torch.from_numpy(arr).float().unsqueeze(0)
            data_list.append(tensor)
        X = torch.stack(data_list, dim=0)
        label_path = os.path.join(path, "gt.npy")
        if os.path.exists(label_path):
            label = np.load(label_path)

    elif name == "confounder":
        path = "causalrivers/confounder"

        data_path = os.path.join(path, "conf_5_data.npy")
        label_path = os.path.join(path, "conf_5_label.npy")

        if not os.path.exists(data_path):
            raise FileNotFoundError(f"Data file not found: {data_path}")
        if not os.path.exists(label_path):
            raise FileNotFoundError(f"Label file not found: {label_path}")
        arr = np.load(data_path)
        X = torch.from_numpy(arr).float()  # [n, 1, d, t]
        label = np.load(label_path)  # [d, d]

    elif name == "flood":
        path = "causalrivers/flood"
        data_path = os.path.join(path, "flood_data.npy")
        label_path = os.path.join(path, "flood_label.npy")

        if not os.path.exists(data_path):
            raise FileNotFoundError(f"Data file not found: {data_path}")
        if not os.path.exists(label_path):
            raise FileNotFoundError(f"Label file not found: {label_path}")
        arr = np.load(data_path)
        X = torch.from_numpy(arr).float()  # [n, 1, d, t]
        label = np.load(label_path)  # [d, d]

    elif name == "no_rain_flood":
        path = "causalrivers/flood"
        data_path = os.path.join(path, "flood_mixed.npy")
        label_path = os.path.join(path, "flood_label.npy")

        if not os.path.exists(data_path):
            raise FileNotFoundError(f"Data file not found: {data_path}")
        if not os.path.exists(label_path):
            raise FileNotFoundError(f"Label file not found: {label_path}")
        arr = np.load(data_path)
        X = torch.from_numpy(arr).float()  # [n, 1, d, t]
        label = np.load(label_path)  # [d, d]

    elif name == "random_flood":
        path = "causalrivers/random_flood"
        data_path = os.path.join(path, "random_flood_data.npy")
        label_path = os.path.join(path, "random_flood_label.npy")

        if not os.path.exists(data_path):
            raise FileNotFoundError(f"Data file not found: {data_path}")
        if not os.path.exists(label_path):
            raise FileNotFoundError(f"Label file not found: {label_path}")
        arr = np.load(data_path)
        X = torch.from_numpy(arr).float()  # [n, 1, d, t]
        label = np.load(label_path)  # [d, d]

    else:
        raise ValueError(f"Unknown dataset: {name}")
    return X, label


def run(args):
    set_seed(args.seed)

    X, label = load_dataset(args.dataset) # X: [n_env, 1, T, d]; label: [d, d]
    num_env, B, T, d = X.shape

    if args.method == "invargc":
        model = InvarGC(
            num_series=args.num_series,
            num_confound=args.num_confound,
            num_env=num_env,
            timestep=T-1,
            hidden=args.hidden
        )
    else:
        raise NotImplementedError(f"Method {args.method} not implemented.")

    best_model = train_adam(
        model,
        X,
        lr=args.lr,
        max_iter=args.max_iter,
        lam_h=args.lam_h,
        lam_c=args.lam_c,
        lam_v=args.lam_v,
        lookback=args.lookback,
        check_every=args.check_every,
        verbose=True,
        label=label,
        device=args.device
    )


    if best_model is not None:
        torch.save(best_model, "saved_model/best_model.pt")
    state_dict = torch.load("saved_model/best_model.pt")
    model.load_state_dict(state_dict)
    model.eval()

    # Inference
    with torch.no_grad():
        G = model.est_gc(threshold=None).detach().cpu().numpy()
    print("Estimated Granger Graph:")
    print(G)

    if label is not None:
        if isinstance(label, torch.Tensor):
            label_np = label.detach().cpu().numpy()
        else:
            label_np = label

        if label_np.shape != G.shape:
            raise ValueError(f"Shape mismatch: label {label_np.shape}, pred {G.shape}")

        auc_roc = roc_auc_score(label_np.flatten(), G.flatten())
        au_prc = average_precision_score(label_np.flatten(), G.flatten())
        print(f"Final AUROC: {auc_roc:.4f}, AUPRC: {au_prc:.4f}")