From e780977ea7a7d8afda79f9967872a05b2684998d Mon Sep 17 00:00:00 2001 From: Nico Melone Date: Tue, 11 Feb 2025 18:20:20 -0600 Subject: [PATCH] model adjustments --- parker-squares.ipynb | 125 ++++++++++++++----------------------------- 1 file changed, 40 insertions(+), 85 deletions(-) diff --git a/parker-squares.ipynb b/parker-squares.ipynb index 94baa4a..6a2544d 100644 --- a/parker-squares.ipynb +++ b/parker-squares.ipynb @@ -64,14 +64,14 @@ "name": "stdout", "output_type": "stream", "text": [ - "cuda\n" + "cpu\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "c:\\Users\\Nico\\miniconda3\\envs\\ml\\Lib\\site-packages\\torch\\nn\\modules\\transformer.py:385: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.batch_first was not True(use batch_first for better inference performance)\n", + "/Users/nico/miniconda3/envs/ml/lib/python3.13/site-packages/torch/nn/modules/transformer.py:379: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.batch_first was not True(use batch_first for better inference performance)\n", " warnings.warn(\n" ] } @@ -79,10 +79,10 @@ "source": [ "# Model initialization\n", "input_dim = 1000 # Vocabulary size (digits 0-8)\n", - "embed_dim = 64\n", + "embed_dim = 128\n", "num_heads = 32\n", "ff_dim = 128\n", - "num_layers = 1000\n", + "num_layers = 100\n", "output_dim = 9 # Output 9 numbers for the 3x3 grid\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "print(device)\n", @@ -96,7 +96,7 @@ "outputs": [], "source": [ "# Training setup\n", - "dataset = DummyDataset(max=input_dim, size=1)\n", + "dataset = DummyDataset(max=input_dim, size=1000)\n", "dataloader = DataLoader(dataset, batch_size=32, shuffle=True)\n", "optimizer = optim.Adam(model.parameters(), lr=0.001)" ] @@ -211,97 +211,52 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Epoch 1, Loss: 2168482.7500\n", - "Epoch 2, Loss: 2058405.6250\n", - "Epoch 3, Loss: 156104.3906\n", - "Epoch 4, Loss: 64142.9570\n", - "Epoch 5, Loss: 124696.9297\n", - "Epoch 6, Loss: 78761.6641\n", - "Epoch 7, Loss: 59254.0273\n", - "Epoch 8, Loss: 31516.5137\n", - "Epoch 9, Loss: 25372.9551\n", - "Epoch 10, Loss: 20348.7832\n", - "Epoch 11, Loss: 26226.0098\n", - "Epoch 12, Loss: 11597.3955\n", - "Epoch 13, Loss: 13238.3486\n", - "Epoch 14, Loss: 15733.9697\n", - "Epoch 15, Loss: 8141.3574\n", - "Epoch 16, Loss: 24545.1621\n", - "Epoch 17, Loss: 14355.9150\n", - "Epoch 18, Loss: 8949.5967\n", - "Epoch 19, Loss: 10782.6211\n", - "Epoch 20, Loss: 8086.9185\n", - "Epoch 21, Loss: 7462.0649\n", - "Epoch 22, Loss: 9373.1094\n", - "Epoch 23, Loss: 3174.2156\n", - "Epoch 24, Loss: 4319.6846\n", - "Epoch 25, Loss: 3958.9021\n", - "Epoch 26, Loss: 3572.8293\n", - "Epoch 27, Loss: 8007.1396\n", - "Epoch 28, Loss: 3937.1189\n", - "Epoch 29, Loss: 5150.6650\n", - "Epoch 30, Loss: 2238.5474\n", - "Epoch 31, Loss: 1977.6279\n", - "Epoch 32, Loss: 1048.2073\n", - "Epoch 33, Loss: 1188.5005\n", - "Epoch 34, Loss: 1350.1697\n", - "Epoch 35, Loss: 873.2927\n", - "Epoch 36, Loss: 1433.1985\n", - "Epoch 37, Loss: 1185.3119\n", - "Epoch 38, Loss: 1040.7480\n", - "Epoch 39, Loss: 583.9352\n", - "Epoch 40, Loss: 1210.1052\n", - "Epoch 41, Loss: 2639.0896\n", - "Epoch 42, Loss: 965.5707\n", - "Epoch 43, Loss: 385.0073\n", - "Epoch 44, Loss: 703.6561\n", - "Epoch 45, Loss: 217.5824\n", - "Epoch 46, Loss: 556.3173\n", - "Epoch 47, Loss: 1557.3625\n", - "Epoch 48, Loss: 314.1129\n", - "Epoch 49, Loss: 646.6355\n", - "Epoch 50, Loss: 1853.7103\n", - "Epoch 51, Loss: 1410.0405\n", - "Epoch 52, Loss: 809.9199\n", - "Epoch 53, Loss: 1252.6039\n", - "Epoch 54, Loss: 1428.6708\n", - "Epoch 55, Loss: 589.9499\n", - "Epoch 56, Loss: 578.7387\n", - "Epoch 57, Loss: 1211.7103\n", - "Epoch 58, Loss: 1180.1213\n", - "Epoch 59, Loss: 1668.9307\n", - "Epoch 60, Loss: 1097.0701\n", - "Epoch 61, Loss: 137.2408\n", - "Epoch 62, Loss: 1906.7075\n", - "Epoch 63, Loss: 844.5870\n", - "Epoch 64, Loss: 747.8174\n", - "Epoch 65, Loss: 592.3982\n", - "Epoch 66, Loss: 214.6830\n", - "Epoch 67, Loss: 3160.4482\n", - "Epoch 68, Loss: 344.8889\n", - "Epoch 69, Loss: 2320.1528\n", - "Epoch 70, Loss: 615.0215\n", - "Epoch 71, Loss: 692.8455\n", - "Epoch 72, Loss: 545.4932\n", - "Epoch 73, Loss: 590.8438\n", - "Epoch 74, Loss: 1941.2006\n", - "Epoch 75, Loss: 935.5252\n", - "Epoch 76, Loss: 535.0656\n", - "Epoch 77, Loss: 942.5314\n", - "Epoch 78, Loss: 719.8766\n" + "Epoch 1, Loss: 875301.6689\n", + "Epoch 2, Loss: 1843.8967\n", + "Epoch 3, Loss: 1024.5599\n", + "Epoch 4, Loss: 691.9049\n", + "Epoch 5, Loss: 454.1340\n", + "Epoch 6, Loss: 203.3463\n", + "Epoch 7, Loss: 74.9484\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[10], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# Train the model\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m \u001b[43mtrain_model\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdataloader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mepochs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m10\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mscale_factor\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1000\u001b[39;49m\u001b[43m)\u001b[49m\n", + "Cell \u001b[0;32mIn[9], line 12\u001b[0m, in \u001b[0;36mtrain_model\u001b[0;34m(model, dataloader, optimizer, epochs, scale_factor)\u001b[0m\n\u001b[1;32m 10\u001b[0m batch \u001b[38;5;241m=\u001b[39m batch\u001b[38;5;241m.\u001b[39mto(device) \u001b[38;5;66;03m# Move batch to GPU\u001b[39;00m\n\u001b[1;32m 11\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[0;32m---> 12\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 13\u001b[0m loss \u001b[38;5;241m=\u001b[39m magic_square_loss(outputs, scale_factor)\n\u001b[1;32m 14\u001b[0m loss\u001b[38;5;241m.\u001b[39mbackward()\n", + "File \u001b[0;32m~/miniconda3/envs/ml/lib/python3.13/site-packages/torch/nn/modules/module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1734\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1736\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/miniconda3/envs/ml/lib/python3.13/site-packages/torch/nn/modules/module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1745\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1746\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1747\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n", + "Cell \u001b[0;32mIn[3], line 12\u001b[0m, in \u001b[0;36mTransformerModel.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x):\n\u001b[1;32m 11\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39membedding(x)\n\u001b[0;32m---> 12\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtransformer_encoder\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 13\u001b[0m x \u001b[38;5;241m=\u001b[39m x\u001b[38;5;241m.\u001b[39mmean(dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m) \u001b[38;5;66;03m# Global average pooling\u001b[39;00m\n\u001b[1;32m 14\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mexp(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfc(x))\n", + "File \u001b[0;32m~/miniconda3/envs/ml/lib/python3.13/site-packages/torch/nn/modules/module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1734\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1736\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/miniconda3/envs/ml/lib/python3.13/site-packages/torch/nn/modules/module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1745\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1746\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1747\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n", + "File \u001b[0;32m~/miniconda3/envs/ml/lib/python3.13/site-packages/torch/nn/modules/transformer.py:511\u001b[0m, in \u001b[0;36mTransformerEncoder.forward\u001b[0;34m(self, src, mask, src_key_padding_mask, is_causal)\u001b[0m\n\u001b[1;32m 508\u001b[0m is_causal \u001b[38;5;241m=\u001b[39m _detect_is_causal_mask(mask, is_causal, seq_len)\n\u001b[1;32m 510\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m mod \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlayers:\n\u001b[0;32m--> 511\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mmod\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 512\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 513\u001b[0m \u001b[43m \u001b[49m\u001b[43msrc_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 514\u001b[0m \u001b[43m \u001b[49m\u001b[43mis_causal\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mis_causal\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 515\u001b[0m \u001b[43m \u001b[49m\u001b[43msrc_key_padding_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msrc_key_padding_mask_for_layers\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 516\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 518\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m convert_to_nested:\n\u001b[1;32m 519\u001b[0m output \u001b[38;5;241m=\u001b[39m output\u001b[38;5;241m.\u001b[39mto_padded_tensor(\u001b[38;5;241m0.0\u001b[39m, src\u001b[38;5;241m.\u001b[39msize())\n", + "File \u001b[0;32m~/miniconda3/envs/ml/lib/python3.13/site-packages/torch/nn/modules/module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1734\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1736\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/miniconda3/envs/ml/lib/python3.13/site-packages/torch/nn/modules/module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1745\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1746\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1747\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n", + "File \u001b[0;32m~/miniconda3/envs/ml/lib/python3.13/site-packages/torch/nn/modules/transformer.py:904\u001b[0m, in \u001b[0;36mTransformerEncoderLayer.forward\u001b[0;34m(self, src, src_mask, src_key_padding_mask, is_causal)\u001b[0m\n\u001b[1;32m 900\u001b[0m x \u001b[38;5;241m=\u001b[39m x \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_ff_block(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnorm2(x))\n\u001b[1;32m 901\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 902\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnorm1(\n\u001b[1;32m 903\u001b[0m x\n\u001b[0;32m--> 904\u001b[0m \u001b[38;5;241m+\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_sa_block\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msrc_mask\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msrc_key_padding_mask\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mis_causal\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mis_causal\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 905\u001b[0m )\n\u001b[1;32m 906\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnorm2(x \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_ff_block(x))\n\u001b[1;32m 908\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m x\n", + "File \u001b[0;32m~/miniconda3/envs/ml/lib/python3.13/site-packages/torch/nn/modules/transformer.py:918\u001b[0m, in \u001b[0;36mTransformerEncoderLayer._sa_block\u001b[0;34m(self, x, attn_mask, key_padding_mask, is_causal)\u001b[0m\n\u001b[1;32m 911\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m_sa_block\u001b[39m(\n\u001b[1;32m 912\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 913\u001b[0m x: Tensor,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 916\u001b[0m is_causal: \u001b[38;5;28mbool\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[1;32m 917\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 918\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mself_attn\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 919\u001b[0m \u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 920\u001b[0m \u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 921\u001b[0m \u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 922\u001b[0m \u001b[43m \u001b[49m\u001b[43mattn_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattn_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 923\u001b[0m \u001b[43m \u001b[49m\u001b[43mkey_padding_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkey_padding_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 924\u001b[0m \u001b[43m \u001b[49m\u001b[43mneed_weights\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 925\u001b[0m \u001b[43m \u001b[49m\u001b[43mis_causal\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mis_causal\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 926\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 927\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdropout1(x)\n", + "File \u001b[0;32m~/miniconda3/envs/ml/lib/python3.13/site-packages/torch/nn/modules/module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1734\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1736\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/miniconda3/envs/ml/lib/python3.13/site-packages/torch/nn/modules/module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1745\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1746\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1747\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n", + "File \u001b[0;32m~/miniconda3/envs/ml/lib/python3.13/site-packages/torch/nn/modules/activation.py:1368\u001b[0m, in \u001b[0;36mMultiheadAttention.forward\u001b[0;34m(self, query, key, value, key_padding_mask, need_weights, attn_mask, average_attn_weights, is_causal)\u001b[0m\n\u001b[1;32m 1342\u001b[0m attn_output, attn_output_weights \u001b[38;5;241m=\u001b[39m F\u001b[38;5;241m.\u001b[39mmulti_head_attention_forward(\n\u001b[1;32m 1343\u001b[0m query,\n\u001b[1;32m 1344\u001b[0m key,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1365\u001b[0m is_causal\u001b[38;5;241m=\u001b[39mis_causal,\n\u001b[1;32m 1366\u001b[0m )\n\u001b[1;32m 1367\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1368\u001b[0m attn_output, attn_output_weights \u001b[38;5;241m=\u001b[39m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmulti_head_attention_forward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1369\u001b[0m \u001b[43m \u001b[49m\u001b[43mquery\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1370\u001b[0m \u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1371\u001b[0m \u001b[43m \u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1372\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43membed_dim\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1373\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnum_heads\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1374\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43min_proj_weight\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1375\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43min_proj_bias\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1376\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias_k\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1377\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias_v\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1378\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madd_zero_attn\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1379\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdropout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1380\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mout_proj\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1381\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mout_proj\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1382\u001b[0m \u001b[43m \u001b[49m\u001b[43mtraining\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1383\u001b[0m \u001b[43m \u001b[49m\u001b[43mkey_padding_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkey_padding_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1384\u001b[0m \u001b[43m \u001b[49m\u001b[43mneed_weights\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mneed_weights\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1385\u001b[0m \u001b[43m \u001b[49m\u001b[43mattn_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattn_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1386\u001b[0m \u001b[43m \u001b[49m\u001b[43maverage_attn_weights\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43maverage_attn_weights\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1387\u001b[0m \u001b[43m \u001b[49m\u001b[43mis_causal\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mis_causal\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1388\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1389\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_first \u001b[38;5;129;01mand\u001b[39;00m is_batched:\n\u001b[1;32m 1390\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m attn_output\u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m0\u001b[39m), attn_output_weights\n", + "File \u001b[0;32m~/miniconda3/envs/ml/lib/python3.13/site-packages/torch/nn/functional.py:6278\u001b[0m, in \u001b[0;36mmulti_head_attention_forward\u001b[0;34m(query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training, key_padding_mask, need_weights, attn_mask, use_separate_proj_weight, q_proj_weight, k_proj_weight, v_proj_weight, static_k, static_v, average_attn_weights, is_causal)\u001b[0m\n\u001b[1;32m 6275\u001b[0m k \u001b[38;5;241m=\u001b[39m k\u001b[38;5;241m.\u001b[39mview(bsz, num_heads, src_len, head_dim)\n\u001b[1;32m 6276\u001b[0m v \u001b[38;5;241m=\u001b[39m v\u001b[38;5;241m.\u001b[39mview(bsz, num_heads, src_len, head_dim)\n\u001b[0;32m-> 6278\u001b[0m attn_output \u001b[38;5;241m=\u001b[39m \u001b[43mscaled_dot_product_attention\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 6279\u001b[0m \u001b[43m \u001b[49m\u001b[43mq\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mk\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mv\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattn_mask\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdropout_p\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mis_causal\u001b[49m\n\u001b[1;32m 6280\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 6281\u001b[0m attn_output \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 6282\u001b[0m attn_output\u001b[38;5;241m.\u001b[39mpermute(\u001b[38;5;241m2\u001b[39m, \u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m3\u001b[39m)\u001b[38;5;241m.\u001b[39mcontiguous()\u001b[38;5;241m.\u001b[39mview(bsz \u001b[38;5;241m*\u001b[39m tgt_len, embed_dim)\n\u001b[1;32m 6283\u001b[0m )\n\u001b[1;32m 6285\u001b[0m attn_output \u001b[38;5;241m=\u001b[39m linear(attn_output, out_proj_weight, out_proj_bias)\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "# Train the model\n", - "train_model(model, dataloader, optimizer, epochs=10000, scale_factor=1000)" + "train_model(model, dataloader, optimizer, epochs=10, scale_factor=1000)" ] }, {