import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms
from tqdm import tqdm
from sklearn.metrics import accuracy_score, roc_auc_score
import numpy as np


from data.pairs_dataset.dataset import KNEE_TREATMENTS_OF_INTEREST, X_RAY_GRADES, CLINICAL_INFO, DEMOGRAPHIC_INFO, TimeAgnosticPropensityDataset

TREATMENT_LABELS = list(KNEE_TREATMENTS_OF_INTEREST.values()) + ["No Treatment"]
NUM_TREATMENTS = len(TREATMENT_LABELS)  # Should be 13
FEATURE_DIM = (len(X_RAY_GRADES) * 2) + len(CLINICAL_INFO) + len(DEMOGRAPHIC_INFO)

def compute_distribution(dataset):
    # Each sample in dataset.samples is a tuple: (img_path, features_tensor, label_tensor)
    counter = {i: 0 for i in range(NUM_TREATMENTS)}
    for _, _, label in dataset.samples:
        # label is a multi-hot vector
        for i, bit in enumerate(label):
            if float(bit) == 1.0:
                counter[i] += 1
    total = len(dataset)
    prevalence = {cls: counter[cls] / total for cls in counter}
    return counter, prevalence

def log_distribution(log_file, split_name, dataset):
    counter, prevalence = compute_distribution(dataset)
    with open(log_file, "a") as f:
        f.write(f"\nDistribution for {split_name} set:\n")
        for cls in range(NUM_TREATMENTS):
            treat_name = TREATMENT_LABELS[cls]
            count = counter.get(cls, 0)
            prev = prevalence.get(cls, 0.0)
            f.write(f"  Class {cls} ({treat_name}): {count} samples, prevalence {prev:.4f}\n")

# -----------------------------
# Helper Function to Compute Positive Weights for Loss Reweighting
# -----------------------------
def compute_pos_weight(dataset):
    # For multi-label classification: for each class, pos_weight = (# negatives)/(# positives)
    total = len(dataset)
    positives = np.zeros(NUM_TREATMENTS)
    for _, _, label in dataset.samples:
        lab = label.numpy() if isinstance(label, torch.Tensor) else np.array(label)
        positives += lab
    negatives = total - positives
    pos_weight = negatives / (positives + 1e-6)
    return torch.tensor(pos_weight, dtype=torch.float)


# -----------------------------
# Propensity Model Definition
# -----------------------------
class PropensityModel(nn.Module):
    def __init__(self, feature_dim=FEATURE_DIM, num_treatments=NUM_TREATMENTS):
        """
        A fusion model that takes an image and a concatenated feature vector (from both knees' X_RAY_GRADES,
        CLINICAL_INFO, and DEMOGRAPHIC_INFO) and outputs logits for NUM_TREATMENTS classes.
        """
        super().__init__()
        self.backbone = timm.create_model('efficientformerv2_l.snap_dist_in1k', pretrained=True, num_classes=0)
        self.fc_img = nn.Linear(self.backbone.num_features, 128)
        self.fc_feat = nn.Linear(feature_dim, 64)
        self.classifier = nn.Linear(128 + 64, num_treatments)
    
    def forward(self, image, features):
        img_feats = self.backbone(image)
        img_feats = F.relu(self.fc_img(img_feats))
        tab_feats = F.relu(self.fc_feat(features))
        combined = torch.cat([img_feats, tab_feats], dim=1)
        logits = self.classifier(combined)
        return logits

# -----------------------------
# Evaluation Function
# -----------------------------
def evaluate(model, dataloader, device, threshold=0.5):
    """
    For multi-label, we use BCEWithLogitsLoss.
    After applying sigmoid and thresholding at 0.5, we compute:
      - Average loss,
      - Overall accuracy (fraction of correctly predicted label positions),
      - Per-class accuracy,
      - Per-class recall,
      - Macro AUC.
    """
    model.eval()
    all_preds = []
    all_labels = []
    running_loss = 0.0
    criterion = nn.BCEWithLogitsLoss()
    with torch.no_grad():
        for images, features, labels in dataloader:
            images = images.to(device)
            features = features.to(device)
            labels = labels.to(device)
            outputs = model(images, features)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * images.size(0)
            probs = torch.sigmoid(outputs)
            preds = (probs >= threshold).float()
            all_preds.append(preds.cpu().numpy())
            all_labels.append(labels.cpu().numpy())
    avg_loss = running_loss / len(dataloader.dataset)
    all_preds = np.concatenate(all_preds, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)
    overall_acc = (all_preds == all_labels).mean()
    
    per_class_acc = {}
    per_class_recall = {}
    for cls in range(NUM_TREATMENTS):
        cls_preds = all_preds[:, cls]
        cls_labels = all_labels[:, cls]
        # Compute per-class accuracy and cast to Python float rounded to 3 decimals.
        per_class_acc[cls] = float(round((cls_preds == cls_labels).mean(), 3))
        # Recall = TP / (TP + FN)
        tp = np.sum((cls_preds == 1) & (cls_labels == 1))
        fn = np.sum((cls_preds == 0) & (cls_labels == 1))
        per_class_recall[cls] = float(round(tp / (tp + fn), 3)) if (tp + fn) > 0 else None

    # Compute overall recall across all labels
    all_labels_flat = all_labels.reshape(-1)
    all_preds_flat = all_preds.reshape(-1)
    tp = np.sum((all_preds_flat == 1) & (all_labels_flat == 1))
    fn = np.sum((all_preds_flat == 0) & (all_labels_flat == 1))
    overall_recall = float(round(tp / (tp + fn), 3)) if (tp + fn) > 0 else 0.0

    try:
        all_outputs = []
        with torch.no_grad():
            for images, features, _ in dataloader:
                images = images.to(device)
                features = features.to(device)
                outs = model(images, features)
                probs = torch.sigmoid(outs)
                all_outputs.append(probs.cpu().numpy())
        all_probs = np.concatenate(all_outputs, axis=0)
        aucs = []
        for cls in range(NUM_TREATMENTS):
            if np.unique(all_labels[:, cls]).size > 1:
                auc = roc_auc_score(all_labels[:, cls], all_probs[:, cls])
                aucs.append(auc)
        macro_auc = np.mean(aucs) if aucs else 0.0
    except Exception as e:
        macro_auc = 0.0
    return avg_loss, overall_acc, macro_auc, per_class_acc, per_class_recall, overall_recall



