from transformers import ViTImageProcessor, ViTModel
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

## GLOBAL

T = 4 # Number of tasks
task_name = "voc_2012"

## LOADING THE MODEL
from transformers import OwlViTProcessor, OwlViTForObjectDetection, ViTModel
from transformers import ViTModel

device = "cuda:1"

from transformers.models.owlvit.modeling_owlvit import BaseModelOutput

processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch16")
model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch16")

backbone = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')

class OwlViTEncoder(nn.Module):
    """
    Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
    [`OwlViTEncoderLayer`].

    Args:
        config: OwlViTConfig
    """

    def __init__(self):
        super().__init__()
        self.layers = backbone.encoder.layer
        self.gradient_checkpointing = False

    def forward(
        self,
        inputs_embeds,
        attention_mask = None,
        causal_attention_mask = None,
        output_attentions = None,
        output_hidden_states = None,
        return_dict = None,
    ):
        r"""
        Args:
            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`).
            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.
                [What are attention masks?](../glossary#attention-mask)
            causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Causal mask for the text model. Mask values selected in `[0, 1]`:
                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.
                [What are attention masks?](../glossary#attention-mask)
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
                for more detail.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        """
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        encoder_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None

        hidden_states = inputs_embeds
        for encoder_layer in self.layers:
            if output_hidden_states:
                encoder_states = encoder_states + (hidden_states,)
            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    encoder_layer.__call__,
                    hidden_states,
                    attention_mask,
                    causal_attention_mask,
                    output_attentions,
                )
            else:
                layer_outputs = encoder_layer(
                    hidden_states,
                    attention_mask,
                    causal_attention_mask,
                    # output_attentions=output_attentions,
                )

            hidden_states = layer_outputs[0]

            if output_attentions:
                all_attentions = all_attentions + (layer_outputs[1],)

        if output_hidden_states:
            encoder_states = encoder_states + (hidden_states,)

        if not return_dict:
            return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
        return BaseModelOutput(
            last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
        )
        
model.owlvit.vision_model.encoder = OwlViTEncoder()
model.to(device)

##### POLICY PATCH

base_params = model.owlvit.vision_model.state_dict()

## CLONING
torch.save(base_params, "owl_vit.pt")
base_params = torch.load("owl_vit.pt")

############ DECOMPOSE
decomposed_params = {}
for k, v in base_params.items():
    if 'layernorm' in k or 'bias' in k or 'embeddings' in k or 'layrnorm' in k or 'layer_norm' in k:
        continue  # skip this param

    U, S, V = torch.svd(v)
    decomposed_params[f"{k}.U"] = U
    decomposed_params[f"{k}.S"] = S
    decomposed_params[f"{k}.V"] = V
    
#### LOSS FUNCTIONS

# loss
from torch.nn.utils.rnn import pad_sequence
from torchvision.ops.boxes import box_area
from scipy.optimize import linear_sum_assignment

# modified from torchvision to also return the union
def box_iou(boxes1, boxes2, all_pairs = True):
    area1 = box_area(boxes1)
    area2 = box_area(boxes2)

    if all_pairs:
        lt = torch.max(boxes1[:, None, :2], boxes2[:, :2])  # [N,M,2]
        rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:])  # [N,M,2]
    else:
        lt = torch.max(boxes1[:, :2], boxes2[:, :2])  # [N,2]
        rb = torch.min(boxes1[:, 2:], boxes2[:, 2:])  # [N,2]      

    wh = (rb - lt).clamp(min=0)  # [N,M,2]
    if all_pairs:
        inter = wh[:, :, 0] * wh[:, :, 1]  # [N,M]

        union = area1[:, None] + area2 - inter
    else:
        inter = wh[:, 0] * wh[:, 1]  # [N]
        union = area1 + area2 - inter       

    iou = inter / union
    return iou, union

