import logging
import numpy as np
import torch
import torch.nn as nn
from sklearn.metrics import accuracy_score
from tqdm import tqdm
from utils import *
from sklearn.metrics import precision_score, recall_score, f1_score, roc_auc_score, average_precision_score
import pandas as pd



def train_model(model, dataloader, optimizer, device, feature_weight, pos_weight):
    """Train the model for one epoch, ignoring -1 padding values in labels."""
    model.train()
    total_loss = 0
    total_acc = 0
    criterion = nn.BCEWithLogitsLoss(reduction='none')  

    for idx, batch in enumerate(dataloader):
        times = batch[0].to(device).float() 
        train_biomarkers = batch[1].to(device).float() 
        risk_factor = batch[2].to(device).float()  
        labels = batch[3].to(device).float()  
        labels = retain_first_jump_v2(labels) 
        mask = batch[4].to(device).float() 
        H = batch[5].to(device).float()  

        optimizer.zero_grad()
        outputs = model(times, train_biomarkers, risk_factor, H, mask) # [batch_size, seq_len, 21]


        if idx == 0:
            all_embeddings = outputs
        else:
            all_embeddings = torch.cat((all_embeddings, outputs), dim=0)

        valid_mask = (labels != -1).float()  # [batch_size, seq_len, 21] -> 1.0 = valid, 0.0 = padding


        loss = F.binary_cross_entropy_with_logits(
            outputs, labels,  reduction='none'
        )
        feature_weight = feature_weight.view(1, 1, 21).to(device)
        pos_weight = pos_weight.view(1, 1, 21).to(device)
        class_weight = torch.where(labels == 1, pos_weight, 1.0)
        loss = loss * feature_weight * class_weight * valid_mask


        eps = 1e-8 
        loss = loss.sum() / (valid_mask.sum() + eps)


        overall_accuracy, accuracy_per_label, preds_binary, probs = calculate_accuracy(outputs, labels, valid_mask)


        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_acc += overall_accuracy

    return total_loss / len(dataloader), total_acc / len(dataloader), accuracy_per_label, all_embeddings


def test_model(model, dataloader, device, feature_weight, pos_weight):
    
    model.eval()
    total_loss = 0
    num_batches = len(dataloader)


    total_precision_per_label = torch.zeros(21, device=device)
    total_recall_per_label = torch.zeros(21, device=device)
    total_f1_per_label = torch.zeros(21, device=device)
    total_accuracy_per_label = torch.zeros(21, device=device)  
    subject_count_per_label = torch.zeros(21, device=device)  

    with torch.no_grad():
        for idx, batch in enumerate(tqdm(dataloader, desc="Testing", ncols=100)):
            times = batch[0].to(device).float()
            train_biomarkers = batch[1].to(device).float()
            risk_factor = batch[2].to(device).float()
            labels = batch[3].to(device).float()
            labels = retain_first_jump_v2(labels) 
            mask = batch[4].to(device).float()
            H = batch[5].to(device).float()

            outputs = model(times, train_biomarkers, risk_factor, H, mask)  # [batch_size, seq_len, 21]

            if idx == 0:
                all_embeddings = outputs
            else:
                all_embeddings = torch.cat((all_embeddings, outputs),  dim=0)


            valid_mask = (labels != -1).float()

            loss = F.binary_cross_entropy_with_logits(outputs, labels, reduction='none')

            feature_weight = feature_weight.view(1, 1, 21).to(device)
            pos_weight = pos_weight.view(1, 1, 21).to(device)
            class_weight = torch.where(labels == 1, pos_weight, 1.0)
            loss = loss * feature_weight * class_weight * valid_mask

            eps = 1e-8  # aviod divide by 0
            loss = loss.sum() / (valid_mask.sum() + eps)

            total_loss += loss.item()


            probs = torch.sigmoid(outputs)  # [batch_size, seq_len, 21]
            preds_binary = (probs >= 0.5).float()  

            label_np = labels.cpu().numpy()
            preds_binary_np = preds_binary.cpu().numpy()

            for i in range(21): 
                valid_idx = (label_np[:, :, i] != -1)  
                filtered_labels = label_np[:, :, i][valid_idx]
                filtered_preds = preds_binary_np[:, :, i][valid_idx]

                if len(filtered_labels) > 0:
                    precision = precision_score(filtered_labels, filtered_preds, average='binary', zero_division=0)
                    recall = recall_score(filtered_labels, filtered_preds, average='binary', zero_division=0)
                    f1 = f1_score(filtered_labels, filtered_preds, average='binary', zero_division=0)
                    accuracy = accuracy_score(filtered_labels, filtered_preds)

                    total_precision_per_label[i] += precision
                    total_recall_per_label[i] += recall
                    total_f1_per_label[i] += f1
                    total_accuracy_per_label[i] += accuracy


                count_ones = ((label_np[:, :, i] == 1) & (valid_mask[:, :, i].cpu().numpy() == 1)).sum()
                subject_count_per_label[i] += count_ones 



    avg_loss = total_loss / num_batches
    avg_precision_per_label = (total_precision_per_label / num_batches).tolist()
    avg_recall_per_label = (total_recall_per_label / num_batches).tolist()
    avg_f1_per_label = (total_f1_per_label / num_batches).tolist()
    avg_accuracy_per_label = (total_accuracy_per_label / num_batches).tolist()
    avg_subject_count_per_label = subject_count_per_label.cpu().numpy().tolist()


    weights = np.ones(21) / 21  
    # weights = np.array(avg_subject_count_per_label) / sum(avg_subject_count_per_label)  
    weighted_precision = np.sum(weights * avg_precision_per_label)
    weighted_recall = np.sum(weights * avg_recall_per_label)
    weighted_f1 = np.sum(weights * avg_f1_per_label)
    weighted_accuracy = np.sum(weights * avg_accuracy_per_label)


    df_results = pd.DataFrame({
        "Task": [f"Task {i+1}" for i in range(21)] + ["Weighted Avg"],
        "Accuracy": avg_accuracy_per_label + [weighted_accuracy],
        "Precision": avg_precision_per_label + [weighted_precision],
        "Recall": avg_recall_per_label + [weighted_recall],
        "F1-score": avg_f1_per_label + [weighted_f1],
        "Count": avg_subject_count_per_label + [sum(avg_subject_count_per_label)],
    })

    return avg_loss, weighted_accuracy, df_results, all_embeddings

