import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision.transforms.functional as TF
from PIL import Image
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import pandas as pd
import matplotlib.pyplot as plt
from zen import ZenGrad, ZenGrad_M

# === Install segmentation_models_pytorch if not already ===
try:
    import segmentation_models_pytorch as smp
except ImportError:
    print("Please install segmentation_models_pytorch using pip.")

txt_file = os.path.join(args.voc_root, "ImageSets", "Segmentation", "train.txt")
image_dir = os.path.join(args.voc_root, "JPEGImages")
mask_dir = os.path.join(args.voc_root, "SegmentationClass")

# === Dataset ===
class VOCDataset(Dataset):
    def __init__(self, txt_file, image_dir, mask_dir, transform=None):
        with open(txt_file, 'r') as f:
            self.image_ids = f.read().splitlines()
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        img_id = self.image_ids[idx]
        img_path = os.path.join(self.image_dir, f"{img_id}.jpg")
        mask_path = os.path.join(self.mask_dir, f"{img_id}.png")
        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path)
        if self.transform:
            image, mask = self.transform(image, mask)
        return image, mask

# === Transform ===
class SegmentationTransform:
    def __init__(self, size=256):
        self.size = size

    def __call__(self, image, mask):
        image = TF.resize(image, (self.size, self.size))
        mask = TF.resize(mask, (self.size, self.size), interpolation=Image.NEAREST)
        image = TF.to_tensor(image)
        mask = torch.from_numpy(np.array(mask)).long()
        return image, mask

# === Metrics ===
def calculate_iou(pred, target, num_classes=21):
    pred = pred.view(-1)
    target = target.view(-1)
    mask = target != 255
    pred = pred[mask]
    target = target[mask]
    ious = []
    for cls in range(num_classes):
        pred_inds = (pred == cls)
        target_inds = (target == cls)
        intersection = (pred_inds & target_inds).sum().item()
        union = pred_inds.sum().item() + target_inds.sum().item() - intersection
        if union == 0:
            ious.append(float('nan'))
        else:
            ious.append(intersection / union)
    return np.nanmean(ious)

def calculate_dice(pred, target, num_classes=21):
    pred = pred.view(-1)
    target = target.view(-1)
    mask = target != 255
    pred = pred[mask]
    target = target[mask]

    dice_scores = []
    for cls in range(num_classes):
        pred_inds = (pred == cls).float()
        target_inds = (target == cls).float()
        intersection = (pred_inds * target_inds).sum()
        union = pred_inds.sum() + target_inds.sum()
        if union == 0:
            continue
        dice = (2. * intersection) / (union + 1e-6)
        dice_scores.append(dice.item())
    if len(dice_scores) == 0:
        return 0.0
    return np.clip(np.mean(dice_scores), 0.0, 1.0)

# === Decode Mask to RGB ===
def decode_segmap(label_mask, num_classes=21):
    label_colors = np.array([
        (0, 0, 0), (128, 0, 0), (0, 128, 0), (128, 128, 0),
        (0, 0, 128), (128, 0, 128), (0, 128, 128), (128, 128, 128),
        (64, 0, 0), (192, 0, 0), (64, 128, 0), (192, 128, 0),
        (64, 0, 128), (192, 0, 128), (64, 128, 128), (192, 128, 128),
        (0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0), (0, 64, 128)
    ])
    r = np.zeros_like(label_mask).astype(np.uint8)
    g = np.zeros_like(label_mask).astype(np.uint8)
    b = np.zeros_like(label_mask).astype(np.uint8)
    for l in range(num_classes):
        idx = label_mask == l
        r[idx] = label_colors[l, 0]
        g[idx] = label_colors[l, 1]
        b[idx] = label_colors[l, 2]
    return np.stack([r, g, b], axis=2)

# === Visualize Predictions ===
def visualize_predictions(model, dataloader, device, num_images=5):
    model.eval()
    images, masks = next(iter(dataloader))
    images = images.to(device)
    outputs = model(images)
    preds = torch.argmax(outputs, dim=1).cpu()
    fig, axs = plt.subplots(num_images, 3, figsize=(15, 4 * num_images))
    for i in range(num_images):
        img = images[i].cpu().permute(1, 2, 0).numpy()
        true_mask = masks[i].cpu().numpy()
        pred_mask = preds[i].numpy()
        axs[i, 0].imshow(img)
        axs[i, 0].set_title("Input Image")
        axs[i, 1].imshow(decode_segmap(true_mask))
        axs[i, 1].set_title("Ground Truth")
        axs[i, 2].imshow(decode_segmap(pred_mask))
        axs[i, 2].set_title("Prediction")
        for j in range(3): axs[i, j].axis('off')
    plt.tight_layout()
    plt.savefig("/workspace/ZenGrad_Unet.png", dpi=300)
    plt.show()

# === Evaluation ===
def evaluate_model(model, dataloader, device, epoch):
    model.eval()
    total_iou, total_dice, count = 0, 0, 0
    with torch.no_grad():
        for images, masks in dataloader:
            images, masks = images.to(device), masks.to(device)
            outputs = model(images)
            preds = torch.argmax(outputs, dim=1)
            iou = calculate_iou(preds, masks)
            dice = calculate_dice(preds, masks)
            total_iou += iou
            total_dice += dice
            count += 1
    avg_iou = total_iou / count
    avg_dice = total_dice / count
    print(f"Epoch {epoch+1} | Avg IoU: {avg_iou:.4f} | Avg Dice: {avg_dice:.4f}")
    return avg_iou, avg_dice

# === Prepare Data ===
transform = SegmentationTransform(size=256)
dataset = VOCDataset(txt_file, image_dir, mask_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

# === Model ===
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = smp.Unet(encoder_name="resnet50", encoder_weights=None, in_channels=3, classes=21)
model = model.to(device)

# === Loss and Optimizer ===
criterion = nn.CrossEntropyLoss(ignore_index=255)
optimizer = ZenGrad(model.parameters())  # Change Optimisers

# === Train Loop ===
metrics_per_epoch = []
for epoch in range(500):
    model.train()
    total_loss = 0
    for images, masks in tqdm(dataloader, desc=f"Epoch {epoch+1}"):
        images, masks = images.to(device), masks.to(device)
        outputs = model(images)
        loss = criterion(outputs, masks)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    avg_loss = total_loss / len(dataloader)
    avg_iou, avg_dice = evaluate_model(model, dataloader, device, epoch)
    metrics_per_epoch.append({
        "Epoch": epoch + 1,
        "Loss": avg_loss,
        "Mean IoU": avg_iou,
        "Mean Dice": avg_dice
    })

# === Save CSV ===
df_metrics = pd.DataFrame(metrics_per_epoch)
df_metrics.to_csv("/workspace/ZenGrad_Unet.csv", index=False)
print("Saved: ZenGrad_Unet.csv")

# === Final Prediction Visualization ===
visualize_predictions(model, dataloader, device, num_images=7)
