import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import recall_score, confusion_matrix
from torch.utils.tensorboard import SummaryWriter

def train_epoch(model, dataloader, optimizer, criterion, device, writer: SummaryWriter, epoch: int):
    """
    Single training epoch with gradient clipping and TensorBoard logging.
    """
    model.train()
    running_loss = 0.0
    total_grad_norm = 0.0
    pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc="Training")
    for i, (images, labels, support_set) in pbar:
        images = images.to(device)
        labels = labels.long().to(device)
        support_set = support_set.to(device)

        optimizer.zero_grad()
        if hasattr(model, "support_set_size"):  # ClassificationWrapper or Ensemble
            logits = model(images, None, labels, support_set)
        else:
            logits = model(images)
        loss = criterion(logits, labels)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 5)

        grad_norm = 0.0
        for p in model.parameters():
            if p.grad is not None:
                grad_norm += p.grad.data.norm(2).item() ** 2
        grad_norm = grad_norm ** 0.5

        optimizer.step()

        running_loss += loss.item() * images.size(0)
        total_grad_norm += grad_norm

        avg_loss = running_loss / ((i+1) * dataloader.batch_size)
        avg_grad_norm = total_grad_norm / (i+1)
        pbar.set_description(f"Train Loss: {avg_loss:.4f}, Grad Norm: {avg_grad_norm:.4f}")

    epoch_loss = running_loss / len(dataloader.dataset)
    writer.add_scalar("Loss/Train", epoch_loss, epoch)
    writer.add_scalar("GradNorm/Train", avg_grad_norm, epoch)

    for name, param in model.named_parameters():
        writer.add_histogram(f"{name}", param, epoch)
        if param.grad is not None:
            writer.add_histogram(f"{name}_grad", param.grad, epoch)

    return epoch_loss

def save_support_set(support_set_tensor, num_classes, support_size, out_path):
    """
    Save the support set grid as an image.
    """
    import torchvision.transforms.functional as TF
    import matplotlib.pyplot as plt
    import torch

    fig, axs = plt.subplots(num_classes, support_size, figsize=(support_size * 2, num_classes * 2))
    if num_classes == 1:
        axs = np.expand_dims(axs, axis=0)
    if support_size == 1:
        axs = np.expand_dims(axs, axis=1)

    for i in range(num_classes):
        for j in range(support_size):
            img_tensor = support_set_tensor[i, j].cpu()
            mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
            std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
            img_tensor = img_tensor * std + mean
            img = TF.to_pil_image(img_tensor)
            axs[i, j].imshow(img)
            axs[i, j].axis("off")

    plt.tight_layout()
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    fig.savefig(out_path, dpi=300)
    print(f"Saved support grid to {out_path}")
    plt.close(fig)

def validate_epoch(model, dataloader, criterion, device, writer, epoch, class_names, plot_embeddings_flag=True):
    """
    Single validation epoch with per-class recall, macro/weighted recalls,
    confusion matrix computation, and TensorBoard logging.
    """
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []

    pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc="Validation")
    for i, (images, labels, support_set) in pbar:
        images = images.to(device)
        labels = labels.long().to(device)
        support_set = support_set.to(device)

        if hasattr(model, "support_set_size"):
            logits = model(images, None, labels, support_set)
        else:
            logits = model(images)

        loss = criterion(logits, labels)
        running_loss += loss.item() * images.size(0)

        preds = torch.argmax(logits, dim=1)
        correct += (preds == labels).sum().item()
        total += images.size(0)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

        avg_loss = running_loss / total
        pbar.set_description(f"Val Loss: {avg_loss:.4f}")

    epoch_loss = running_loss / len(dataloader.dataset)
    accuracy = correct / total

    per_class_recall = recall_score(all_labels, all_preds, average=None)
    macro_recall = recall_score(all_labels, all_preds, average='macro')
    weighted_recall = recall_score(all_labels, all_preds, average='weighted')

    cm = confusion_matrix(all_labels, all_preds)

    writer.add_scalar("Loss/Val", epoch_loss, epoch)
    writer.add_scalar("Accuracy/Val", accuracy, epoch)
    writer.add_scalar("Recall/Macro", macro_recall, epoch)
    writer.add_scalar("Recall/Weighted", weighted_recall, epoch)
    for i, rec in enumerate(per_class_recall):
        writer.add_scalar(f"Recall/Class_{i}", rec, epoch)

    return epoch_loss, accuracy
