import pickle
import numpy as np
import torch
from torch.utils.data import DataLoader, random_split, Dataset
from sklearn.metrics import accuracy_score
from TimesNet import Model as Model_Timesnet
from PatchTST import Model as Model_PatchTST
import time
import matplotlib.pyplot as plt
import pandas as pd
import os
import gc
from collections import Counter



class Configs_Timesnet:
    def __init__(self):
        self.task_name = 'classification'
        self.seq_len = 60
        self.label_len = 0
        self.pred_len = 0
        self.enc_in = 14
        self.num_class = 128
        self.d_model = 32
        self.d_ff = 64
        self.num_kernels = 6
        self.top_k = 3
        self.e_layers = 2
        self.embed = 'fixed'
        self.freq = 'h'
        self.dropout = 0.1
        self.epochs = 1

class Configs_PatchTST:
    def __init__(self):
        self.task_name = 'classification'
        self.seq_len = 243
        self.pred_len = 1
        self.d_model = 64
        self.enc_in = 14
        self.n_heads = 4
        self.d_ff = 128
        self.e_layers = 3
        self.dropout = 0.1
        self.activation = 'gelu'
        self.factor = 5
        self.num_class = 5
        self.epochs = 1
        self.patch_len= 16
        self.stride= 8

class ClassificationDataset(Dataset):
    def __init__(self, X, Y, dtype=torch.float32):
        assert isinstance(X, np.ndarray), "X should be NumPy"
        self.X = X
        self.Y = Y.astype(np.int64, copy=False) 
        self.dtype = dtype

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        xb = torch.from_numpy(self.X[idx]).to(self.dtype)
        yb = torch.tensor(self.Y[idx], dtype=torch.long)
        return xb, yb


# ---------------- Load Data ----------------
def load_and_prepare_data(ts_file, label_file, label_target='type'):
    
    ext = os.path.splitext(ts_file)[1].lower()
    if ext == '.pckl':
        with open(ts_file, 'rb') as f:
            ts_data = pickle.load(f)  # (N, feature_dim, T)
    elif ext == '.npy':
        ts_data = np.load(ts_file)   # (N, feature_dim, T)
    else:
        raise ValueError(f"Unsupported ts_file extension: {ext}. Expected .pckl or .npy")

    n_bus = ts_data.shape[1]
    with open(label_file, 'rb') as f:
        label_dicts_raw = pickle.load(f)

    label_dicts = []
    for entry in label_dicts_raw:
        clean_entry = {k: v.iloc[0] if hasattr(v, 'iloc') else v for k, v in entry.items()}
        label_dicts.append(clean_entry)

    ts_data = np.transpose(ts_data, (0, 2, 1))  # (N, T, feature_dim)

    type_set = sorted(list(set(d['type'] for d in label_dicts)))
    type_map = {k: i for i, k in enumerate(type_set)}


    location_pairs = set()
    for d in label_dicts:
        b1, b2 = d['bus1'], d['bus2']
        key = f"{b1}_{b2}"
        location_pairs.add(key)

    location_list = sorted(location_pairs) 
    location_map = {name: idx for idx, name in enumerate(location_list)}

    if label_target == 'type':
        labels = np.array([type_map[d['type']] for d in label_dicts])
        num_classes = len(type_map)
        classes = type_set
    elif label_target == 'location':
        labels = []
        for d in label_dicts:
            key1 = f"{d['bus1']}_{d['bus2']}"
            key2 = f"{d['bus2']}_{d['bus1']}"
            if key1 in location_map:
                labels.append(location_map[key1])
            elif key2 in location_map:
                labels.append(location_map[key2])
        labels = np.array(labels)
        num_classes = len(location_map)
        classes = location_list
    else:
        raise ValueError("Unknown label target")

    # ts_data = ts_data[:, :60, :]
    return ts_data, labels, num_classes, ts_data.shape[2]


def split_classification_data(X, y):
    dataset = ClassificationDataset(X, y)
    total_len = len(dataset)
    train_len = int(total_len * 0.75)
    val_len = int(total_len * 0.15)
    test_len = total_len - train_len - val_len
    train_data, val_data, test_data = random_split(dataset, [train_len, val_len, test_len])

    train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=32, shuffle=False)
    test_loader = DataLoader(test_data, batch_size=32, shuffle=False)
    return train_loader, val_loader, test_loader


