import torch
import torch.nn as nn
from tqdm import tqdm
from transformers import CLIPTokenizer
import random 
import math 
class Network(torch.nn.Module):
    def __init__(self,output_dim):
        super().__init__()
        self.net1 = nn.Linear(in_features=128,out_features=256).to(device)
        self.net2 = nn.Linear(in_features=256,out_features=output_dim).to(device)
        self.relu = nn.ReLU()
        #self.initialize()
        '''
        self.model = nn.Sequential(
            nn.Linear(2048,1024,bias=False),
            nn.ReLU(),
            nn.Linear(1024,output_dim,bias=False)
            )
        '''

    def initialize(self):
        nn.init.uniform_(self.net1.weight,a=-1/math.sqrt(768),b=1/math.sqrt(768)) 
        nn.init.uniform_(self.net1.bias,a=-1/math.sqrt(768),b=1/math.sqrt(768)) 
        nn.init.uniform_(self.net2.weight,a=-1/math.sqrt(384),b=1/math.sqrt(384)) 
        nn.init.uniform_(self.net2.bias,a=-1/math.sqrt(384),b=1/math.sqrt(384)) 
        print("Initialized")

    def initialize2(self):
        nn.init.uniform_(self.net1.weight,a=0/math.sqrt(768),b=1/math.sqrt(768)) 
        nn.init.uniform_(self.net1.bias,a=0/math.sqrt(768),b=1/math.sqrt(768)) 
        nn.init.uniform_(self.net2.weight,a=0/math.sqrt(384),b=1/math.sqrt(768)) 
        nn.init.uniform_(self.net2.bias,a=0/math.sqrt(384),b=1/math.sqrt(768)) 
        print("Initialized")

    def forward(self, x):
        x = self.net1(x)
        x = self.relu(x)
        x = self.net2(x)
        #x = self.model(x)
        return x

if __name__ =='__main__':
    
    device = torch.device('cuda:2')
    tokenizer = CLIPTokenizer.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="tokenizer")

    ################ CAPTION DATA ################
    checkpoint = torch.load('caption_used_data.pyt', map_location='cpu')
    caption_data = checkpoint['data'].to(device)
    USEDINDEX_TO_INDEX = checkpoint['USEDINDEX_TO_INDEX']
    CAPTION_VOCAB_USED_LEN = checkpoint['VOCAB_USED_LEN']
    print("VOCAB LEN:{} DATA:{}".format(CAPTION_VOCAB_USED_LEN,caption_data.shape))

    net = Network(output_dim=CAPTION_VOCAB_USED_LEN).to(device)
    pytorch_total_params = sum(p.numel() for p in net.parameters())
    print("model parameter:",pytorch_total_params)
    seed = 0
    num_embeddings = caption_data.shape[0]
    embeddings_dim = 128
    unique_embeddings = torch.rand(num_embeddings, embeddings_dim, generator=torch.Generator().manual_seed(seed)) * 2 - 1
    unique_embeddings = unique_embeddings.to(device)
    print("unique_embeddings:{}".format(unique_embeddings.shape))
    
    criterion = nn.MSELoss()
    criterion2 = nn.L1Loss()
    criterion3 = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=5e-6, weight_decay=0)
    total_epoch = 2000000
    total_length  = num_embeddings * CAPTION_VOCAB_USED_LEN
    topk = 10
    bs = 32
    for epoch in tqdm(range(total_epoch)):
        optimizer.zero_grad()
        indices = random.sample(range(1, num_embeddings), bs)

        pred = net(unique_embeddings[indices])
        loss = criterion(caption_data[indices].float(),pred.float()).mean()
        loss_l1 = criterion2(caption_data[indices].float(),pred.float()).mean()
        loss =  loss 
        loss.backward()
        optimizer.step()

        if epoch % 100 == 0:
            with torch.no_grad():
                pred = net(unique_embeddings)
                total_correct = int((caption_data==pred.round()).sum())
                accuracy = total_correct/total_length
                for t in range(len(pred)):
                    prompt_ids = caption_data[t]
                    decoded_ids = torch.clip(pred[t],min=torch.tensor(0).to(pred.device),max=torch.tensor(1).to(pred.device))
                    num_token = decoded_ids.sum()
                    decoded_ids = torch.topk(decoded_ids,topk).indices
                    decoded_ids = decoded_ids.squeeze().tolist()
                    decoded_ids = [USEDINDEX_TO_INDEX[p] for p in decoded_ids]
                    prompt_ids = torch.where(prompt_ids == 1)[0]
                    prompt_ids  = prompt_ids.squeeze().tolist()
                    prompt_ids = [USEDINDEX_TO_INDEX[p] for p in prompt_ids]

                    decoded_string = [tokenizer._convert_id_to_token(i).strip().replace('</w>','') for i in decoded_ids]
                    decoded_string = ' '.join(decoded_string)
                    prompt_string = [tokenizer._convert_id_to_token(i).strip().replace('</w>','') for i in prompt_ids]
                    prompt_string = ' '.join(prompt_string)
                    print("sum:{} decoded_string:{} \n prompt_string:{}".format(num_token,decoded_string,prompt_string))
                    if t > 10:
                        break
                
                print("Accuracy:",accuracy)
        if epoch % 10000 == 0:
            checkpoint = {'w1': net.net1.weight,'w2':net.net2.weight,'unique_emb':unique_embeddings}
            torch.save(checkpoint,'pretrained_linear.pyt')




    




