import os
import argparse
import pandas as pd
import numpy as np
import torch
import random
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import f1_score
from collections import defaultdict

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def load_table(ehr, name, root, train_ids, test_ids, mode=None):
    if ehr == "eicu":
        name = name.replace("labevents", "lab") \
                   .replace("inputevents", "infusiondrug") \
                   .replace("prescriptions", "medication")
    df = pd.read_csv(os.path.join(root, f"{name}.csv"))
    
    if (name == "prescriptions") and (ehr == "mimiciv"):
        df = df.rename(columns={"drug": "itemid"})
    elif ("lab" in name) and (ehr == "eicu"):
        df = df.rename(columns={"labname": "itemid"})
    elif ("infusiondrug" in name) and (ehr == "eicu"):
        df = df.rename(columns={"drugname": "itemid"})
    elif ("medication" in name) and (ehr == "eicu"):
        df = df.rename(columns={"drugname": "itemid"})
    
    if mode == 'train':
        df = df[df.stay_id.isin(train_ids)]
    elif mode == 'test':
        df = df[df.stay_id.isin(test_ids)]
    return df


class RandomTEventDataset(Dataset):
    def __init__(self, seq_data_list, t_indices=None):
        self.data = [events for events in seq_data_list if len(events) >= 2]
        if t_indices is not None:
            self.t_indices = t_indices
        else:
            self.t_indices = [None] * len(self.data)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        events = self.data[idx]
        if self.t_indices[idx] is not None:
            t = self.t_indices[idx]
        else:
            t = np.random.randint(1, len(events))
        input_seq = np.array([vec for _, vec in events[:t]])
        target = np.array(events[t][1])
        return input_seq, target, len(input_seq)

    def collate_fn(self, batch):
        batch.sort(key=lambda x: x[2], reverse=True)
        sequences, targets, lengths = zip(*batch)
        sequences = [torch.tensor(seq, dtype=torch.float32) for seq in sequences]
        padded_sequences = nn.utils.rnn.pad_sequence(sequences, batch_first=True)
        targets = torch.tensor(np.stack(targets), dtype=torch.float32)
        lengths = torch.tensor(lengths, dtype=torch.long)
        return padded_sequences, targets, lengths


class EventPredictor(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x, lengths):
        packed = nn.utils.rnn.pack_padded_sequence(x, lengths.cpu(), batch_first=True)
        _, (hn, _) = self.lstm(packed)
        return self.sigmoid(self.fc(hn[-1]))
 
 
