# Real off-the-shelf tied linear module
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
import math 
import random
from tqdm import tqdm 
from diffusers import StableDiffusionPipeline
from transformers import CLIPTokenizer

class TiedUNET(nn.Module):
    def __init__(self, num_used_token,tied_net,input_dim=32,output_dim=1000, indicies_tied=None):
        super().__init__()
        self.output_dim = num_used_token
        tied_params_count = 0
        if indicies_tied is None:
            ratio = 0.0065
            indicies_tied = []
            for name, params in tied_net.named_parameters():
                flattened_params = params.view(-1).detach()
                shape_length = len(flattened_params)
                print("name:{} params_num:{}".format(name,shape_length))
                indices = random.sample(range(1, shape_length), int(shape_length*ratio))
                indicies_tied.append(indices)
                tied_params_count += len(indices)
                #params_tied = flattened_params[indices]
                #weights_tied.append(params_tied)
                r2 = 1/math.sqrt(768) 
                r1 = -1/math.sqrt(768)
                flattened_params[indices] = (r1 - r2) * torch.rand_like(flattened_params[indices]) + r2 #torch.randn_like(flattened_params[indices])
                new_tensor = flattened_params.reshape(params.shape)
                tied_net.state_dict()[name].copy_(new_tensor)
            print("tied_params_count:{}".format(tied_params_count))
        else:
            print("indices successfully loaded....")
        self.linear_weight_list = [2048*1024,1024*self.output_dim]
        self.used_params = 2048*1024+1024*self.output_dim
        self.indicies_tied = indicies_tied
        self.tied_net = tied_net 
        self.tied_params_count = tied_params_count


    def get_weight(self,fit=True):
        weight_tied = []
        #print("getting weight")
        for i, (name, params) in enumerate(self.tied_net.named_parameters()):
            flattened_params = params.view(-1)
            params_tied = flattened_params[self.indicies_tied[i]]
            weight_tied.append(params_tied)
        weight_tied = torch.cat(weight_tied,dim=0)
        return weight_tied[:self.used_params] if fit else weight_tied
    
    def set_weight(self,weight):
        count = 0
        for i, (name, params) in enumerate(self.tied_net.named_parameters()):
            flattened_params = params.view(-1).detach()
            params_tied = flattened_params[self.indicies_tied[i]]
            shape_length = len(params_tied)
            flattened_params[self.indicies_tied[i]] = weight[count:count+shape_length].to(device=params.device,dtype=params.dtype)
            new_tensor = flattened_params.reshape(params.shape)
            self.tied_net.state_dict()[name].copy_(new_tensor)
            count += shape_length
        print("Weight is set successfully")

    def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
        weight_tied = self.get_weight().to(dtype=input_tensor.dtype)

        linear1 = weight_tied[:self.linear_weight_list[0]]
        linear1 = linear1.reshape(2048,1024)
        #print("input:{} linear:{}".format(input_tensor.device,linear1.device))
        x = input_tensor @ linear1 ##nn.Linear(32,100)

        prev_len = self.linear_weight_list[0]
        linear2 = weight_tied[prev_len:prev_len+self.linear_weight_list[1]]
        linear2 = linear2.reshape(1024,self.output_dim)
        x = F.relu(x)
        x = x @ linear2
       
        return x

if __name__ == '__main__':
    
  
    model_id = "CompVis/stable-diffusion-v1-4"
    device = torch.device("cuda:7")
    checkpoint = torch.load('pretrained_linear.pyt', map_location='cpu')
    w1 = checkpoint['w1'].t().reshape(-1)
    w2 = checkpoint['w2'].t().reshape(-1)
    print("w1:{} w2:{}".format(w1.shape,w2.shape))

    unique_embeddings = checkpoint['unique_emb'].to(device)
    load_retrieved_weight = torch.cat([w1.view(-1),w2.view(-1)],dim=0)
    print("w1:{} w2:{} weight:{}".format(w1.shape,w2.shape,load_retrieved_weight.shape))
  
    ################ CAPTION DATA ################
    checkpoint = torch.load('caption_used_data.pyt', map_location='cpu')
    caption_data = checkpoint['data'].to(device)
    USEDINDEX_TO_INDEX = checkpoint['USEDINDEX_TO_INDEX']
    CAPTION_VOCAB_USED_LEN = checkpoint['VOCAB_USED_LEN']
    print("VOCAB LEN:{} DATA:{}".format(CAPTION_VOCAB_USED_LEN,caption_data.shape))


    pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
    pipe = pipe.to(device)
    net_caption = TiedUNET(num_used_token=2987,tied_net=pipe.unet,indicies_tied=None).to(device)

    zero_pad = torch.zeros(size=(1,net_caption.tied_params_count-len(load_retrieved_weight))).view(-1)
    load_retrieved_weight = torch.cat([load_retrieved_weight,zero_pad],dim=0)
    print("padded weight:{}".format(load_retrieved_weight.shape))

    #weight_retrieved = net_caption.get_weight(fit=False)
    #print("weight_retrieved:{}".format(weight_retrieved.shape))

    net_caption.set_weight(weight=load_retrieved_weight)
    prompt = "a photo of cat"
    image = pipe(prompt).images[0]  
    image.save("astronaut_rides_horse.png")

    ###### EVAL CAPTION #####

    total_length  = caption_data.shape[0] * CAPTION_VOCAB_USED_LEN
    topk = 15
    tokenizer = CLIPTokenizer.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="tokenizer")

    with torch.no_grad():
        pred = net_caption(unique_embeddings)
        total_correct = int((caption_data==pred.round()).sum())
        accuracy = total_correct/total_length
        for t in range(len(pred)):
            prompt_ids = caption_data[t]
            decoded_ids = torch.clip(pred[t],min=torch.tensor(0).to(pred.device),max=torch.tensor(1).to(pred.device))
            num_token = decoded_ids.sum()
            decoded_ids = torch.topk(decoded_ids,topk).indices
            decoded_ids = decoded_ids.squeeze().tolist()
            decoded_ids = [USEDINDEX_TO_INDEX[p] for p in decoded_ids]
            prompt_ids = torch.where(prompt_ids == 1)[0]
            prompt_ids  = prompt_ids.squeeze().tolist()
            prompt_ids = [USEDINDEX_TO_INDEX[p] for p in prompt_ids]

            decoded_string = [tokenizer._convert_id_to_token(i).strip().replace('</w>','') for i in decoded_ids]
            decoded_string = ' '.join(decoded_string)
            prompt_string = [tokenizer._convert_id_to_token(i).strip().replace('</w>','') for i in prompt_ids]
            prompt_string = ' '.join(prompt_string)
            print("sum:{} decoded_string:{} \n prompt_string:{}".format(num_token,decoded_string,prompt_string))
            if t > 10:
                break
        
        print("Accuracy:",accuracy)


