#%%
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from PIL import Image
import numpy as np
from transformers import CLIPProcessor, CLIPModel
from tqdm import tqdm
import torch.nn as nn
import os
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import ReduceLROnPlateau

#%%
def rgba_to_binary(image, background_color=(255, 255, 255)):
    if image.mode == 'RGBA':
        background = Image.new("RGB", image.size, background_color)
        background.paste(image, mask=image.split()[3])
        image = background
    else:
        image = image.convert("RGB")
    arr = np.array(image)
    white_mask = np.all(arr == 255, axis=-1)
    binary_mask = np.where(white_mask, 0, 1).astype(np.uint8)
    return binary_mask

def get_mask_bbox(binary_mask):
    if isinstance(binary_mask, torch.Tensor):
        binary_mask = binary_mask.detach().cpu().numpy()
    # If input is not a numpy array, convert it
    elif not isinstance(binary_mask, np.ndarray):
        binary_mask = np.array(binary_mask)
        
    binary_mask = binary_mask.astype(np.uint8)
    foreground_coords = np.argwhere(binary_mask == 0)
    if foreground_coords.size == 0:
        print("None mask returned, get whole image!!!!!!!!!!!!!")
        h, w = binary_mask.shape
        return (0, 0, w, h)

    y_min, x_min = foreground_coords.min(axis=0)
    y_max, x_max = foreground_coords.max(axis=0)
    return (x_min, y_min, x_max, y_max)

#%%
class GatedCrossAttentionFusion(nn.Module):
    def __init__(self, embed_dim, num_heads=8):
        super().__init__()
        self.cnn_region = nn.Sequential(
            nn.Conv1d(1, 1, kernel_size=3, padding=1),
            nn.ReLU()
        )
        self.cnn_full = nn.Sequential(
            nn.Conv1d(1, 1, kernel_size=3, padding=1),
            nn.ReLU()
        )
        self.cross_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.gate_fc = nn.Sequential(
            nn.Linear(embed_dim, 1),
            nn.Sigmoid()
        )
        self.ln = nn.LayerNorm(embed_dim)
        self.proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, region_emb, full_emb):
        region = region_emb.unsqueeze(1)
        full = full_emb.unsqueeze(1)    

        region_cnn = self.cnn_region(region).squeeze(1)
        full_cnn   = self.cnn_full(full).squeeze(1)

        Q = region_cnn.unsqueeze(1)
        K = V = full_cnn.unsqueeze(1)
        cross, _ = self.cross_attn(Q, K, V)
        cross = cross.squeeze(1)

        gate = self.gate_fc(region_cnn)
        gated = region_cnn + gate * cross

        fused = self.ln(gated)
        projected = self.proj(fused)
        return projected

#%%
class RegionCLIPDataset(Dataset):
    def __init__(self, hf_dataset):
        self.dataset = hf_dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        row = self.dataset[idx]
        image = row["OUTPUT_IMG"]
        mask = row["MASK_IMG"]
        text = row["OUTPUT_DESCRIPTION"]
        resolution = 512

        image = Image.open(image) if isinstance(image, str) else image
        mask = Image.open(mask) if isinstance(mask, str) else mask
        
        image = image.resize((resolution, resolution), Image.BICUBIC)
        mask = mask.resize((resolution, resolution), Image.NEAREST)

        binary_mask = rgba_to_binary(mask)
        bbox = get_mask_bbox(binary_mask)
        region = image.crop(bbox) if bbox else image
        
        # # Plot with matplotlib (no saving)
        # print(text)
        # plt.figure(figsize=(4, 4))
        # plt.imshow(mask)
        # if bbox:
        #     x0, y0, x1, y1 = bbox
        #     plt.gca().add_patch(
        #         plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor='red', linewidth=2)
        #     )
        # plt.axis('off')   
        # plt.show()

        return image, region, text

def custom_collate(batch):
    full_imgs, regions, texts = zip(*batch)
    return list(full_imgs), list(regions), list(texts)

