from PIL import Image
import torch
import torch.nn as nn

from transformers import Owlv2Processor, Owlv2ForObjectDetection

device = "cuda:1"

processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble")
model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble").to(device)

base_params = model.owlv2.vision_model.state_dict()

original_model_params = {
    k: v.clone().detach().cpu() for k, v in base_params.items() if "classifier" in k
}
decomposed_params = {}

for k, v in base_params.items():
    if 'layernorm' in k or 'bias' in k or 'embeddings' in k or 'layrnorm' in k or 'layer_norm' in k:
        continue  # skip this param

    # print(k)
    U, S, V = torch.svd(v)
    decomposed_params[f"{k}.U"] = U
    decomposed_params[f"{k}.S"] = S
    decomposed_params[f"{k}.V"] = V

class Policy(nn.Module):
    def __init__(self, base_params, gpu, init_val=0.1, max_mult=1, **kwargs):
        # Create learnable parameters.
        super().__init__()
        self.learnable_params = {}
        self.num_params = 0
        self.max_mult = max_mult
        self.enable_mask = True
        for k, v in base_params.items():
            # each param initialized with small gaussian noise
            if 'layernorm' in k or 'bias' in k or 'embeddings' in k or 'layrnorm' in k or 'layer_norm' in k:
                continue
            else:
                self.learnable_params[k] = torch.nn.Parameter(
                    data=(
                        torch.randn(
                            min(v.shape),
                            device=gpu,
                            dtype=torch.bfloat16,
                        )
                        * init_val
                    ),
                    requires_grad=True,
                )
                self.num_params += self.learnable_params[k].numel()
        print(f"#params={self.num_params}")
        self.learnable_params_list = list(self.learnable_params.values())
        self.trainable_params = self.learnable_params_list
        self.learnable_params_module_list = nn.ParameterList(self.learnable_params_list)

    def get_learnable_params(self, detach=False):
        return self.learnable_params

    def set_trainable_params_values(self, new_values):
        with torch.no_grad():
            for p, v in zip(self.trainable_params, new_values):
                p.data.copy_(v)

    def get_mask(self, p):
        if self.enable_mask:
            return torch.sigmoid(p).to(torch.bfloat16) * self.max_mult
        else:
            return torch.ones_like(p).to(torch.bfloat16)

policy = Policy(base_params, gpu=device, decomposed_params=decomposed_params, mode=1)
learnable_params = policy.get_learnable_params()

def compose_new_params(
    policy,
    param_name,
    decomposed_params,
    learnable_params,
):
    """Compose new parameters from decomposed parameters."""
    # mm = get_mask(learnable_params[param_name])
    mm = policy.get_mask(learnable_params[param_name])
    return (
        decomposed_params[f"{param_name}.U"]
        @ torch.diag_embed(decomposed_params[f"{param_name}.S"] * mm)
        @ decomposed_params[f"{param_name}.V"].T
    ) * (
        decomposed_params[f"{param_name}.S"].sum()
        / (decomposed_params[f"{param_name}.S"] * mm).sum()
    )

def backward(
    policy,
    model,
    base_params,
    decomposed_params,
    learnable_params,
):
    """Backward pass."""
    keys_to_backprop = [k for k in base_params if 'layernorm' not in k and 'bias' not in k and 'embeddings' not in k and 'layrnorm' not in k and 'layer_norm' not in k]
    last_key = keys_to_backprop[-1]
    for k in keys_to_backprop[:-1]:
        compose_new_params(policy, k, decomposed_params, learnable_params).backward(
            model.owlv2.vision_model.get_parameter(k).grad, retain_graph=True
        )
    # release graph
    compose_new_params(policy, last_key, decomposed_params, learnable_params).backward(
        model.owlv2.vision_model.get_parameter(last_key).grad, retain_graph=False
    )

new_params = {}
for k in base_params:
    if 'layernorm' in k or 'bias' in k or 'embeddings' in k or 'layrnorm' in k or 'layer_norm' in k:
        new_params[k] = base_params[k]
        continue  # skip this param

    new_params[k] = compose_new_params(
        policy, k, decomposed_params, learnable_params
    )

model.owlv2.vision_model.load_state_dict(new_params)