def generalized_box_iou(boxes1, boxes2, all_pairs = True):
    """
    Generalized IoU from https://giou.stanford.edu/

    The boxes should be in [x0, y0, x1, y1] format

    Returns a [N, M] pairwise matrix, where N = len(boxes1)
    and M = len(boxes2)
    """
    # degenerate boxes gives inf / nan results
    # so do an early check
    assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
    assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
    iou, union = box_iou(boxes1, boxes2, all_pairs)

    if all_pairs:
        lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
        rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
    else:
        lt = torch.min(boxes1[:, :2], boxes2[:, :2])
        rb = torch.max(boxes1[:, 2:], boxes2[:, 2:])      

    wh = (rb - lt).clamp(min=0)  # [N,M,2]
    if all_pairs:
        area = wh[:, :, 0] * wh[:, :, 1]
    else:
        area = wh[:, 0] * wh[:, 1]  
          
    return iou - (area - union) / area

def box_cxcywh_to_xyxy(x):
    x_c, y_c, w, h = x.unbind(-1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
         (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=-1)

def box_loss_fn(tgt_bbox, pred_boxes):
    out_bbox = pred_boxes.flatten(0, 1) # [batch_size * num_queries, 4]
    
    cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
    
    cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox),
                                    box_cxcywh_to_xyxy(tgt_bbox))
    
    return cost_bbox, cost_giou

_cost_bbox = 1.0
_cost_class = 1.0
_cost_giou = 1.0

def classification_loss_fn(tgt_ids, pred_logits):
    out_prob = pred_logits.flatten(0, 1).sigmoid()
    
    # Compute the classification cost.
    alpha = 0.25
    gamma = 2.0
    neg_cost_class = (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log())
    pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
    cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids]
    
    return cost_class

######### INIT POLICY
from utils import Policy, apply_policy_to_model, backward, compose_new_params

## LOADING THE DATASET
from dataset_voc2012 import VOC2012Dataset
train_dataset = VOC2012Dataset("Dataset/VOCdevkit/VOC2012", train=True, num_tasks=T)
test_dataset = VOC2012Dataset("Dataset/VOCdevkit/VOC2012", train=False, num_tasks=T)

def collate_wrapper(batch, class_names):
    list_class_names = list(class_names)
    
    batch_inputs = {}
    targets = []    
    for data, annotations in batch:
        image_h, image_w, _ = data.shape
        image = Image.fromarray(data)
        inputs = processor(text=list_class_names, images=image, return_tensors="pt")
        
        labels = []
        boxes = []
        for annotation in annotations:
            x1, y1, x2, y2 = annotation["bbox"]
            cx, cy, w, h = (x1+x2)/(2*image_w), (y1+y2)/(2*image_h), (x2-x1)/image_w, (y2-y1)/image_h
            boxes.append((cx, cy, w, h))
            labels.append(annotation["label"])
        
        labels = torch.tensor(labels)
        boxes = torch.tensor(boxes)
        
        targets.append({
            "labels": labels,
            "boxes": boxes
        })
            
        for key in inputs:
            if key not in batch_inputs:
                batch_inputs[key] = []
            batch_inputs[key].append(inputs[key])
            
    for key in batch_inputs:
        batch_inputs[key] = torch.cat(batch_inputs[key], dim=0)
        
    return batch_inputs, targets

num_epochs = 10

def accuracy_fn(logits, labels):
    pred = torch.argmax(logits, dim=-1)
    return torch.mean((pred == labels).to(torch.float32))