#%%
if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = CLIPModel.from_pretrained("pretrained_frameworks/clip-vit-base-patch16").to(device)
    processor = CLIPProcessor.from_pretrained("pretrained_frameworks/clip-vit-base-patch16")
    fusion = GatedCrossAttentionFusion(embed_dim=model.config.projection_dim).to(device)
    
    # fusion_ckpt = "fusion_phase1.pth"
    # if os.path.exists(fusion_ckpt):
    #     print(f"Loading GatedCrossAttentionFusion weights from {fusion_ckpt}")
    #     fusion.load_state_dict(torch.load(fusion_ckpt, map_location=device))
    # else:
    #     print(f"No fusion checkpoint found at {fusion_ckpt}, using random init.")

    hf_dataset = load_dataset("downloaded_datatset/HumanEdit", split="train", streaming=False)
    # limited_samples = list(hf_dataset.take(100))
    # custom_dataset = RegionCLIPDataset(limited_samples)
    custom_dataset = RegionCLIPDataset(hf_dataset)
    train_loader = DataLoader(custom_dataset, batch_size=512, shuffle=True, collate_fn=custom_collate)

    # Phase 1: train only fusion
    for param in model.parameters():
        param.requires_grad = False

    optimizer = torch.optim.AdamW(fusion.parameters(), lr=1e-4)
    FUSION_EPOCHS = 50

    print("\n=== Phase 1: Train Fusion ===")
    best_loss = float('inf')  # initialize best loss
    best_model_path = "fusion_phase1_best.pth"

    for epoch in range(FUSION_EPOCHS):
        model.eval()
        fusion.train()
        total_loss = 0

        for full_imgs, regions, texts in tqdm(train_loader, desc=f"Epoch {epoch+1}/{FUSION_EPOCHS}"):
            inputs_region = processor(images=regions, return_tensors="pt", padding=True, truncation=True).to(device)
            inputs_full = processor(images=full_imgs, return_tensors="pt", padding=True, truncation=True).to(device)
            inputs_text = processor(text=texts, return_tensors="pt", padding=True, truncation=True).to(device)

            E_r = model.get_image_features(**inputs_region)
            E_f = model.get_image_features(**inputs_full)
            E_t = model.get_text_features(**inputs_text)

            E_rf = fusion(E_r, E_f)

            E_rf = nn.functional.normalize(E_rf, dim=-1)
            E_t = nn.functional.normalize(E_t, dim=-1)

            logits = E_rf @ E_t.T
            labels = torch.arange(len(regions)).to(device)
            loss = torch.nn.functional.cross_entropy(logits, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        print(f"[Phase 1][Epoch {epoch+1}] Loss: {avg_loss:.4f}")

        # Save checkpoint if this is the best model so far
        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save(fusion.state_dict(), best_model_path)
            print(f"✅ Saved best model at epoch {epoch+1} with loss {avg_loss:.4f}")

    # Phase 2: unfreeze CLIP and train all
    print("\n=== Phase 2: Fine-Tune CLIP + Fusion ===")
    for param in model.parameters():
        param.requires_grad = True

    optimizer = torch.optim.AdamW(
        list(model.parameters()) + list(fusion.parameters()), lr=1e-5
    )
    scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5, verbose=True)

    EPOCHS = 20

    for epoch in range(EPOCHS):
        model.train()
        fusion.train()
        total_loss = 0
        for full_imgs, regions, texts in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
            inputs_region = processor(images=regions, return_tensors="pt", padding=True, truncation=True).to(device)
            inputs_full = processor(images=full_imgs, return_tensors="pt", padding=True, truncation=True).to(device)
            inputs_text = processor(text=texts, return_tensors="pt", padding=True, truncation=True).to(device)

            E_r = model.get_image_features(**inputs_region)
            E_f = model.get_image_features(**inputs_full)
            E_t = model.get_text_features(**inputs_text)

            E_rf = fusion(E_r, E_f)
            
            E_rf = nn.functional.normalize(E_rf, dim=-1)
            E_t = nn.functional.normalize(E_t, dim=-1)
            logits = E_rf @ E_t.T
            labels = torch.arange(len(regions)).to(device)
            loss = torch.nn.functional.cross_entropy(logits, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            
        avg_loss = total_loss / len(train_loader)
        scheduler.step(avg_loss)
        print(f"[Phase 2][Epoch {epoch+1}] Loss: {avg_loss:.4f}")

    # Save full model
    os.makedirs("finetuned_models/finetuned_clip_region", exist_ok=True)
    file_path = f"finetuned_models/finetuned_clip_region/clip_fusion_joint_{FUSION_EPOCHS}fusion_{EPOCHS}joint.pth"

    torch.save({
        "clip_model_state_dict": model.state_dict(),
        "fusion_state_dict": fusion.state_dict(),
        "clip_config": model.config,
        "fusion_config": {
            "embed_dim": fusion.cross_attn.embed_dim,
            "num_heads": fusion.cross_attn.num_heads
        }
    }, file_path)
    print(f"Model is saved successfully at {file_path}")

# %%