# # set the learnable params to training
for k in learnable_params:
    model.owlv2.vision_model.get_parameter(k).requires_grad_(True)
    
def apply_policy_to_model(policy, base_params, decomposed_params, learnable_params):
    updated_params = {}
    for k in base_params:
        if any(skip in k for skip in ['layernorm', 'bias', 'embeddings', 'layer_norm', 'layrnorm']):
            updated_params[k] = base_params[k]
            continue
        updated_params[k] = compose_new_params(policy, k, decomposed_params, learnable_params)
    model.owlv2.vision_model.load_state_dict(updated_params, strict=False)

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(policy.trainable_params, lr=1e-2)

# loss
def classification_loss_fn(tgt_ids, pred_logits):
    out_prob = pred_logits.flatten(0, 1).sigmoid()
    
    # Compute the classification cost.
    alpha = 0.25
    gamma = 2.0
    neg_cost_class = (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log())
    pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
    cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids]
    
    return cost_class

from torch.nn.utils.rnn import pad_sequence
from torchvision.ops.boxes import box_area
from scipy.optimize import linear_sum_assignment

# modified from torchvision to also return the union
def box_iou(boxes1, boxes2, all_pairs = True):
    area1 = box_area(boxes1)
    area2 = box_area(boxes2)

    if all_pairs:
        lt = torch.max(boxes1[:, None, :2], boxes2[:, :2])  # [N,M,2]
        rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:])  # [N,M,2]
    else:
        lt = torch.max(boxes1[:, :2], boxes2[:, :2])  # [N,2]
        rb = torch.min(boxes1[:, 2:], boxes2[:, 2:])  # [N,2]      

    wh = (rb - lt).clamp(min=0)  # [N,M,2]
    if all_pairs:
        inter = wh[:, :, 0] * wh[:, :, 1]  # [N,M]

        union = area1[:, None] + area2 - inter
    else:
        inter = wh[:, 0] * wh[:, 1]  # [N]
        union = area1 + area2 - inter       

    iou = inter / union
    return iou, union

def generalized_box_iou(boxes1, boxes2, all_pairs = True):
    """
    Generalized IoU from https://giou.stanford.edu/

    The boxes should be in [x0, y0, x1, y1] format

    Returns a [N, M] pairwise matrix, where N = len(boxes1)
    and M = len(boxes2)
    """
    # degenerate boxes gives inf / nan results
    # so do an early check
    assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
    assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
    iou, union = box_iou(boxes1, boxes2, all_pairs)

    if all_pairs:
        lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
        rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
    else:
        lt = torch.min(boxes1[:, :2], boxes2[:, :2])
        rb = torch.max(boxes1[:, 2:], boxes2[:, 2:])      

    wh = (rb - lt).clamp(min=0)  # [N,M,2]
    if all_pairs:
        area = wh[:, :, 0] * wh[:, :, 1]
    else:
        area = wh[:, 0] * wh[:, 1]  
          
    return iou - (area - union) / area

def box_cxcywh_to_xyxy(x):
    x_c, y_c, w, h = x.unbind(-1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
         (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=-1)

def box_loss_fn(tgt_bbox, pred_boxes):
    out_bbox = pred_boxes.flatten(0, 1) # [batch_size * num_queries, 4]
    
    cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
    
    cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox),
                                        box_cxcywh_to_xyxy(tgt_bbox))
    
    return cost_bbox, cost_giou

_cost_bbox = 1.0
_cost_class = 1.0
_cost_giou = 1.0

# =========================================================== #
# loop (changing cat to dog) [Class label changing]
gt = torch.tensor([
    [0.1647, 0.1484, 0.2198, 0.0675],
    [0.5488, 0.2054, 0.0566, 0.1714]
])

targets = [{
    "labels": torch.tensor([1 for _ in range(gt.shape[0])]),
    "boxes": gt.detach(),
}]

