'''
@File       :   ImageReward.py
@Time       :   2023/02/28 19:53:00
@Auther     :   Jiazheng Xu
@Contact    :   xjz22@mails.tsinghua.edu.cn
@Description:   ImageReward Reward model for reward model.
'''

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import clip
from PIL import Image
from config.options import *
from config.utils import *
from models.clip_pretrain import clip_pretrain
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from transformers import CLIPTextModel, CLIPTokenizer, logging, CLIPVisionModel, CLIPFeatureExtractor
from torch.nn import TransformerEncoder, TransformerEncoderLayer

try:
    from torchvision.transforms import InterpolationMode
    BICUBIC = InterpolationMode.BICUBIC
except ImportError:
    BICUBIC = Image.BICUBIC


def _convert_image_to_rgb(image):
    return image.convert("RGB")


def _transform(n_px):
    return Compose([
        Resize(n_px, interpolation=BICUBIC),
        # CenterCrop(n_px),
        _convert_image_to_rgb,
        ToTensor(),
        Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
    ])


class MLP(nn.Module):
    def __init__(self, input_size):
        super().__init__()
        self.input_size = input_size
        
        self.layers = nn.Sequential(
            nn.Linear(self.input_size, 512),
            # nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 256)
            # # nn.ReLU(),
            # nn.Dropout(0.2),
            # nn.Linear(256, 128),
            # nn.ReLU(),
            # nn.Dropout(0.1),
            # nn.Linear(128, 64)
            # nn.ReLU(),
            # nn.Linear(64, 1)
        )
        self.last_layer = nn.Linear(256, 1, bias=False)
        self.last_layer_weight = self.last_layer.weight

        # initial MLP param
        for name, param in self.layers.named_parameters():
            if 'weight' in name:
                nn.init.normal_(param, mean=0.0, std=1.0/(self.input_size+1))
            if 'bias' in name:
                nn.init.constant_(param, val=0)

        for name, param in self.last_layer.named_parameters():
            if 'weight' in name:
                nn.init.normal_(param, mean=0.0, std=1.0/(self.input_size+1))
            if 'bias' in name:
                nn.init.constant_(param, val=0)
        
        
    def forward(self, input):
        
        feature = self.layers(input)
        out = self.last_layer(feature)
        w = self.last_layer.weight
        return out, w

class ViTBlock(nn.Module):
    def __init__(self, feature_dim, num_heads, mlp_dim, dropout=0.1):
        super(ViTBlock, self).__init__()
        # Transformer encoder layer
        self.encoder_layer = TransformerEncoderLayer(
            d_model=feature_dim,
            nhead=num_heads,
            dim_feedforward=mlp_dim,
            dropout=dropout,
            batch_first=True  # Input shape: (batch_size, seq_length, feature_dim)
        )
        self.transformer_encoder = TransformerEncoder(self.encoder_layer, num_layers=1)

    def forward(self, x):
        x = self.transformer_encoder(x)
        return x

class ImageReward(nn.Module):
    def __init__(self, device='cpu'):
        super().__init__()
        self.device = device
        
        self.clip_model, self.preprocess = clip.load("ViT-B/32", device="cuda") #clip.load(config['clip_model'], device="cuda" if torch.cuda.is_available() else "cpu")
        
        self.clip_model = self.clip_model.float()
        self.mlp = MLP(config['ImageReward']['mlp_dim'])
        self.vit_block = ViTBlock(config["ViT"]["feature_dim"], config["ViT"]["num_heads"], config["ViT"]["mlp_dim"])
        
        if opts.fix_base:
            self.clip_model.requires_grad_(False)
        
        for name, parms in self.clip_model.named_parameters():
            if '_proj' in name:
                parms.requires_grad_(False)
        
        # fix certain ratio of layers
        self.image_layer_num = 12
        if opts.fix_rate > 0:
            image_fix_num = "resblocks.{}".format(int(self.image_layer_num * opts.fix_rate))
            for name, parms in self.clip_model.visual.named_parameters():
                parms.requires_grad_(False)
                if image_fix_num in name:
                    break


    def loose_layer(self, fix_rate):
        text_layer_id = [f"layer.{id}" for id in range(int(12 * fix_rate), 13)]
        image_layer_id = [f"blocks.{id}" for id in range(int(24 * fix_rate), 25)]
        for name, parms in self.blip.text_encoder.named_parameters():
            for text_id in text_layer_id:
                if text_id in name:
                    parms.requires_grad_(True)
        for name, parms in self.blip.visual_encoder.named_parameters():
            for image_id in image_layer_id:
                if image_id in name:
                    parms.requires_grad_(True)


    def forward(self, batch_data):
        
        emb_inpaint, emb_mask_rgb = self.encode_pair(batch_data) # Nan
        # forward
        emb_feature = torch.cat((emb_inpaint, emb_mask_rgb), dim=-1)
        emb_feature = self.vit_block(emb_feature) # 1024

        score, w = self.mlp(emb_feature)
           
        return score, w


    def encode_pair(self, batch_data):
        inpaint_embeds_bs, mask_rgb_embeds_bs = [], []
        for bs in range(len(batch_data)):
            inpaint, mask_rgb = batch_data[bs]['inpaint'], batch_data[bs]['mask_rgb']
            inpaint, mask_rgb = inpaint.to(self.device), mask_rgb.to(self.device)
    
            # with torch.no_grad():
            inpaint_embeds = self.clip_model.encode_image(inpaint).to(torch.float32)
            mask_rgb_embeds = self.clip_model.encode_image(mask_rgb).to(torch.float32)

            inpaint_embeds_bs.append(inpaint_embeds)
            mask_rgb_embeds_bs.append(mask_rgb_embeds)
           
        return torch.stack(inpaint_embeds_bs, dim=0), torch.stack(mask_rgb_embeds_bs, dim=0)