def train_model(args, train_loader, valid_loader, criterion, device):
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    ckpt_name = f"best_model_lstm_{args.ehr}_{args.suffix}_seed{args.seed}_lr{args.lr}_th{args.threshold}.pt"

    best_valid_f1 = 0
    patience_counter = 0
        
    for epoch in range(args.epochs):
        print(f"\n Epoch {epoch+1}/{args.epochs}")
        model.train()
        total_loss = 0
        for batch_x, batch_y, lengths in train_loader:
            batch_x, batch_y, lengths = batch_x.to(device), batch_y.to(device), lengths.to(device)
            optimizer.zero_grad()
            loss = criterion(model(batch_x, lengths), batch_y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Train Loss: {total_loss / len(train_loader):.4f}")

        # Validation
        model.eval()
        preds_all, targets_all = [], []
        with torch.no_grad():
            for batch_x, batch_y, lengths in valid_loader:
                batch_x, batch_y, lengths = batch_x.to(device), batch_y.to(device), lengths.to(device)
                outputs = model(batch_x, lengths)
                preds_all.append((outputs > args.threshold).int().cpu().numpy())
                targets_all.append(batch_y.int().cpu().numpy())

        preds_all = np.vstack(preds_all)
        targets_all = np.vstack(targets_all)
        f1 = f1_score(targets_all, preds_all, average='micro', zero_division=0)
        print(f"Valid F1: {f1:.4f}")

        if f1 > best_valid_f1:
            best_valid_f1 = f1
            torch.save(model.state_dict(), ckpt_name)
            print(f"Best model saved → {ckpt_name}")
            patience_counter = 0
        else:
            patience_counter += 1
            print(f"⏳ No improvement: {patience_counter}/{args.patience} epochs")
            if patience_counter >= args.patience:
                print("Early stopping triggered.")
                break

def test_model(args, test_loader, model, device, criterion, threshold):
    print("\n Test start")
    real_lab_test = load_table(args.ehr, "labevents", root=args.real_data_root, mode='test')
    real_input_test = load_table(args.ehr, "inputevents", root=args.real_data_root, mode='test')
    real_pre_test = load_table(args.ehr, "prescriptions", root=args.real_data_root, mode='test').rename(columns={"drug": "itemid"})

    real_test_df = pd.concat([real_lab_test, real_input_test, real_pre_test])[['stay_id', 'time', 'itemid']].dropna()
    real_test_df = real_test_df.sort_values(by=['stay_id', 'time'])
    test_grouped = real_test_df.groupby(['stay_id', 'time'])['itemid'].apply(set).reset_index()
    test_item_matrix = mlb.transform(test_grouped['itemid'])
    test_grouped['item_vector'] = list(test_item_matrix)

    test_seq_data = defaultdict(list)
    for _, row in test_grouped.iterrows():
        test_seq_data[row['stay_id']].append((row['time'], row['item_vector']))
    for sid in test_seq_data:
        test_seq_data[sid] = sorted(test_seq_data[sid], key=lambda x: x[0])


    t_indices_file = f"{args.ehr}_test_t_indices.npy"
    test_seq_list = list(test_seq_data.values())
    test_seq_list = [events for events in test_seq_list if len(events) >= 2]

    if os.path.exists(t_indices_file):
        t_indices = np.load(t_indices_file, allow_pickle=True)
        print(f"{t_indices_file} loaded.")
    else:
        t_indices = [np.random.randint(1, len(events)) for events in test_seq_list]
        np.save(t_indices_file, t_indices)
        print(f"t indices saved to {t_indices_file}.")


    test_dataset = RandomTEventDataset(test_seq_list, t_indices=t_indices)

    test_loader = DataLoader(
        test_dataset, batch_size=args.batch_size, shuffle=False,
        collate_fn=test_dataset.collate_fn, num_workers=0
    )

    model.load_state_dict(torch.load(ckpt_name))
    model.eval()

    preds_all, targets_all = [], []
    with torch.no_grad():
        for batch_x, batch_y, lengths in test_loader:
            batch_x, batch_y, lengths = batch_x.to(device), batch_y.to(device), lengths.to(device)
            outputs = model(batch_x, lengths)
            preds_all.append((outputs > args.threshold).int().cpu().numpy())
            targets_all.append(batch_y.int().cpu().numpy())

    preds_all = np.vstack(preds_all)
    targets_all = np.vstack(targets_all)
    test_f1 = f1_score(targets_all, preds_all, average='micro', zero_division=0)
    print(f" micro F1-score: {test_f1:.4f}")
   


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--ehr", type=str, default='mimiciv')
    parser.add_argument("--obs_size", type=int, default=12)
    parser.add_argument('--real_data_root', type=str, default='data/real_data/mimiciv/')
    parser.add_argument('--syn_data_root', type=str, default='data/syn_data/mimiciv/sdv/')
    parser.add_argument("--suffix", type=str, default='')
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument('--lr', type=float, default=0.0005)
    parser.add_argument('--threshold', type=float, default=0.1)
    parser.add_argument('--epochs', type=int, default=30)
    parser.add_argument('--patience', type=int, default=5)
    parser.add_argument('--batch_size', type=int, default=512)
    parser.add_argument('--test_only', action='store_true', help='Run test only without training')
    args = parser.parse_args()     

    set_seed(args.seed)

    # Load Real
    splits = pd.read_csv(os.path.join(args.real_data_root, f"{args.ehr}_split.csv")).reset_index()
    train_ids = splits[splits["seed0"] == "train"].index.unique()
    test_ids = splits[splits["seed0"] == "test"].index.unique()


    real_lab_for_vocab = load_table(args.ehr, "labevents", root=args.real_data_root, train_ids=train_ids, test_ids=test_ids, mode='train')
    real_input_for_vocab = load_table(args.ehr, "inputevents", root=args.real_data_root, train_ids=train_ids, test_ids=test_ids, mode='train')
    real_pre_for_vocab = load_table(args.ehr, "prescriptions", root=args.real_data_root, train_ids=train_ids, test_ids=test_ids, mode='train').rename(columns={"drug": "itemid"})

    real_vocab_df = pd.concat([real_lab_for_vocab, real_input_for_vocab, real_pre_for_vocab])[['stay_id', 'time', 'itemid']].dropna()
    real_vocab_df = real_vocab_df.sort_values(by=['stay_id', 'time'])
    real_grouped = real_vocab_df.groupby(['stay_id', 'time'])['itemid'].apply(set).reset_index()

    mlb = MultiLabelBinarizer()
    mlb.fit(real_grouped['itemid'])  


    # Load Syn
    train_lab = load_table(args.ehr, f"labevents{args.suffix}", root=args.syn_data_root, train_ids=train_ids, test_ids=test_ids, mode='train')
    train_input = load_table(args.ehr, f"inputevents{args.suffix}", root=args.syn_data_root, train_ids=train_ids, test_ids=test_ids, mode='train')
    train_pre = load_table(args.ehr, f"prescriptions{args.suffix}", root=args.syn_data_root, train_ids=train_ids, test_ids=test_ids, mode='train')

    train_df = pd.concat([train_lab, train_input, train_pre])[['stay_id', 'time', 'itemid']].dropna()
    train_df = train_df.sort_values(by=['stay_id', 'time'])
    grouped = train_df.groupby(['stay_id', 'time'])['itemid'].apply(set).reset_index()

    item_matrix = mlb.transform(grouped['itemid'])
    grouped['item_vector'] = list(item_matrix)

    seq_data = defaultdict(list)
    for _, row in grouped.iterrows():
        seq_data[row['stay_id']].append((row['time'], row['item_vector']))
    for sid in seq_data:
        seq_data[sid] = sorted(seq_data[sid], key=lambda x: x[0])

    all_seq_data = list(seq_data.values())
    train_size = int(len(all_seq_data) * 0.8)
    train_data, valid_data = random_split(all_seq_data, [train_size, len(all_seq_data) - train_size], generator=torch.Generator().manual_seed(args.seed))
    train_dataset = RandomTEventDataset(train_data)
    valid_dataset = RandomTEventDataset(valid_data)

    train_loader = DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True,
        collate_fn=train_dataset.collate_fn, num_workers=0
    )

    valid_loader = DataLoader(
        valid_dataset, batch_size=args.batch_size, shuffle=False,
        collate_fn=valid_dataset.collate_fn, num_workers=0
    )

    model = EventPredictor(len(mlb.classes_), 128, len(mlb.classes_))
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    item_counts = np.sum(item_matrix, axis=0)
    weights = 1.0 / np.log(item_counts + 1.1 + 1e-5)
    weights /= np.mean(weights)
    class_weights = torch.tensor(weights, dtype=torch.float32, device=device)
    criterion = nn.BCELoss(weight=class_weights)
    
    if not args.test_only:
        train_model(args, train_loader, valid_loader, criterion, device)
    else:
        test_model(args, test_loader, model, device, criterion, args.threshold)
    