# Real off-the-shelf tied linear module
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
import math 
import random
from tqdm import tqdm 
from diffusers import StableDiffusionPipeline
from transformers import CLIPTokenizer

class TiedUNET(nn.Module):
    def __init__(self, num_used_token,tied_net,input_dim=32,output_dim=1000, indicies_tied=None):
        super().__init__()
        self.output_dim = num_used_token
        tied_params_count = 0
        if indicies_tied is None:
            ratio = 0.0003
            indicies_tied = []
            for name, params in tied_net.named_parameters():
                flattened_params = params.view(-1).detach()
                shape_length = len(flattened_params)
                print("name:{} params_num:{}".format(name,shape_length))
                indices = random.sample(range(1, shape_length), int(shape_length*ratio))
                indicies_tied.append(indices)
                tied_params_count += len(indices)
                #params_tied = flattened_params[indices]
                #weights_tied.append(params_tied)
                r2 = 1/math.sqrt(768) 
                r1 = -1/math.sqrt(768)
                flattened_params[indices] = (r1 - r2) * torch.rand_like(flattened_params[indices]) + r2 #torch.randn_like(flattened_params[indices])
                new_tensor = flattened_params.reshape(params.shape)
                tied_net.state_dict()[name].copy_(new_tensor)
            print("tied_params_count:{}".format(tied_params_count))
        else:
            print("indices successfully loaded....")
        self.linear_weight_list = [128*256,256*self.output_dim]
        self.used_params = 128*256+256*self.output_dim
        self.indicies_tied = indicies_tied
        self.tied_net = tied_net 
        self.tied_params_count = tied_params_count


    def get_weight(self,fit=True):
        weight_tied = []
        #print("getting weight")
        for i, (name, params) in enumerate(self.tied_net.named_parameters()):
            flattened_params = params.view(-1)
            params_tied = flattened_params[self.indicies_tied[i]]
            weight_tied.append(params_tied)
        weight_tied = torch.cat(weight_tied,dim=0)
        return weight_tied[:self.used_params] if fit else weight_tied
    
    def set_weight(self,weight):
        count = 0
        for i, (name, params) in enumerate(self.tied_net.named_parameters()):
            flattened_params = params.view(-1).detach()
            params_tied = flattened_params[self.indicies_tied[i]]
            shape_length = len(params_tied)
            flattened_params[self.indicies_tied[i]] = weight[count:count+shape_length].to(device=params.device,dtype=params.dtype)
            new_tensor = flattened_params.reshape(params.shape)
            self.tied_net.state_dict()[name].copy_(new_tensor)
            count += shape_length
        print("Weight is set successfully")

    def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
        weight_tied = self.get_weight().to(dtype=input_tensor.dtype)

        linear1 = weight_tied[:self.linear_weight_list[0]]
        linear1 = linear1.reshape(128,256)
        #print("input:{} linear:{}".format(input_tensor.device,linear1.device))
        x = input_tensor @ linear1 ##nn.Linear(32,100)

        prev_len = self.linear_weight_list[0]
        linear2 = weight_tied[prev_len:prev_len+self.linear_weight_list[1]]
        linear2 = linear2.reshape(256,self.output_dim)
        x = F.relu(x)
        x = x @ linear2
       
        return x