for t in range(T):
    train_dataset.set_task(t)
    test_dataset.set_task(t)
    
    subset_classes_names = train_dataset.subset_classes_names
        
    train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=lambda x : collate_wrapper(x, subset_classes_names))
    test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=True, collate_fn=lambda x : collate_wrapper(x, subset_classes_names))
        
        
    ## RESET OF THE POLICY    
    policy = Policy(base_params, gpu=device, decomposed_params=decomposed_params, mode=1)
    learnable_params = policy.get_learnable_params()

    model.owlvit.vision_model = apply_policy_to_model(model.owlvit.vision_model, policy, base_params, decomposed_params, learnable_params)

    for k, p in model.owlvit.vision_model.named_parameters():
        model.owlvit.vision_model.get_parameter(k).requires_grad_(False)

    # # set the learnable params to training
    for k in learnable_params:
        model.owlvit.vision_model.get_parameter(k).requires_grad_(True)
        
    optimizer = torch.optim.Adam(policy.trainable_params, lr=1e-3)

    ################# FIX ##############
    optimizer.zero_grad()                # clears grads for policy params
    model.owlvit.zero_grad(set_to_none=True)    # clears grads for backbone + head
    ###################################

    for epoch in range(num_epochs):
        avg_iou = 0
        avg_giou = 0
        avg_acc = 0
        num_rows = 0
        for inputs, targets in train_dataloader:
            for key in inputs:
                inputs[key] = inputs[key].to(device)
             
            outputs = model(**inputs)

            bs, num_queries = outputs["logits"].shape[:2]

            tgt_ids = torch.cat([v["labels"] for v in targets]).to(device)
            tgt_bbox = torch.cat([v["boxes"] for v in targets]).to(device)
            
            pred_logits = outputs["logits"] # (batch, num_objects, num_queries)
            pred_boxes = outputs["pred_boxes"] # (batch, num_objects, 4)

            with torch.no_grad():
                cost_class = classification_loss_fn(tgt_ids, pred_logits)
                cost_bbox, cost_giou = box_loss_fn(tgt_bbox, pred_boxes)

                C = _cost_bbox * cost_bbox + _cost_class * cost_class + _cost_giou * cost_giou
                C = C.view(bs, num_queries, -1).cpu()
                
                
                sizes = [len(v["boxes"]) for v in targets]
                indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
                indices = [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] # [batch_size, 2, num_targets]
                    
            src_logits = torch.cat([pred_logits[i, idx[0], :] for i, idx in enumerate(indices)]) # [batch * num_targets, 2]
            trg_labels = torch.cat([targets[i]["labels"][idx[1]] for i, idx in enumerate(indices)]).to(src_logits.device) # [batch * num_targets]

            src_boxes = torch.cat([pred_boxes[i, idx[0], :] for i, idx in enumerate(indices)]) # [batch * num_targets, 2]
            trg_bbox = torch.cat([targets[i]["boxes"][idx[1],:] for i, idx in enumerate(indices)]).to(src_logits.device)

            # objection loss
            loss_bbox = torch.nn.functional.l1_loss(src_boxes, trg_bbox) # L1 loss
            
            giou = generalized_box_iou(box_cxcywh_to_xyxy(src_boxes),
                                    box_cxcywh_to_xyxy(trg_bbox), all_pairs=False).mean()
            iou, _ = box_iou(box_cxcywh_to_xyxy(src_boxes),
                            box_cxcywh_to_xyxy(trg_bbox), all_pairs=False)
            iou = iou.mean().detach()

            loss_giou = 1 - giou      
            
            # classification loss
            loss_class = torch.nn.functional.cross_entropy(src_logits, 
                                                            trg_labels)

            loss = loss_class + loss_bbox + loss_giou

            # Backward
            optimizer.zero_grad()
            loss.backward()
            backward(policy, model.owlvit.vision_model, base_params, decomposed_params, learnable_params)
            optimizer.step()
            
            acc = accuracy_fn(src_logits, trg_labels)

            # Compose current masked weights
            model.owlvit.vision_model = apply_policy_to_model(model.owlvit.vision_model, policy, base_params, decomposed_params, learnable_params)

            # for j in range(src_logits.shape[0]):
            #     pred_label = torch.argmax(src_logits[j, :].detach().cpu())
            #     pred_conf = torch.max(src_logits[j, :].detach().cpu()).sigmoid()
            #     pred_box = src_boxes[j]
                
            #     print("target_boxes", trg_bbox.detach(), "pred boxes", pred_box.detach(), "pred_label", int(pred_label), "pred_conf", pred_conf)
            # print(f"Acc: {acc.item():.4f} | Total Loss: {loss.item():.4f} | Class Loss: {loss_class.item():.4f} | Box Loss: {loss_bbox.item():.4f} | Giou Loss: {loss_giou.item():.4f}")

            avg_iou += iou.item()
            avg_giou += giou.item()
            avg_acc += acc.item()
            num_rows += 1
            
            print(f"(TRAINING) Epoch: {epoch}/{num_epochs} | Acc: {acc.item():.4f} | IOU: {iou.item():.4f} | GIOU: {giou.item():.4f} " \
                + f"| C_Acc: {avg_acc/num_rows:.4f} | C_IOU: {avg_iou/num_rows:.4f} | C_GIOU: {avg_giou/num_rows:.4f}")
        

        if epoch % 5 == 0:
            torch.save(policy, f"weights/{task_name}_task_num_{t}_{epoch}_{avg_acc/num_rows}_{avg_giou/num_rows}_{avg_iou/num_rows}_training")
            
        with torch.no_grad():
            avg_iou = 0
            avg_giou = 0
            avg_acc = 0
            num_rows = 0            
            for inputs, targets in test_dataloader:
                for key in inputs:
                    inputs[key] = inputs[key].to(device)
                
                outputs = model(**inputs)

                bs, num_queries = outputs["logits"].shape[:2]

                tgt_ids = torch.cat([v["labels"] for v in targets]).to(device)
                tgt_bbox = torch.cat([v["boxes"] for v in targets]).to(device)
                
                pred_logits = outputs["logits"] # (batch, num_objects, num_queries)
                pred_boxes = outputs["pred_boxes"] # (batch, num_objects, 4)

                cost_class = classification_loss_fn(tgt_ids, pred_logits)
                cost_bbox, cost_giou = box_loss_fn(tgt_bbox, pred_boxes)

                C = _cost_bbox * cost_bbox + _cost_class * cost_class + _cost_giou * cost_giou
                C = C.view(bs, num_queries, -1).cpu()
                
                
                sizes = [len(v["boxes"]) for v in targets]
                indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
                indices = [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] # [batch_size, 2, num_targets]
                        
                src_logits = torch.cat([pred_logits[i, idx[0], :] for i, idx in enumerate(indices)]) # [batch * num_targets, 2]
                trg_labels = torch.cat([targets[i]["labels"][idx[1]] for i, idx in enumerate(indices)]).to(src_logits.device) # [batch * num_targets]

                src_boxes = torch.cat([pred_boxes[i, idx[0], :] for i, idx in enumerate(indices)]) # [batch * num_targets, 2]
                trg_bbox = torch.cat([targets[i]["boxes"][idx[1],:] for i, idx in enumerate(indices)]).to(src_logits.device)
               
                giou = generalized_box_iou(box_cxcywh_to_xyxy(src_boxes),
                        box_cxcywh_to_xyxy(trg_bbox), all_pairs=False).mean()
                iou, _ = box_iou(box_cxcywh_to_xyxy(src_boxes),
                                box_cxcywh_to_xyxy(trg_bbox), all_pairs=False)
                iou = iou.mean().detach()
              
                acc = accuracy_fn(src_logits, trg_labels)
                
                avg_iou += iou.item()
                avg_giou += giou.item()
                avg_acc += acc.item()
                num_rows += 1
                
                print(f"(TESTING) Epoch: {epoch}/{num_epochs} | Acc: {acc.item():.4f} | IOU: {iou.item():.4f} | GIOU: {giou.item():.4f} " \
                    + f"| C_Acc: {avg_acc/num_rows:.4f} | C_IOU: {avg_iou/num_rows:.4f} | C_GIOU: {avg_giou/num_rows:.4f}")
                                   
        if epoch % 5 == 0:
            torch.save(policy, f"weights/{task_name}_task_num_{t}_{epoch}_{avg_acc/num_rows}_{avg_giou/num_rows}_{avg_iou/num_rows}_testing")
            