# -----------------------------
# Training Function with Logging and Loss Reweighting
# -----------------------------
def train_propensity_model(csv_path, epochs=10, batch_size=64, lr=1e-4, log_file="checkpoints/Propensity_Model/train_propensity.log"):
    train_transforms = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    eval_transforms = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    
    # Create training and validation splits.
    df = pd.read_csv(csv_path)
    train_df = df.sample(frac=0.8, random_state=42)
    val_df = df.drop(train_df.index)
    train_csv = "temp_train.csv"
    val_csv = "temp_val.csv"
    train_df.to_csv(train_csv, index=False)
    val_df.to_csv(val_csv, index=False)
    
    train_dataset = TimeAgnosticPropensityDataset(train_csv, image_transform=train_transforms)
    val_dataset = TimeAgnosticPropensityDataset(val_csv, image_transform=eval_transforms)
    
    # Log dataset distributions.
    with open(log_file, "w") as f:
        f.write("Propensity Model Training Log\n")
    log_distribution(log_file, "Train", train_dataset)
    log_distribution(log_file, "Validation", val_dataset)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = PropensityModel(feature_dim=FEATURE_DIM, num_treatments=NUM_TREATMENTS).to(device)
    
    # Compute per-class positive weights.
    pos_weight = compute_pos_weight(train_dataset).to(device)
    print("Computed pos_weight:", pos_weight)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    
    model.train()

    best_val_auc = -float("inf")
    best_model_state = None
    for epoch in range(epochs):
        total_loss = 0.0
        all_preds = []
        all_labels = []
        for images, features, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            images = images.to(device)
            features = features.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model(images, features)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * images.size(0)
            preds = (torch.sigmoid(outputs) >= 0.5).float()
            all_preds.append(preds.cpu().numpy())
            all_labels.append(labels.cpu().numpy())
        avg_loss = total_loss / len(train_dataset)
        all_preds = np.concatenate(all_preds, axis=0)
        all_labels = np.concatenate(all_labels, axis=0)
        train_acc = (all_preds == all_labels).mean()
        # Compute overall train recall
        train_preds_flat = all_preds.reshape(-1)
        train_labels_flat = all_labels.reshape(-1)
        train_tp = np.sum((train_preds_flat == 1) & (train_labels_flat == 1))
        train_fn = np.sum((train_preds_flat == 0) & (train_labels_flat == 1))
        train_recall = train_tp / (train_tp + train_fn) if (train_tp + train_fn) > 0 else 0.0
        print(f"Epoch {epoch+1}: Train Loss = {avg_loss:.4f}, Train Accuracy = {train_acc*100:.2f}%, Train Recall = {train_recall*100:.2f}%")
        
        val_loss, val_acc, val_macro_auc, per_class_acc, per_class_recall, val_recall = evaluate(model, val_loader, device)
        log_str = (
            f"Epoch {epoch+1}: Train Loss = {avg_loss:.4f}, Train Acc = {train_acc*100:.2f}%, Train Recall = {train_recall*100:.2f}%\n"
            f"Validation Loss = {val_loss:.4f}, Val Acc = {val_acc*100:.2f}%, Val Recall = {val_recall*100:.2f}%, Macro AUC = {val_macro_auc:.4f}\n"
            f"Per-Class Accuracy: {per_class_acc}\n"
            f"Per-Class Recall: {per_class_recall}\n\n"
        )
        print(log_str)
        with open(log_file, "a") as f:
            f.write(log_str)

        if val_macro_auc > best_val_auc:
            best_val_auc = val_macro_auc
            best_model_state = model.state_dict()
    
    save_path = "checkpoints/Propensity_Model/propensity_model.pth"
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    if best_model_state is not None:
        torch.save(best_model_state, save_path)
        print(f"Saved best model to {save_path}")
    os.remove(train_csv)
    os.remove(val_csv)
    return model

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description="Pretrain a propensity model for knee treatments (Multi-Label).")
    parser.add_argument("--csv", type=str, default="data/pairs_dataset/train.csv", help="CSV file path")
    parser.add_argument("--epochs", type=int, default=50, help="Number of epochs")
    parser.add_argument("--batch_size", type=int, default=64, help="Batch size")
    parser.add_argument("--lr", type=float, default=5e-5, help="Learning rate")
    parser.add_argument("--log_file", type=str, default="checkpoints/Propensity_Model/train_propensity.log", help="Log file path")
    args = parser.parse_args()
    
    train_propensity_model(args.csv, epochs=args.epochs, batch_size=args.batch_size, lr=args.lr, log_file=args.log_file)