for i in range(10):
    url = "000000039769.jpg"
    image = Image.open(url)
    text_labels = [["a photo of a cat", "a photo of a dog"]]
    inputs = processor(text=text_labels, images=image, return_tensors="pt")
    for key in inputs:
        inputs[key] = inputs[key].to(device)
    outputs = model(**inputs)

    bs, num_queries = outputs["logits"].shape[:2]

    tgt_ids = torch.cat([v["labels"] for v in targets]).to(device)
    tgt_bbox = torch.cat([v["boxes"] for v in targets]).to(device)

    pred_logits = outputs["logits"] # (batch, num_objects, num_queries)
    pred_boxes = outputs["pred_boxes"] # (batch, num_objects, 4)

    with torch.no_grad():
        cost_class = classification_loss_fn(tgt_ids, pred_logits)
        cost_bbox, cost_giou = box_loss_fn(tgt_bbox, pred_boxes)

        C = _cost_bbox * cost_bbox + _cost_class * cost_class + _cost_giou * cost_giou
        C = C.view(bs, num_queries, -1).cpu()
        
        
        sizes = [len(v["boxes"]) for v in targets]
        indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
        indices = [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] # [batch_size, 2, num_targets]
        
    src_logits = torch.cat([pred_logits[i, idx[0], :] for i, idx in enumerate(indices)]) # [batch * num_targets, 2]
    trg_labels = torch.cat([targets[i]["labels"][idx[1]] for i, idx in enumerate(indices)]).to(src_logits.device) # [batch * num_targets]

    src_boxes = torch.cat([pred_boxes[i, idx[0], :] for i, idx in enumerate(indices)]) # [batch * num_targets, 2]
    trg_bbox = torch.cat([targets[i]["boxes"][idx[1],:] for i, idx in enumerate(indices)]).to(src_logits.device)

    # objection loss
    loss_bbox = torch.nn.functional.l1_loss(src_boxes, trg_bbox) # L1 loss
    
    loss_giou = 1 - generalized_box_iou(box_cxcywh_to_xyxy(src_boxes),
                                        box_cxcywh_to_xyxy(trg_bbox), all_pairs=False).mean()    

    # classification loss
    loss_class = torch.nn.functional.cross_entropy(src_logits, 
                                                trg_labels)

    loss = loss_class + loss_bbox + loss_giou

    # Backward
    optimizer.zero_grad()
    loss.backward()
    backward(policy, model, base_params, decomposed_params, learnable_params)
    optimizer.step()

    # Compose current masked weights
    apply_policy_to_model(policy, base_params, decomposed_params, learnable_params)
    
    
    for j in range(src_logits.shape[0]):
        pred_label = torch.argmax(src_logits[j, :].detach().cpu())
        pred_box = src_boxes[j]
        
        print("target_boxes", gt.detach(), "pred boxes", pred_box.detach(), "pred_label", text_labels[0][int(pred_label)])
    print(f"Step {i:03d} | Total Loss: {loss.item():.4f} | Class Loss: {loss_class.item():.4f} | Box Loss: {loss_bbox.item():.4f} | Giou Loss: {loss_giou.item():.4f}")

## INFERENCE
url = "000000039769.jpg"
image = Image.open(url)
text_labels = [["a photo of a cat", "a photo of a dog"]]
inputs = processor(text=text_labels, images=image, return_tensors="pt")
for key in inputs:
    inputs[key] = inputs[key].to(device)
outputs = model(**inputs)

# Target image sizes (height, width) to rescale box predictions [batch_size, 2]
target_sizes = torch.tensor([(image.height, image.width)])
# Convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax)
results = processor.post_process_grounded_object_detection(
    outputs=outputs, target_sizes=target_sizes, threshold=0.4, text_labels=text_labels
)
# Retrieve predictions for the first image for the corresponding text queries
result = results[0]
boxes, scores, text_labels = result["boxes"], result["scores"], result["text_labels"]

import cv2

image = cv2.imread(url)

for box, score, text_label in zip(boxes, scores, text_labels):
    box = [round(i, 2) for i in box.tolist()]
    print(f"Detected {text_label} with confidence {round(score.item(), 3)} at location {box}")
    
    x1, y1, x2, y2 = box
    image = cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
    image = cv2.putText(image, text_label, (int(x1)+10, int(y1)+10), cv2.FONT_HERSHEY_COMPLEX, 1.0, (0, 255, 0), 1)
    
cv2.imwrite("prediction_cat_to_dog.jpg", image)