class ImageRewardGroup(nn.Module):
    def __init__(self, device='cpu'):
        super().__init__()
        self.device = device
        
        self.clip_model, self.preprocess = clip.load("ViT-B/32", device="cuda") #clip.load(config['clip_model'], device="cuda" if torch.cuda.is_available() else "cpu")
        
        self.clip_model = self.clip_model.float()
        self.mlp = MLP(config['ImageReward']['mlp_dim'])
        self.vit_block = ViTBlock(config["ViT"]["feature_dim"], config["ViT"]["num_heads"], config["ViT"]["mlp_dim"])
        
        if opts.fix_base:
            self.clip_model.requires_grad_(False)
        
        for name, parms in self.clip_model.named_parameters():
            if '_proj' in name:
                parms.requires_grad_(False)
        
        # fix certain ratio of layers
        self.image_layer_num = 12
        if opts.fix_rate > 0:
            image_fix_num = "resblocks.{}".format(int(self.image_layer_num * opts.fix_rate))
            for name, parms in self.clip_model.visual.named_parameters():
                parms.requires_grad_(False)
                if image_fix_num in name:
                    break


    def loose_layer(self, fix_rate):
        text_layer_id = [f"layer.{id}" for id in range(int(12 * fix_rate), 13)]
        image_layer_id = [f"blocks.{id}" for id in range(int(24 * fix_rate), 25)]
        for name, parms in self.blip.text_encoder.named_parameters():
            for text_id in text_layer_id:
                if text_id in name:
                    parms.requires_grad_(True)
        for name, parms in self.blip.visual_encoder.named_parameters():
            for image_id in image_layer_id:
                if image_id in name:
                    parms.requires_grad_(True)


    def forward(self, batch_data):
        
        b_emb_inpt, b_emb_msk, w_emb_inpt, w_emb_msk = self.encode_pair(batch_data) # Nan
        # forward
        b_emb_feature = torch.cat((b_emb_inpt, b_emb_msk), dim=-1)
        b_emb_feature = self.vit_block(b_emb_feature) # 1024
        w_emb_feature = torch.cat((w_emb_inpt, w_emb_msk), dim=-1)
        w_emb_feature = self.vit_block(w_emb_feature) # 1024

        reward_better, _ = self.mlp(b_emb_feature)
        reward_better = reward_better.squeeze(-1)
        reward_worse, _ = self.mlp(w_emb_feature)
        reward_worse = reward_worse.squeeze(-1)
        reward = torch.concat((reward_better, reward_worse), dim=1)
        # w = torch.concat((w1, w2), dim=1)
       
        return reward, self.mlp.last_layer_weight


    def encode_pair(self, batch_data):
        better_inpaint_embeds_bs, better_mask_rgb_embeds_bs = [], []
        worse_inpaint_embeds_bs, worse_mask_rgb_embeds_bs = [], []
        for bs in range(len(batch_data)):
            better_inpt, better_msk = batch_data[bs]['better_inpt'], batch_data[bs]['better_msk']
            better_inpt, better_msk = better_inpt.to(self.device), better_msk.to(self.device)

            worse_inpt, worse_msk = batch_data[bs]['worse_inpt'], batch_data[bs]['worse_msk']
            worse_inpt, worse_msk = worse_inpt.to(self.device), worse_msk.to(self.device)
            # with torch.no_grad():
            better_inpaint_embeds = self.clip_model.encode_image(better_inpt).to(torch.float32)
            better_mask_rgb_embeds = self.clip_model.encode_image(better_msk).to(torch.float32)
            worse_inpaint_embeds = self.clip_model.encode_image(worse_inpt).to(torch.float32)
            worse_mask_rgb_embeds = self.clip_model.encode_image(worse_msk).to(torch.float32)

            better_inpaint_embeds_bs.append(better_inpaint_embeds)
            better_mask_rgb_embeds_bs.append(better_mask_rgb_embeds)
            worse_inpaint_embeds_bs.append(worse_inpaint_embeds)
            worse_mask_rgb_embeds_bs.append(worse_mask_rgb_embeds)

            b_inpt = torch.stack(better_inpaint_embeds_bs, dim=0)
            b_msk = torch.stack(better_mask_rgb_embeds_bs, dim=0)
            w_inpt = torch.stack(worse_inpaint_embeds_bs, dim=0)
            w_msk = torch.stack(worse_mask_rgb_embeds_bs, dim=0)
        
       
        return b_inpt, b_msk, w_inpt, w_msk