def train_classification(model, loaders, cfg, label_target, device=None, eval_test_every=1):

    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)
    model = model.to(device)
    opt = torch.optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = torch.nn.CrossEntropyLoss()
    train_loader, val_loader, test_loader = loaders

    history = {
        "train_loss": [],
        "val_loss": [],
        "val_acc": [],
        "test_loss": [np.nan] * cfg.epochs,
        "test_acc":  [np.nan] * cfg.epochs,
    }

    def _evaluate(loader):
        model.eval()
        total_loss, n = 0.0, 0
        preds, trues = [], []
        with torch.no_grad():
            for xb, yb in loader:
                xb, yb = xb.to(device), yb.to(device)
                x_mark_enc = torch.ones((xb.shape[0], xb.shape[1]), dtype=torch.float32, device=device)
                out = model(xb, x_mark_enc, None, None)
                loss = loss_fn(out, yb)
                bs = yb.size(0)
                total_loss += loss.item() * bs
                n += bs
                preds.append(out.argmax(dim=-1).cpu().numpy())
                trues.append(yb.cpu().numpy())
        if n == 0:
            return float("nan"), 0.0
        preds = np.concatenate(preds)
        trues = np.concatenate(trues)
        return total_loss / n, accuracy_score(trues, preds)

    total_start_time = time.time()

    for epoch in range(1, cfg.epochs + 1):
        # Train
        model.train()
        running_loss, n_train = 0.0, 0
        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            x_mark_enc = torch.ones((xb.shape[0], xb.shape[1]), dtype=torch.float32, device=device)

            out = model(xb, x_mark_enc, None, None)
            loss = loss_fn(out, yb)
            opt.zero_grad()
            loss.backward()
            opt.step()

            bs = yb.size(0)
            running_loss += loss.item() * bs
            n_train += bs

        train_loss = running_loss / max(1, n_train)
        history["train_loss"].append(train_loss)

        # Validation
        val_loss, val_acc = _evaluate(val_loader)
        history["val_loss"].append(val_loss)
        history["val_acc"].append(val_acc)

        msg = (f"[Epoch {epoch}/{cfg.epochs}] "
               f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc ({label_target}): {val_acc:.4f}")

        # Test
        should_eval_test = (isinstance(eval_test_every, int) and eval_test_every > 0 and (epoch % eval_test_every == 0))
        if should_eval_test:
            test_loss, test_acc = _evaluate(test_loader)
            history["test_loss"][epoch - 1] = test_loss
            history["test_acc"][epoch - 1]  = test_acc
            msg += f" | Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.4f}"

        print(msg)

    total_time = time.time() - total_start_time
    print(f"[{label_target.upper()}] Total Training Time: {total_time:.2f} seconds")

    if not (isinstance(eval_test_every, int) and eval_test_every > 0):
        test_loss, test_acc = _evaluate(test_loader)
        history["test_loss"][-1] = test_loss
        history["test_acc"][-1]  = test_acc
        print(f"[Final Test] Loss: {test_loss:.4f} | Acc: {test_acc:.4f}")

    return history




if __name__ == '__main__':
    """
    All data is stored at Harvard Dataverse. Due to the size limit imposed by Harvard Dataverse, transient_data and transient_label files are split in to smaller sub-files.

    Before running this code, please download the data files and run Datafile_merge.py to merge sub-files to single file for each topology.  
    """

    data_path = r"/IEEE14_transient_data.pckl"
    label_path = r"/IEEE14_transient_label.pckl"
    n_bus = 14

    current_model = 'Timesnet' # 'PatchTST' or 'Timesnet'

    
    # type classification
    label_target = 'type'  # or 'type'
    X, y, num_classes, input_dim = load_and_prepare_data(data_path, label_path, label_target)
    print(f"X shape: {X.shape}, y shape: {len(y)}, num_classes: {num_classes}, input_dim: {input_dim}")
    out_path = f"/{current_model}_class_{label_target}_{n_bus}.csv"


    if current_model == 'PatchTST':
        cfg = Configs_PatchTST()
        cfg.seq_len = X.shape[1]
        cfg.num_class = num_classes
        cfg.enc_in = input_dim
        if n_bus == 2000:
            cfg.d_model = 16
            cfg.stride = 32
            cfg.patch_len = 16
        model = Model_PatchTST(cfg)
    else:
        cfg = Configs_Timesnet()
        # cfg.seq_len = X.shape[1]
        cfg.num_class = num_classes
        cfg.enc_in = input_dim
        model = Model_Timesnet(cfg)

    loaders = split_classification_data(X, y)
    history = train_classification(model, loaders, cfg, label_target, eval_test_every=10)

    epochs = np.arange(1, cfg.epochs + 1)
    df_hist = pd.DataFrame({
        "epoch": epochs,
        "train_loss": history["train_loss"],
        "val_loss": history["val_loss"],
        "val_acc": history["val_acc"],
    })
    df_hist.to_csv(out_path, index=False)
    print("Saved training history to loss_history.csv")




    #location classification
    label_target = 'location'  # or 'type'
    X, y, num_classes, input_dim, class1 = load_and_prepare_data(data_path, label_path, label_target)
    print(f"X shape: {X.shape}, y shape: {len(y)}, num_classes: {num_classes}, input_dim: {input_dim}")
    out_path = f"/{current_model}_{label_target}_{n_bus}.csv"

    if current_model == 'PatchTST':
        cfg = Configs_PatchTST()
        cfg.seq_len = X.shape[1]
        cfg.num_class = num_classes
        cfg.enc_in = input_dim
        if n_bus == 2000:
            cfg.d_model = 16
            cfg.stride = 32
            cfg.patch_len = 16
        model = Model_PatchTST(cfg)
    else:
        cfg = Configs_Timesnet()
        cfg.seq_len = X.shape[1]
        cfg.num_class = num_classes
        cfg.enc_in = input_dim
        model = Model_Timesnet(cfg)
    
    loaders = split_classification_data(X, y)

    history = train_classification(model, loaders, cfg, label_target, eval_test_every=10)

    epochs = np.arange(1, cfg.epochs + 1)
    df_hist = pd.DataFrame({
        "epoch": epochs,
        "train_loss": history["train_loss"],
        "val_loss": history["val_loss"],
        "val_acc": history["val_acc"],
    })
    df_hist.to_csv(out_path, index=False)
    print("Saved training history to loss_history.csv")
