{"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"taoI32r2BGXi"},"outputs":[],"source":["import torch\n","import torch.distributions as D\n","import numpy as np\n","import math\n","from torchinfo import summary\n","import random\n","\n","device = 'cuda'\n","\n","seed = 1000\n","torch.manual_seed(seed)\n","np.random.seed(seed)\n","random.seed(seed)\n","torch.cuda.manual_seed(seed)"]},{"cell_type":"markdown","metadata":{"id":"SEr7X_9tdQsy"},"source":["### Dataset setup exchangeable"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":6,"status":"ok","timestamp":1697225611252,"user":{"displayName":"Liyi Zhang","userId":"06658486172350999592"},"user_tz":240},"id":"1aYxRpBEeVas","outputId":"8a5b34dc-b157-4648-bc48-0f9132658e6d"},"outputs":[],"source":["V = 1000 # vocab size\n","K = 5 # num of topics\n","N = 12000 # num of documents\n","M = 100 # num of words in each doc\n","alpha = 0.5\n","\n","# The smaller the Dirichlet parameter, the less even is the simplex spread\n","dcht_V = D.dirichlet.Dirichlet(torch.zeros([V])+0.5) # dist over words\n","dcht_K = D.dirichlet.Dirichlet(torch.zeros([K])+alpha) # dist over topics\n","\n","topic_vecs = dcht_V.sample([K])\n","print(topic_vecs.shape)"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":233717,"status":"ok","timestamp":1696889156747,"user":{"displayName":"Liyi Zhang","userId":"06658486172350999592"},"user_tz":240},"id":"MIzVZeMyeavc","outputId":"1bc84831-dd7c-47ad-b556-a5394200dc65"},"outputs":[],"source":["dataset = torch.zeros([N, M], dtype=torch.long)\n","topic_mixtures = torch.zeros([N, K], dtype=torch.float32)\n","\n","for n in range(N):\n","\n","    # draw topic proportion\n","    theta = torch.squeeze(dcht_K.sample([1]))\n","    topic_mixtures[n,:] = theta\n","    # draw topic assignment for all words in document n\n","    dist_topic = D.categorical.Categorical(theta)\n","    topic_assignments = dist_topic.sample([M])\n","\n","    for m in range(M):\n","\n","        k = topic_assignments[m]\n","        # draw word\n","        topic_vec = topic_vecs[k]\n","        word = D.categorical.Categorical(topic_vec).sample([1])\n","        word = torch.squeeze(word)\n","        # word = torch.squeeze(torch.nn.functional.one_hot(word, num_classes=V))\n","        dataset[n,m] = word\n","\n","    if n % 1000 == 0:\n","        print(n)"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"vfdVkEFfA_PO"},"outputs":[],"source":["np.savetxt('../data/dataset_005_N10000_V1000_ID1.csv', dataset.cpu().numpy())\n","np.savetxt('../data/topic_mixtures_005_N10000_V1000_ID1.csv', topic_mixtures.cpu().numpy())\n","np.savetxt('../data/topic_vecs_005_N10000_V1000_ID1.csv', topic_vecs.cpu().numpy())"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"JgQdOhFOdVHo"},"outputs":[],"source":["dataset = torch.tensor(np.genfromtxt('../data/dataset_005_N10000_V1000_ID1.csv'), dtype=torch.int64)\n","topic_mixtures = torch.tensor(np.genfromtxt('../data/topic_mixtures_005_N10000_V1000_ID1.csv'))\n","topic_vecs = torch.tensor(np.genfromtxt('../data/topic_vecs_005_N10000_V1000_ID1.csv'))"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":6,"status":"ok","timestamp":1697225614500,"user":{"displayName":"Liyi Zhang","userId":"06658486172350999592"},"user_tz":240},"id":"SopIvrxedVMF","outputId":"8a9675c9-77b2-415c-9c39-7ae692593b31"},"outputs":[],"source":["val_idx = 10000\n","test_idx = 11000\n","len(dataset)"]},{"cell_type":"markdown","metadata":{"id":"bKRKDy7mX2eG"},"source":["### LLM"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"mTGUFTHDX2eG"},"outputs":[],"source":["START_IDX = 100\n","END_IDX = 101\n","PAD_IDX = 102\n","\n","def generate_square_subsequent_mask(sz, device='cpu'):\n","    mask = (torch.triu(torch.ones((sz, sz), device=device)) == 1).transpose(0, 1)\n","    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))\n","    return mask\n","\n","def create_mask(src, tgt, device='cpu'):\n","    src_seq_len = src.shape[1]\n","    tgt_seq_len = tgt.shape[1]\n","\n","    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)\n","    src_mask = generate_square_subsequent_mask(src_seq_len)\n","    #src_mask = torch.zeros((src_seq_len, src_seq_len),device=device).type(torch.bool)\n","\n","    src_padding_mask = (src == 999999)\n","    tgt_padding_mask = (tgt == 999999)\n","    return src_mask.to(torch.device(device)), tgt_mask.to(torch.device(device)), src_padding_mask.to(torch.device(device)), tgt_padding_mask.to(torch.device(device))\n","\n","# src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(torch.zeros((3,10)), torch.zeros((3,11)), device=device)"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"O4YulbXR7vv3"},"outputs":[],"source":["from torch.nn import TransformerEncoder, TransformerEncoderLayer\n","\n","class PositionalEncoding(torch.nn.Module):\n","\n","    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 500):\n","        super().__init__()\n","        self.dropout = torch.nn.Dropout(p=dropout)\n","\n","        position = torch.arange(max_len).unsqueeze(1)\n","        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))\n","        pe = torch.zeros(1, max_len, d_model)\n","        pe[0, :, 0::2] = torch.sin(position * div_term)\n","        pe[0, :, 1::2] = torch.cos(position * div_term)\n","        self.register_buffer('pe', pe)\n","\n","    def forward(self, x):\n","\n","        x = x + self.pe[:,:x.size(1),:]\n","        return self.dropout(x)\n","\n","class TransformerModel(torch.nn.Module):\n","\n","    def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int,\n","                 nlayers: int, dropout: float = 0.5, use_pos = False):\n","        super().__init__()\n","        self.model_type = 'Transformer'\n","        encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout, batch_first=True)\n","        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)\n","        self.embedding = torch.nn.Embedding(ntoken, d_model)\n","        self.d_model = d_model\n","        self.linear = torch.nn.Linear(d_model, ntoken)\n","        self.use_pos = use_pos\n","        if self.use_pos:\n","            self.pos_encoder = PositionalEncoding(d_model, dropout)\n","\n","    def forward(self, src, memory=None, src_mask=None):\n","\n","        src = self.embedding(src) * np.sqrt(self.d_model)\n","        if self.use_pos:\n","            src = self.pos_encoder(src)\n","        output = self.transformer_encoder(src, src_mask)\n","        self.doc_embd = output\n","        output = self.linear(output)\n","        return output"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"9p4KHUkcqgzx"},"outputs":[],"source":["dataset = dataset.to(torch.device(device))\n","topic_mixtures = topic_mixtures.to(torch.device(device))\n","\n","train_dataset = dataset[:val_idx]\n","val_dataset = dataset[val_idx:test_idx]\n","test_dataset = dataset[test_idx:]\n","train_mixtures = topic_mixtures[:val_idx]\n","val_mixtures = topic_mixtures[val_idx:test_idx]\n","test_mixtures = topic_mixtures[test_idx:]\n","\n","def get_loss_tv(output, target):\n","\n","    diff = torch.abs(output-target)\n","\n","    return torch.max(diff, dim=1).values\n","\n","def get_loss_l2(output, target):\n","\n","    return torch.mean(torch.sum((output-target)**2, dim=1))"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":883,"status":"ok","timestamp":1697225615902,"user":{"displayName":"Liyi Zhang","userId":"06658486172350999592"},"user_tz":240},"id":"DILwCls454xa","outputId":"98367152-33ff-4d12-dd26-d04e2c58bd6c"},"outputs":[],"source":["d_model = 128\n","model = TransformerModel(V, d_model, 8, d_model, 4, 0.1, True).to(torch.device(device))\n","summary(model, [(1,M-1), (1,1,d_model), (M-1,M-1)], dtypes=[torch.int32,torch.float32,torch.float32])"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":1000},"executionInfo":{"elapsed":428232,"status":"error","timestamp":1697226044131,"user":{"displayName":"Liyi Zhang","userId":"06658486172350999592"},"user_tz":240},"id":"D0CufBTrM7P4","outputId":"54e995e9-010a-4ae9-db56-4e1b3f8095fe"},"outputs":[],"source":["'Transformer training'\n","sz = M-1\n","mask = (torch.triu(torch.ones((sz, sz), device=device)) == 1).transpose(0, 1)\n","mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))\n","# mask = None\n","\n","d_model = 128\n","model = TransformerModel(V, d_model, 8, d_model, 4, 0.1, True).to(torch.device(device))\n","\n","criterion = torch.nn.CrossEntropyLoss()\n","lr = 0.0001  # learning rate\n","optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n","batch_size = 16\n","min_val_loss = 1e6\n","\n","for epoch in range(200):\n","\n","    model.train()\n","    total_loss = 0.\n","    count = 0\n","\n","    num_batches = len(train_dataset) // batch_size\n","    for i in range(0, len(train_dataset), batch_size):\n","        end_idx = min(i+batch_size, len(train_dataset))\n","        data = train_dataset[i:end_idx, :-1]\n","        target = train_dataset[i:end_idx, 1:]\n","\n","        output = model(data, torch.zeros(data.shape[0], 1, d_model).to(torch.device(device)), mask)\n","\n","        output = output[:,:,:]\n","        target = target[:,:]\n","\n","        output_flat = torch.reshape(output,(-1,V))\n","        target_flat = torch.reshape(target, (-1,))\n","        loss = criterion(output_flat, target_flat)\n","\n","        optimizer.zero_grad()\n","        loss.backward()\n","        optimizer.step()\n","\n","        total_loss += loss.item() * data.shape[0]\n","\n","        # if count % 10 == 0:\n","        #     print(loss)\n","\n","        count += 1\n","\n","    print('Total train loss', total_loss / len(train_dataset))\n","\n","    model.eval()\n","    total_loss = 0.\n","    count = 0\n","\n","    num_batches = len(val_dataset) // batch_size\n","    for i in range(0, len(val_dataset), batch_size):\n","        end_idx = min(i+batch_size, len(val_dataset))\n","        data = val_dataset[i:end_idx, :-1]\n","        target = val_dataset[i:end_idx,1:]\n","\n","        output = model(data, torch.zeros(data.shape[0], 1, 128).to(torch.device(device)), mask)\n","\n","        output = output[:,:,:]\n","        target = target[:,:]\n","\n","        output_flat = torch.reshape(output,(-1,V))\n","        target_flat = torch.reshape(target, (-1,))\n","        loss = criterion(output_flat, target_flat)\n","\n","        total_loss += loss.item() * data.shape[0]\n","\n","        # if count % 10 == 0:\n","        #     print(loss)\n","\n","        count += 1\n","\n","    print('Total val loss', total_loss / len(val_dataset))\n","\n","    if total_loss < min_val_loss:\n","        min_val_loss = total_loss\n","        torch.save(model.state_dict(), '../results/transformer_model_weights_1.pth')\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":279,"status":"ok","timestamp":1697226100220,"user":{"displayName":"Liyi Zhang","userId":"06658486172350999592"},"user_tz":240},"id":"puBcF3-x5SAV","outputId":"bbcc480d-69ea-4622-a95e-546c8911fff8"},"outputs":[],"source":["model.load_state_dict(torch.load('../results/transformer_model_weights_1.pth'))\n","\n","input = val_dataset[[0], :-1].to(torch.device(device))\n","\n","output = model(input, torch.zeros(1, 1, 128).to(torch.device(device)), mask)\n","\n","print(input)\n","#print(output)\n","print(torch.argmax(output, dim=2))"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":1000},"executionInfo":{"elapsed":24642,"status":"error","timestamp":1697226125902,"user":{"displayName":"Liyi Zhang","userId":"06658486172350999592"},"user_tz":240},"id":"ehMtvIX6NARc","outputId":"c6ff0c80-3c84-41f5-b931-e79e042b6750"},"outputs":[],"source":["# lr = 0.0001 for real and 0.001 for fake\n","\n","id = 1\n","dataset = torch.tensor(np.genfromtxt(f'data/dataset_005_N10000_V1000_ID{id}.csv'), dtype=torch.int64)\n","topic_mixtures = torch.tensor(np.genfromtxt(f'data/topic_mixtures_005_N10000_V1000_ID{id}.csv'))\n","topic_vecs = torch.tensor(np.genfromtxt(f'data/topic_vecs_005_N10000_V1000_ID{id}.csv'))\n","\n","dataset = dataset.to(torch.device(device))\n","topic_mixtures = topic_mixtures.to(torch.device(device))\n","\n","train_dataset = dataset[:val_idx]\n","val_dataset = dataset[val_idx:test_idx]\n","test_dataset = dataset[test_idx:]\n","train_mixtures = topic_mixtures[:val_idx]\n","val_mixtures = topic_mixtures[val_idx:test_idx]\n","test_mixtures = topic_mixtures[test_idx:]\n","\n","def loss_mle(target_mixtures, q_params):\n","    samps = target_mixtures\n","    # Get q\n","    q = D.dirichlet.Dirichlet(q_params)\n","    logq = q.log_prob(samps)\n","\n","    return torch.mean(-logq)\n","\n","\n","#train_mode = 'bayesian' # choose between 'classification', 'bayesian'\n","train_mode = 'classification'\n","\n","if train_mode == 'bayesian':\n","    classifier = torch.nn.Sequential(\n","        torch.nn.Linear(d_model, 5),\n","        torch.nn.Softplus(),\n","    ).to(torch.device(device))\n","\n","elif train_mode == 'classification':\n","    classifier = torch.nn.Sequential(\n","        torch.nn.Linear(d_model, 5)\n","    ).to(torch.device(device))\n","\n","\n","model.eval()\n","for param in model.parameters():\n","    param.requires_grad = False\n","\n","criterion = torch.nn.CrossEntropyLoss(reduction='sum')\n","lr = 0.001  # learning rate\n","optimizer = torch.optim.Adam(classifier.parameters(), lr=lr)\n","batch_size = 64\n","\n","token_num = -1\n","\n","for epoch in range(300):\n","\n","    classifier.train()\n","    total_loss = 0.\n","    count = 0\n","\n","    num_batches = len(val_dataset) // batch_size\n","    for i in range(0, len(val_dataset), batch_size):\n","        end_idx = min(i+batch_size, len(val_dataset))\n","        data = val_dataset[i:end_idx, :-1]\n","        src = model.embedding(data) * np.sqrt(model.d_model)\n","        if model.use_pos:\n","            src = model.pos_encoder(src)\n","        if token_num == 0.5:\n","            outs = model.transformer_encoder(src, mask)\n","            embd = torch.mean(outs[:,:,:], dim=1)\n","        else:\n","            embd = model.transformer_encoder(src, mask)[:,token_num,:]\n","\n","\n","        target = val_mixtures[i:end_idx, :]\n","\n","        output = classifier(embd)\n","\n","        if train_mode == 'classification':\n","            loss = criterion(output, target)\n","        elif train_mode == 'bayesian':\n","            loss = loss_mle(target, output)\n","\n","        optimizer.zero_grad()\n","        loss.backward()\n","        optimizer.step()\n","\n","        total_loss += loss.item()\n","\n","        # if count % 10 == 0:\n","        #     print(loss)\n","        #     # for param in classifier.parameters():\n","        #     #     print(torch.mean(param))\n","\n","        count += 1\n","\n","    print('Total train loss', total_loss/len(val_dataset))\n","\n","    classifier.eval()\n","    total_loss = 0.\n","    loss_l2 = 0.\n","    loss_tv = 0.\n","    loss_ce = 0.\n","    count = 0\n","    pred_results = []\n","\n","    num_batches = len(test_dataset) // batch_size\n","    for i in range(0, len(test_dataset), batch_size):\n","        end_idx = min(i+batch_size, len(test_dataset))\n","        data = test_dataset[i:end_idx, :-1]\n","        src = model.embedding(data) * np.sqrt(model.d_model)\n","        if model.use_pos:\n","            src = model.pos_encoder(src)\n","        if token_num == 0.5:\n","            outs = model.transformer_encoder(src, mask)\n","            embd = torch.mean(outs[:,:,:], dim=1)\n","        else:\n","            embd = model.transformer_encoder(src, mask)[:,token_num,:]\n","\n","        target = test_mixtures[i:end_idx, :]\n","\n","        output = classifier(embd)\n","        #output = torch.zeros_like(output)+0.2\n","\n","        if train_mode == 'classification':\n","            loss = criterion(output, target)\n","        elif train_mode == 'bayesian':\n","            loss = loss_mle(target, output)\n","        total_loss += loss.item()\n","\n","        true_class = torch.argmax(target, 1)\n","        pred_class = torch.argmax(output, 1)\n","        pred_result = true_class == pred_class\n","        pred_result = list(pred_result.cpu().numpy())\n","        pred_results.extend(pred_result)\n","        # if count % 10 == 0:\n","        #     print(loss)\n","        count += 1\n","\n","        if train_mode == 'classification':\n","            loss_ce = total_loss\n","            loss_l2 += torch.sum(torch.sum((output.softmax(1)-target)**2, dim=1)).item()\n","            loss_tv += torch.sum(get_loss_tv(output.softmax(1), target)).item()\n","        elif train_mode == 'bayesian':\n","            q = D.dirichlet.Dirichlet(np.squeeze(output))\n","            samp = torch.mean(torch.squeeze(q.sample([10])), dim=0)\n","            loss_ce += -torch.sum(torch.sum(torch.multiply(torch.log(samp), target), dim=1)).item()\n","            loss_l2 += torch.sum(torch.sum((samp-target)**2, dim=1)).item()\n","            loss_tv += torch.sum(get_loss_tv(samp, target)).item()\n","\n","    pred_results = np.array(pred_results)\n","    acc = np.mean(pred_results)\n","\n","    print('Val loss CE', loss_ce/len(test_dataset))\n","    print('Val loss L2', loss_l2/len(test_dataset))\n","    print('Val loss TV', loss_tv/len(test_dataset))\n","    print('Accuracy', acc)\n","\n","    # total variation distance"]},{"cell_type":"markdown","metadata":{"id":"0futNQmlI9-O"},"source":["### BERT"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":103,"status":"ok","timestamp":1696882749144,"user":{"displayName":"Liyi Zhang","userId":"06658486172350999592"},"user_tz":240},"id":"_WJYr377NTE6","outputId":"a45bc6a0-8df2-4161-a410-5312b147f250"},"outputs":[],"source":["import random\n","mask_idx = random.sample(list(range(99)), 15)\n","print(mask_idx)"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"3HD6J2wVRVx1"},"outputs":[],"source":["train_dataset = dataset[:val_idx]\n","val_dataset = dataset[val_idx:test_idx]\n","test_dataset = dataset[test_idx:]\n","train_mixtures = topic_mixtures[:val_idx]\n","val_mixtures = topic_mixtures[val_idx:test_idx]\n","test_mixtures = topic_mixtures[test_idx:]"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"z90l0uMXbvSL"},"outputs":[],"source":["def get_loss_tv(output, target):\n","\n","    diff = torch.abs(output-target)\n","\n","    return torch.max(diff, dim=1).values\n","\n","def get_loss_l2(output, target):\n","\n","    return torch.mean(torch.sum((output-target)**2, dim=1))"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":1000},"executionInfo":{"elapsed":195092,"status":"error","timestamp":1696882946406,"user":{"displayName":"Liyi Zhang","userId":"06658486172350999592"},"user_tz":240},"id":"-KqOWVEnPqgO","outputId":"ab454ca4-9a84-436b-ea61-1e512d8ee23a"},"outputs":[],"source":["'BERT'\n","\n","import transformers\n","\n","config = transformers.AutoConfig.from_pretrained(\n","    \"prajjwal1/bert-tiny\",\n","    vocab_size=V+3,\n","    n_ctx=100,\n","    bos_token_id=V,\n","    eos_token_id=V+1,\n","    mask_token_id=V+2\n",")\n","model = transformers.BertLMHeadModel(config).to(torch.device(device))\n","model_size = sum(t.numel() for t in model.parameters())\n","print(model_size)\n","\n","criterion = torch.nn.CrossEntropyLoss()\n","lr = 0.0001  # learning rate\n","optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n","batch_size = 16\n","min_val_loss = 1e6\n","\n","for epoch in range(200):\n","\n","    model.train()\n","    total_loss = 0.\n","    count = 0\n","\n","    num_batches = len(train_dataset) // batch_size\n","    for i in range(0, len(train_dataset), batch_size):\n","        end_idx = min(i+batch_size, len(train_dataset))\n","        target = train_dataset[i:end_idx, :]\n","        input = torch.zeros_like(target)\n","        input[:,:] = target\n","        input[:,mask_idx] = V+2\n","\n","        input = input.to(torch.device(device))\n","        target = target.to(torch.device(device))\n","\n","        output = model(input).logits\n","\n","        output = output[:,mask_idx,:]\n","        target = target[:,mask_idx]\n","\n","        output_flat = output.view(-1, V+3)\n","        target_flat = torch.reshape(target, (-1,))\n","        loss = criterion(output_flat, target_flat)\n","\n","        optimizer.zero_grad()\n","        loss.backward()\n","        optimizer.step()\n","\n","        total_loss += loss.item() * input.shape[0]\n","\n","        # if count % 10 == 0:\n","        #     print(loss)\n","\n","        count += 1\n","\n","    print('Total train loss', total_loss/len(train_dataset))\n","\n","    model.eval()\n","    total_loss = 0.\n","    count = 0\n","\n","    num_batches = len(val_dataset) // batch_size\n","    for i in range(0, len(val_dataset), batch_size):\n","        end_idx = min(i+batch_size, len(val_dataset))\n","        target = val_dataset[i:end_idx, :]\n","        input = torch.zeros_like(target)\n","        input[:,:] = target\n","        input[:,mask_idx] = V+2\n","\n","        input = input.to(torch.device(device))\n","        target = target.to(torch.device(device))\n","\n","        output = model(input).logits\n","\n","        output = output[:,mask_idx,:]\n","        target = target[:,mask_idx]\n","\n","        output_flat = output.view(-1, V+3)\n","        target_flat = torch.reshape(target, (-1,))\n","        loss = criterion(output_flat, target_flat)\n","\n","        total_loss += loss.item() * input.shape[0]\n","\n","        # if count/ % 10 == 0:\n","        #     print(loss)\n","\n","        count += 1\n","\n","    print('Total val loss', total_loss/len(val_dataset))\n","\n","    if total_loss < min_val_loss:\n","        min_val_loss = total_loss\n","        torch.save(model.state_dict(), '../results/bert_model_weights.pth')"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":110,"status":"ok","timestamp":1696882950242,"user":{"displayName":"Liyi Zhang","userId":"06658486172350999592"},"user_tz":240},"id":"m0JpY_EiUYnr","outputId":"2936bced-c63e-42f1-f993-e62091927416"},"outputs":[],"source":["model.load_state_dict(torch.load('../results/bert_model_weights.pth'))\n","\n","target = val_dataset[0:1, :]\n","input = torch.zeros_like(target)\n","input[:,:] = target\n","input[:,mask_idx] = V+2\n","\n","input = input.to(torch.device(device))\n","target = target.to(torch.device(device))\n","\n","output = model(input)\n","\n","print(input)\n","print(target)\n","print(input[:,mask_idx])\n","print(target[:,mask_idx])\n","print(torch.argmax(output.logits, dim=2))\n","print(torch.argmax(output.logits, dim=2)[:,mask_idx])"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":55974,"status":"ok","timestamp":1696883460372,"user":{"displayName":"Liyi Zhang","userId":"06658486172350999592"},"user_tz":240},"id":"0TuYRnT9ZFiZ","outputId":"3b1410bd-3d34-4c28-d7d4-4da9f5cab378"},"outputs":[],"source":["'''\n","BERT probing\n","'''\n","\n","id = 2\n","dataset = torch.tensor(np.genfromtxt(f'../data/dataset_005_N10000_V1000_ID{id}.csv'), dtype=torch.int64)\n","topic_mixtures = torch.tensor(np.genfromtxt(f'../data/topic_mixtures_005_N10000_V1000_ID{id}.csv'))\n","topic_vecs = torch.tensor(np.genfromtxt(f'../data/topic_vecs_005_N10000_V1000_ID{id}.csv'))\n","\n","dataset = dataset.to(torch.device(device))\n","topic_mixtures = topic_mixtures.to(torch.device(device))\n","\n","train_dataset = dataset[:val_idx]\n","val_dataset = dataset[val_idx:test_idx]\n","test_dataset = dataset[test_idx:]\n","train_mixtures = topic_mixtures[:val_idx]\n","val_mixtures = topic_mixtures[val_idx:test_idx]\n","test_mixtures = topic_mixtures[test_idx:]\n","\n","#train_mode = 'bayesian' # choose between 'classification', 'bayesian'\n","train_mode = 'classification'\n","d_model = 128\n","\n","if train_mode == 'bayesian':\n","    classifier = torch.nn.Sequential(\n","        torch.nn.Linear(d_model, 5),\n","        torch.nn.Softplus(),\n","    ).to(torch.device(device))\n","\n","elif train_mode == 'classification':\n","    classifier = torch.nn.Sequential(\n","        torch.nn.Linear(d_model, 5)\n","    ).to(torch.device(device))\n","\n","model.eval()\n","for param in model.parameters():\n","    param.requires_grad = False\n","\n","criterion = torch.nn.CrossEntropyLoss(reduction='sum')\n","lr = 0.003  # learning rate\n","optimizer = torch.optim.Adam(classifier.parameters(), lr=lr)\n","batch_size = 64\n","\n","token_num = 0.5\n","\n","for epoch in range(300):\n","\n","    classifier.train()\n","    total_loss = 0.\n","    count = 0\n","\n","    num_batches = len(val_dataset) // batch_size\n","    for i in range(0, len(val_dataset), batch_size):\n","        end_idx = min(i+batch_size, len(val_dataset))\n","        data = val_dataset[i:end_idx, :]\n","        input = torch.zeros_like(data)\n","        input[:,:] = data\n","        input[:,mask_idx] = 102\n","        with torch.no_grad():\n","            if token_num != 0.5:\n","                embd = model(input.to(torch.device(device)), output_hidden_states=True).hidden_states[-1][:,token_num,:]\n","            else:\n","                outs = model(input.to(torch.device(device)), output_hidden_states=True).hidden_states[-1]\n","                embd = torch.mean(outs[:,:,:], dim=1)\n","\n","        target = val_mixtures[i:end_idx, :].to(torch.device(device))\n","\n","        output = classifier(embd)\n","\n","        if train_mode == 'classification':\n","            loss = criterion(output, target)\n","        elif train_mode == 'bayesian':\n","            loss = loss_mle(target, output)\n","\n","        optimizer.zero_grad()\n","        loss.backward()\n","        optimizer.step()\n","\n","        total_loss += loss.item()\n","\n","        # if count % 10 == 0:\n","        #     print(loss)\n","        #     # for param in classifier.parameters():\n","        #     #     print(torch.mean(param))\n","\n","        count += 1\n","\n","    print('Total train loss', total_loss/len(val_dataset))\n","\n","    classifier.eval()\n","    total_loss = 0.\n","    loss_l2 = 0.\n","    loss_tv = 0.\n","    loss_ce = 0.\n","    count = 0\n","    pred_results = []\n","\n","    num_batches = len(test_dataset) // batch_size\n","    for i in range(0, len(test_dataset), batch_size):\n","        end_idx = min(i+batch_size, len(test_dataset))\n","        data = test_dataset[i:end_idx, :]\n","        input = torch.zeros_like(data)\n","        input[:,:] = data\n","        input[:,mask_idx] = 102\n","        with torch.no_grad():\n","            if token_num != 0.5:\n","                embd = model(input.to(torch.device(device)), output_hidden_states=True).hidden_states[-1][:,token_num,:]\n","            else:\n","                outs = model(input.to(torch.device(device)), output_hidden_states=True).hidden_states[-1]\n","                embd = torch.mean(outs[:,:,:], dim=1)\n","\n","        target = test_mixtures[i:end_idx, :].to(torch.device(device))\n","\n","        output = classifier(embd)\n","        #output = torch.zeros_like(output)+0.2\n","\n","        if train_mode == 'classification':\n","            loss = criterion(output, target)\n","        elif train_mode == 'bayesian':\n","            loss = loss_mle(target, output)\n","        total_loss += loss.item()\n","\n","        true_class = torch.argmax(target, 1)\n","        pred_class = torch.argmax(output, 1)\n","        pred_result = true_class == pred_class\n","        pred_result = list(pred_result.cpu().numpy())\n","        pred_results.extend(pred_result)\n","        # if count % 10 == 0:\n","        #     print(loss)\n","        count += 1\n","\n","        if train_mode == 'classification':\n","            loss_ce = total_loss\n","            loss_l2 += torch.sum(torch.sum((output.softmax(1)-target)**2, dim=1)).item()\n","            loss_tv += torch.sum(get_loss_tv(output.softmax(1), target)).item()\n","        elif train_mode == 'bayesian':\n","            q = D.dirichlet.Dirichlet(np.squeeze(output))\n","            samp = torch.mean(torch.squeeze(q.sample([10])), dim=0)\n","            loss_ce += -torch.sum(torch.sum(torch.multiply(torch.log(samp), target), dim=1)).item()\n","            loss_l2 += torch.sum(torch.sum((samp-target)**2, dim=1)).item()\n","            loss_tv += torch.sum(get_loss_tv(samp, target)).item()\n","\n","    pred_results = np.array(pred_results)\n","    acc = np.mean(pred_results)\n","\n","    print('Val loss CE', loss_ce/len(test_dataset))\n","    print('Val loss L2', loss_l2/len(test_dataset))\n","    print('Val loss TV', loss_tv/len(test_dataset))\n","    print('Accuracy', acc)\n","\n","    # total variation distance"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"-Q95HeAqMrMW"},"outputs":[],"source":[]},{"cell_type":"markdown","metadata":{"id":"ABpq1A3DZyNI"},"source":["### LDA"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"a3oxdxrEZ0A3"},"outputs":[],"source":["import gensim\n","from gensim.models import LdaModel\n","from gensim import corpora\n","\n","train_dataset = dataset[:val_idx]\n","val_dataset = dataset[val_idx:test_idx]\n","test_dataset = dataset[test_idx:]\n","train_mixtures = topic_mixtures[:val_idx]\n","val_mixtures = topic_mixtures[val_idx:test_idx]\n","test_mixtures = topic_mixtures[test_idx:]\n","\n","dataset_lda = np.array(dataset.cpu().numpy(), dtype=str)\n","dictionary = corpora.Dictionary(dataset_lda)\n","\n","for key in dictionary.token2id:\n","    dictionary.id2token[dictionary.token2id[key]] = key\n","\n","lda_corpus = [dictionary.doc2bow(text) for text in dataset_lda]"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":144077,"status":"ok","timestamp":1697226285861,"user":{"displayName":"Liyi Zhang","userId":"06658486172350999592"},"user_tz":240},"id":"gadHtEgQZ0EG","outputId":"32b344ab-fda6-488e-97c2-a2f85a5e8471"},"outputs":[],"source":["lda = LdaModel(lda_corpus[:val_idx], num_topics=5, iterations=3000, passes=10)\n","print('LDA lower bound', lda.log_perplexity(lda_corpus[val_idx:test_idx]))"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":216,"status":"ok","timestamp":1697226285861,"user":{"displayName":"Liyi Zhang","userId":"06658486172350999592"},"user_tz":240},"id":"knHns1ssiCHI","outputId":"cc17b967-f012-4d7a-daac-fddcdee04a59"},"outputs":[],"source":["# ground truth generator\n","torch.argsort(topic_vecs, 1, descending=True)"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":215,"status":"ok","timestamp":1697226285863,"user":{"displayName":"Liyi Zhang","userId":"06658486172350999592"},"user_tz":240},"id":"MRSUVjhMhzLe","outputId":"c8672140-0c3e-42a7-9b77-7fbd9f0cd75c"},"outputs":[],"source":["print('-- Printing topics --')\n","for i in range(5):\n","    word_list = []\n","    pairs = lda.get_topic_terms(i, topn=20)\n","    for pair in pairs:\n","        word_list.append(dictionary.id2token[pair[0]])\n","    print(word_list)\n","print('---------------------')\n","\n","# topic 3 of trained LDA is topic 0 of the generator\n","# topic 4 is topic 1\n","# topic 0 is topic 2\n","# topic 1 is topic 3\n","# topic 2 is topic 4"]},{"cell_type":"markdown","metadata":{},"source":["Here you would need to match ground truth topic ids to lda topic ids"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"h8WfGl05lsh3"},"outputs":[],"source":["trained_lda_mixtures = []\n","for i in range(len(dataset_lda)):\n","    m = lda.get_document_topics(lda_corpus[i], minimum_probability=1e-15)\n","    m = [tup[1] for tup in m]\n","    m_ = [0,0,0,0,0]\n","    m_[0] = m[3]\n","    m_[1] = m[2]\n","    m_[2] = m[0]\n","    m_[3] = m[1]\n","    m_[4] = m[4]\n","    trained_lda_mixtures.append(m_)"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":13,"status":"ok","timestamp":1696876832254,"user":{"displayName":"Liyi Zhang","userId":"06658486172350999592"},"user_tz":240},"id":"4htBH9bHcYys","outputId":"a03cd487-4e65-4795-8dec-0bc303ba45d8"},"outputs":[],"source":["def get_loss_tv(output, target):\n","\n","    diff = torch.abs(output-target)\n","\n","    return torch.max(diff, dim=1).values\n","\n","def get_loss_l2(output, target):\n","\n","    return torch.mean(torch.sum((output-target)**2, dim=1))\n","\n","output = torch.tensor(trained_lda_mixtures[test_idx:]).to(torch.device(device))\n","target = test_mixtures.to(torch.device(device))\n","\n","loss_ce = -torch.mean(torch.sum(torch.multiply(torch.log(output), target), dim=1))\n","loss_l2 = torch.mean(torch.sum((output-target)**2, dim=1))\n","loss_tv = torch.mean(get_loss_tv(output, target))\n","\n","# lsexp = torch.logsumexp(output, dim=1, keepdims=True)\n","# print(-torch.mean(torch.sum(torch.multiply(output - lsexp, target), dim=1)))\n","# criterion = torch.nn.CrossEntropyLoss()\n","# print(criterion(output, target))\n","\n","true_class = torch.argmax(target, 1)\n","pred_class = torch.argmax(output, 1)\n","pred_result = true_class == pred_class\n","print('Accuracy', np.mean(pred_result.cpu().numpy()))\n","print('Loss CE', loss_ce)\n","print('Loss L2', loss_l2)\n","print('Loss TV', loss_tv)"]},{"cell_type":"markdown","metadata":{"id":"zg99z27Y9JQW"},"source":["### Word Embedder"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"JwHAcj409NQB"},"outputs":[],"source":["class WordEmbedder(torch.nn.Module):\n","\n","    def __init__(self, vocab_size, d_model, K):\n","\n","        super().__init__()\n","\n","        self.d_model = d_model\n","\n","        self.embedding = torch.nn.Embedding(vocab_size, d_model)\n","        self.classifier = torch.nn.Linear(d_model, K)\n","\n","    def forward(self, src):\n","\n","        src = self.embedding(src) * np.sqrt(self.d_model)\n","\n","        src = torch.mean(src, dim=1)\n","\n","        pred = self.classifier(src)\n","\n","        return pred\n"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"ad4N4UvkAtS0"},"outputs":[],"source":["dataset = dataset.to(torch.device(device))\n","topic_mixtures = topic_mixtures.to(torch.device(device))\n","\n","train_dataset = dataset[:val_idx]\n","val_dataset = dataset[val_idx:test_idx]\n","test_dataset = dataset[test_idx:]\n","train_mixtures = topic_mixtures[:val_idx]\n","val_mixtures = topic_mixtures[val_idx:test_idx]\n","test_mixtures = topic_mixtures[test_idx:]\n","\n","def get_loss_tv(output, target):\n","\n","    diff = torch.abs(output-target)\n","\n","    return torch.max(diff, dim=1).values\n","\n","def get_loss_l2(output, target):\n","\n","    return torch.mean(torch.sum((output-target)**2, dim=1))"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":46894,"status":"error","timestamp":1696806469422,"user":{"displayName":"Liyi Zhang","userId":"06658486172350999592"},"user_tz":240},"id":"-l8pa6LV9K97","outputId":"a4803ec5-b985-4a62-b06b-79a81b121583"},"outputs":[],"source":["model = WordEmbedder(V, 128, 5).to(torch.device(device))\n","\n","criterion = torch.nn.CrossEntropyLoss()\n","lr = 0.0001  # learning rate\n","optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n","batch_size = 64\n","min_val_loss = 1e6\n","\n","for epoch in range(300):\n","\n","    model.train()\n","    total_loss = 0.\n","    count = 0\n","\n","    num_batches = len(train_dataset) // batch_size\n","    for i in range(0, len(train_dataset), batch_size):\n","        end_idx = min(i+batch_size, len(train_dataset))\n","        data = train_dataset[i:end_idx, :]\n","        target = train_mixtures[i:end_idx, :]\n","\n","        output = model(data)\n","\n","        loss = criterion(output, target)\n","\n","        optimizer.zero_grad()\n","        loss.backward()\n","        optimizer.step()\n","\n","        total_loss += loss.item()\n","\n","        # if count % 10 == 0:\n","        #     print(loss)\n","\n","        count += 1\n","\n","    print('Total train loss', total_loss)\n","\n","    model.eval()\n","    total_loss = 0.\n","    loss_l2 = 0.\n","    loss_tv = 0.\n","    loss_ce = 0.\n","    count = 0\n","    pred_results = []\n","\n","    num_batches = len(test_dataset) // batch_size\n","    for i in range(0, len(test_dataset), batch_size):\n","        end_idx = min(i+batch_size, len(test_dataset))\n","        data = test_dataset[i:end_idx, :]\n","        target = test_mixtures[i:end_idx, :]\n","\n","        output = model(data)\n","\n","        loss = criterion(output, target)\n","\n","        total_loss += loss.item()\n","\n","        # if count % 10 == 0:\n","        #     print(loss)\n","\n","        count += 1\n","\n","        true_class = torch.argmax(target, 1)\n","        pred_class = torch.argmax(output, 1)\n","        pred_result = true_class == pred_class\n","        pred_result = list(pred_result.cpu().numpy())\n","        pred_results.extend(pred_result)\n","\n","        loss_ce = total_loss\n","        loss_l2 += torch.sum(torch.sum((output.softmax(1)-target)**2, dim=1)).item()\n","        loss_tv += torch.sum(get_loss_tv(output.softmax(1), target)).item()\n","\n","    pred_results = np.array(pred_results)\n","    acc = np.mean(pred_results)\n","\n","    print('Val loss CE', loss_ce/len(test_dataset))\n","    print('Val loss L2', loss_l2/len(test_dataset))\n","    print('Val loss TV', loss_tv/len(test_dataset))\n","    print('Accuracy', acc)\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"HOX2MUHP9LCx"},"outputs":[],"source":[]},{"cell_type":"code","execution_count":null,"metadata":{"id":"3O2vZLtS9LE2"},"outputs":[],"source":[]},{"cell_type":"code","execution_count":null,"metadata":{"id":"Ag6Z-4GN9LGr"},"outputs":[],"source":[]},{"cell_type":"code","execution_count":null,"metadata":{"id":"m7o4pAAF9LIg"},"outputs":[],"source":[]},{"cell_type":"markdown","metadata":{"id":"XCRSaXwJMv_p"},"source":["### Analysis"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"ze2oNyM227QS"},"outputs":[],"source":["import matplotlib.pyplot as plt\n","\n","idx = 0\n","\n","def plot_example(idx, method):\n","    data = val_dataset[[idx], :-1]\n","    data_target = val_dataset[[idx], 1:]\n","    target_in = torch.concat([torch.zeros((data_target.shape[0],1), device=device, dtype=torch.int32)+START_IDX, data_target], dim=1)\n","    src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(data, target_in, device=device)\n","\n","    with torch.no_grad():\n","        _ = model(data, target_in, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask)\n","        embd = model.doc_embd[:,token_num,:]\n","    #embd = model.encode(data, src_mask)[:,token_num,:]\n","\n","    target = torch.squeeze(val_mixtures[[idx], :]).cpu().numpy()\n","    output = classifier(embd).detach().cpu()\n","\n","    if method.lower() == 'lda':\n","        # q = D.dirichlet.Dirichlet(np.squeeze(output))\n","        # samp = torch.mean(torch.squeeze(q.sample([10])), dim=0).numpy()\n","        samp = trained_lda_mixtures[8001+idx]\n","    elif method.lower() == 'llm':\n","        samp = np.squeeze(torch.nn.functional.softmax(output, 1).numpy())\n","\n","    plt.stem(samp)\n","    plt.bar([0,1,2,3,4], target, alpha=0.2)\n","    plt.title(f'{method}: datapoint {idx}')"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":862},"executionInfo":{"elapsed":5533,"status":"ok","timestamp":1694807640708,"user":{"displayName":"Liyi Zhang","userId":"06658486172350999592"},"user_tz":240},"id":"dt1rDbBP27Sg","outputId":"258daeec-85a8-4b81-a1e0-4c21b692c170"},"outputs":[],"source":["plt.figure(figsize=(16,8))\n","plt.subplot(2,4,1)\n","plot_example(0, 'LDA')\n","plt.subplot(2,4,2)\n","plot_example(100, 'LDA')\n","plt.subplot(2,4,3)\n","plot_example(200, 'LDA')\n","plt.subplot(2,4,4)\n","plot_example(300, 'LDA')\n","plt.subplot(2,4,5)\n","plot_example(0, 'LLM')\n","plt.subplot(2,4,6)\n","plot_example(100, 'LLM')\n","plt.subplot(2,4,7)\n","plot_example(200, 'LLM')\n","plt.subplot(2,4,8)\n","plot_example(300, 'LLM')\n","plt.savefig('stickplots.pdf', format=\"pdf\", bbox_inches=\"tight\")\n","plt.show()"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"qL9hOolbqGt5"},"outputs":[],"source":[]}],"metadata":{"accelerator":"GPU","colab":{"collapsed_sections":["SEr7X_9tdQsy","bKRKDy7mX2eG","0futNQmlI9-O","ABpq1A3DZyNI","zg99z27Y9JQW"],"provenance":[]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.8.16"}},"nbformat":4,"nbformat_minor":0}
