{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "import torch.nn.functional as F\n", "from torch.utils.data import DataLoader, Dataset\n", "import numpy as np\n", "import csv" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Dummy dataset class\n", "class DummyDataset(Dataset):\n", " def __init__(self, max=100, size=1000, seq_len=9): # Ensure seq_len = 9 for a 3x3 grid\n", " self.data = torch.randint(1, max, (size, seq_len)) # Generate possible grid values\n", " \n", " def __len__(self):\n", " return len(self.data)\n", " \n", " def __getitem__(self, idx):\n", " return self.data[idx]" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# Transformer model definition\n", "class TransformerModel(nn.Module):\n", " def __init__(self, input_dim, embed_dim, num_heads, ff_dim, num_layers, output_dim):\n", " super(TransformerModel, self).__init__()\n", " self.embedding = nn.Embedding(input_dim, embed_dim)\n", " self.encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=ff_dim)\n", " self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)\n", " self.fc = nn.Linear(embed_dim, output_dim)\n", " \n", " def forward(self, x):\n", " x = self.embedding(x)\n", " x = self.transformer_encoder(x)\n", " x = x.mean(dim=1) # Global average pooling\n", " return torch.exp(self.fc(x)) # Ensure outputs are greater than zero\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cpu\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/nico/miniconda3/envs/ml/lib/python3.13/site-packages/torch/nn/modules/transformer.py:379: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.batch_first was not True(use batch_first for better inference performance)\n", " warnings.warn(\n" ] } ], "source": [ "# Model initialization\n", "input_dim = 1000 # Vocabulary size (digits 0-8)\n", "embed_dim = 128\n", "num_heads = 32\n", "ff_dim = 128\n", "num_layers = 100\n", "output_dim = 9 # Output 9 numbers for the 3x3 grid\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "print(device)\n", "model = TransformerModel(input_dim, embed_dim, num_heads, ff_dim, num_layers, output_dim).to(device)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# Training setup\n", "dataset = DummyDataset(max=input_dim, size=1000)\n", "dataloader = DataLoader(dataset, batch_size=32, shuffle=True)\n", "optimizer = optim.Adam(model.parameters(), lr=0.001)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# Custom loss function with scaling\n", "def magic_square_loss(output, scale_factor=100):\n", " scaled_output = output * scale_factor\n", " grid = scaled_output.view(-1, 3, 3) # Reshape to 3x3 grid\n", " row_sums = torch.sum(grid ** 2, dim=1)\n", " col_sums = torch.sum(grid ** 2, dim=2)\n", " diag1_sum = torch.sum(torch.diagonal(grid, dim1=1, dim2=2) ** 2, dim=1)\n", " diag2_sum = torch.sum(torch.diagonal(torch.flip(grid, dims=[2]), dim1=1, dim2=2) ** 2, dim=1)\n", " all_sums = torch.cat([row_sums, col_sums, diag1_sum.unsqueeze(1), diag2_sum.unsqueeze(1)], dim=1)\n", " loss = torch.var(all_sums.float(), dim=1).mean() # Ensure float dtype for variance calculation\n", " return loss / (scale_factor ** 2) # Scale back\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# Function to save outputs to file with computed sums and check for magic squares\n", "def save_outputs(epoch, batch_idx, outputs, scale_factor=100):\n", " with open(\"training_outputs.csv\", \"a\", newline=\"\") as f:\n", " writer = csv.writer(f)\n", " for output in outputs:\n", " scaled_output = torch.round(output * scale_factor).int().view(3, 3).detach().cpu().numpy()\n", " row_sums = np.sum(scaled_output ** 2, axis=1).tolist()\n", " col_sums = np.sum(scaled_output ** 2, axis=0).tolist()\n", " diag1_sum = np.sum(np.diagonal(scaled_output) ** 2)\n", " diag2_sum = np.sum(np.diagonal(np.fliplr(scaled_output)) ** 2)\n", " writer.writerow([epoch, batch_idx] + scaled_output.flatten().tolist() + row_sums + col_sums + [diag1_sum, diag2_sum])\n", " #print(f\"Sums: {row_sums + col_sums + [diag1_sum, diag2_sum]} {set(row_sums + col_sums + [diag1_sum, diag2_sum])}\")\n", " # Check if a magic square is found\n", " if len(set(row_sums + col_sums + [diag1_sum, diag2_sum])) == 1:\n", " print(\"MAGIC SQUARE FOUND!\")\n", " print(scaled_output,epoch, batch_idx)\n", " elif len(set(row_sums + col_sums + [diag1_sum, diag2_sum])) == 2:\n", " print(\"PARKER SQUARE FOUND!\")\n", " print(scaled_output, epoch, batch_idx)\n", " \"\"\" \n", " elif len(set(row_sums + col_sums + [diag1_sum, diag2_sum])) < 8:\n", " print(\"AT LEAST 2 NUMBERS MATCHED!!!! POG\")\n", " print(scaled_output, epoch, batch_idx) \n", " \"\"\"" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# Function to save outputs to file with computed sums and check for magic squares\n", "def check_outputs(epoch, batch_idx, outputs, scale_factor=100):\n", " for output in outputs:\n", " scaled_output = torch.round(output * scale_factor).int().view(3, 3).detach().cpu().numpy()\n", " row_sums = np.sum(scaled_output ** 2, axis=1).tolist()\n", " col_sums = np.sum(scaled_output ** 2, axis=0).tolist()\n", " diag1_sum = np.sum(np.diagonal(scaled_output) ** 2)\n", " diag2_sum = np.sum(np.diagonal(np.fliplr(scaled_output)) ** 2)\n", " \n", " #print(f\"Sums: {row_sums + col_sums + [diag1_sum, diag2_sum]} {set(row_sums + col_sums + [diag1_sum, diag2_sum])}\")\n", " # Check if a magic square is found\n", " if len(set(row_sums + col_sums + [diag1_sum, diag2_sum])) == 1:\n", " print(\"MAGIC SQUARE FOUND!\")\n", " print(scaled_output,epoch, batch_idx)\n", " elif len(set(row_sums + col_sums + [diag1_sum, diag2_sum])) == 2:\n", " print(\"PARKER SQUARE FOUND!\")\n", " print(scaled_output, epoch, batch_idx)\n", " \"\"\" \n", " elif len(set(row_sums + col_sums + [diag1_sum, diag2_sum])) < 8:\n", " print(\"AT LEAST 2 NUMBERS MATCHED!!!! POG\")\n", " print(scaled_output, epoch, batch_idx) \n", " \"\"\"" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "def train_model(model, dataloader, optimizer, epochs=5, scale_factor=100):\n", " model.train()\n", " with open(\"training_outputs.csv\", \"w\", newline=\"\") as f:\n", " writer = csv.writer(f)\n", " writer.writerow([\"Epoch\", \"Batch\"] + [f\"Pos_{i}\" for i in range(9)] + [f\"Row_{i}\" for i in range(3)] + [f\"Col_{i}\" for i in range(3)] + [\"Diag1\", \"Diag2\"])\n", " \n", " for epoch in range(epochs):\n", " total_loss = 0\n", " for batch_idx, batch in enumerate(dataloader):\n", " batch = batch.to(device) # Move batch to GPU\n", " optimizer.zero_grad()\n", " outputs = model(batch)\n", " loss = magic_square_loss(outputs, scale_factor)\n", " loss.backward()\n", " optimizer.step()\n", " total_loss += loss.item()\n", " #save_outputs(epoch, batch_idx, outputs, scale_factor)\n", " check_outputs(epoch, batch_idx, outputs, scale_factor)\n", " print(f\"Epoch {epoch+1}, Loss: {total_loss/len(dataloader):.4f}\")\n" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1, Loss: 875301.6689\n", "Epoch 2, Loss: 1843.8967\n", "Epoch 3, Loss: 1024.5599\n", "Epoch 4, Loss: 691.9049\n", "Epoch 5, Loss: 454.1340\n", "Epoch 6, Loss: 203.3463\n", "Epoch 7, Loss: 74.9484\n" ] }, { "ename": "KeyboardInterrupt", "evalue": "", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[10], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# Train the model\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m \u001b[43mtrain_model\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdataloader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mepochs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m10\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mscale_factor\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1000\u001b[39;49m\u001b[43m)\u001b[49m\n", "Cell \u001b[0;32mIn[9], line 12\u001b[0m, in \u001b[0;36mtrain_model\u001b[0;34m(model, dataloader, optimizer, epochs, scale_factor)\u001b[0m\n\u001b[1;32m 10\u001b[0m batch \u001b[38;5;241m=\u001b[39m batch\u001b[38;5;241m.\u001b[39mto(device) \u001b[38;5;66;03m# Move batch to GPU\u001b[39;00m\n\u001b[1;32m 11\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[0;32m---> 12\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 13\u001b[0m loss \u001b[38;5;241m=\u001b[39m magic_square_loss(outputs, scale_factor)\n\u001b[1;32m 14\u001b[0m loss\u001b[38;5;241m.\u001b[39mbackward()\n", "File \u001b[0;32m~/miniconda3/envs/ml/lib/python3.13/site-packages/torch/nn/modules/module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1734\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1736\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/miniconda3/envs/ml/lib/python3.13/site-packages/torch/nn/modules/module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1745\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1746\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1747\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n", "Cell \u001b[0;32mIn[3], line 12\u001b[0m, in \u001b[0;36mTransformerModel.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x):\n\u001b[1;32m 11\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39membedding(x)\n\u001b[0;32m---> 12\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtransformer_encoder\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 13\u001b[0m x \u001b[38;5;241m=\u001b[39m x\u001b[38;5;241m.\u001b[39mmean(dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m) \u001b[38;5;66;03m# Global average pooling\u001b[39;00m\n\u001b[1;32m 14\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mexp(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfc(x))\n", "File \u001b[0;32m~/miniconda3/envs/ml/lib/python3.13/site-packages/torch/nn/modules/module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1734\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1736\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/miniconda3/envs/ml/lib/python3.13/site-packages/torch/nn/modules/module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1745\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1746\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1747\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n", "File \u001b[0;32m~/miniconda3/envs/ml/lib/python3.13/site-packages/torch/nn/modules/transformer.py:511\u001b[0m, in \u001b[0;36mTransformerEncoder.forward\u001b[0;34m(self, src, mask, src_key_padding_mask, is_causal)\u001b[0m\n\u001b[1;32m 508\u001b[0m is_causal \u001b[38;5;241m=\u001b[39m _detect_is_causal_mask(mask, is_causal, seq_len)\n\u001b[1;32m 510\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m mod \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlayers:\n\u001b[0;32m--> 511\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mmod\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 512\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 513\u001b[0m \u001b[43m \u001b[49m\u001b[43msrc_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 514\u001b[0m \u001b[43m \u001b[49m\u001b[43mis_causal\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mis_causal\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 515\u001b[0m \u001b[43m \u001b[49m\u001b[43msrc_key_padding_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msrc_key_padding_mask_for_layers\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 516\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 518\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m convert_to_nested:\n\u001b[1;32m 519\u001b[0m output \u001b[38;5;241m=\u001b[39m output\u001b[38;5;241m.\u001b[39mto_padded_tensor(\u001b[38;5;241m0.0\u001b[39m, src\u001b[38;5;241m.\u001b[39msize())\n", "File \u001b[0;32m~/miniconda3/envs/ml/lib/python3.13/site-packages/torch/nn/modules/module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1734\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1736\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/miniconda3/envs/ml/lib/python3.13/site-packages/torch/nn/modules/module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1745\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1746\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1747\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n", "File \u001b[0;32m~/miniconda3/envs/ml/lib/python3.13/site-packages/torch/nn/modules/transformer.py:904\u001b[0m, in \u001b[0;36mTransformerEncoderLayer.forward\u001b[0;34m(self, src, src_mask, src_key_padding_mask, is_causal)\u001b[0m\n\u001b[1;32m 900\u001b[0m x \u001b[38;5;241m=\u001b[39m x \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_ff_block(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnorm2(x))\n\u001b[1;32m 901\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 902\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnorm1(\n\u001b[1;32m 903\u001b[0m x\n\u001b[0;32m--> 904\u001b[0m \u001b[38;5;241m+\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_sa_block\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msrc_mask\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msrc_key_padding_mask\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mis_causal\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mis_causal\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 905\u001b[0m )\n\u001b[1;32m 906\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnorm2(x \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_ff_block(x))\n\u001b[1;32m 908\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m x\n", "File \u001b[0;32m~/miniconda3/envs/ml/lib/python3.13/site-packages/torch/nn/modules/transformer.py:918\u001b[0m, in \u001b[0;36mTransformerEncoderLayer._sa_block\u001b[0;34m(self, x, attn_mask, key_padding_mask, is_causal)\u001b[0m\n\u001b[1;32m 911\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m_sa_block\u001b[39m(\n\u001b[1;32m 912\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 913\u001b[0m x: Tensor,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 916\u001b[0m is_causal: \u001b[38;5;28mbool\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[1;32m 917\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 918\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mself_attn\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 919\u001b[0m \u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 920\u001b[0m \u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 921\u001b[0m \u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 922\u001b[0m \u001b[43m \u001b[49m\u001b[43mattn_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattn_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 923\u001b[0m \u001b[43m \u001b[49m\u001b[43mkey_padding_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkey_padding_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 924\u001b[0m \u001b[43m \u001b[49m\u001b[43mneed_weights\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 925\u001b[0m \u001b[43m \u001b[49m\u001b[43mis_causal\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mis_causal\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 926\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 927\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdropout1(x)\n", "File \u001b[0;32m~/miniconda3/envs/ml/lib/python3.13/site-packages/torch/nn/modules/module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1734\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1736\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/miniconda3/envs/ml/lib/python3.13/site-packages/torch/nn/modules/module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1745\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1746\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1747\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n", "File \u001b[0;32m~/miniconda3/envs/ml/lib/python3.13/site-packages/torch/nn/modules/activation.py:1368\u001b[0m, in \u001b[0;36mMultiheadAttention.forward\u001b[0;34m(self, query, key, value, key_padding_mask, need_weights, attn_mask, average_attn_weights, is_causal)\u001b[0m\n\u001b[1;32m 1342\u001b[0m attn_output, attn_output_weights \u001b[38;5;241m=\u001b[39m F\u001b[38;5;241m.\u001b[39mmulti_head_attention_forward(\n\u001b[1;32m 1343\u001b[0m query,\n\u001b[1;32m 1344\u001b[0m key,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1365\u001b[0m is_causal\u001b[38;5;241m=\u001b[39mis_causal,\n\u001b[1;32m 1366\u001b[0m )\n\u001b[1;32m 1367\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1368\u001b[0m attn_output, attn_output_weights \u001b[38;5;241m=\u001b[39m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmulti_head_attention_forward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1369\u001b[0m \u001b[43m \u001b[49m\u001b[43mquery\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1370\u001b[0m \u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1371\u001b[0m \u001b[43m \u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1372\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43membed_dim\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1373\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnum_heads\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1374\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43min_proj_weight\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1375\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43min_proj_bias\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1376\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias_k\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1377\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias_v\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1378\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madd_zero_attn\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1379\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdropout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1380\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mout_proj\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1381\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mout_proj\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1382\u001b[0m \u001b[43m \u001b[49m\u001b[43mtraining\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1383\u001b[0m \u001b[43m \u001b[49m\u001b[43mkey_padding_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkey_padding_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1384\u001b[0m \u001b[43m \u001b[49m\u001b[43mneed_weights\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mneed_weights\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1385\u001b[0m \u001b[43m \u001b[49m\u001b[43mattn_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattn_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1386\u001b[0m \u001b[43m \u001b[49m\u001b[43maverage_attn_weights\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43maverage_attn_weights\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1387\u001b[0m \u001b[43m \u001b[49m\u001b[43mis_causal\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mis_causal\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1388\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1389\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_first \u001b[38;5;129;01mand\u001b[39;00m is_batched:\n\u001b[1;32m 1390\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m attn_output\u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m0\u001b[39m), attn_output_weights\n", "File \u001b[0;32m~/miniconda3/envs/ml/lib/python3.13/site-packages/torch/nn/functional.py:6278\u001b[0m, in \u001b[0;36mmulti_head_attention_forward\u001b[0;34m(query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training, key_padding_mask, need_weights, attn_mask, use_separate_proj_weight, q_proj_weight, k_proj_weight, v_proj_weight, static_k, static_v, average_attn_weights, is_causal)\u001b[0m\n\u001b[1;32m 6275\u001b[0m k \u001b[38;5;241m=\u001b[39m k\u001b[38;5;241m.\u001b[39mview(bsz, num_heads, src_len, head_dim)\n\u001b[1;32m 6276\u001b[0m v \u001b[38;5;241m=\u001b[39m v\u001b[38;5;241m.\u001b[39mview(bsz, num_heads, src_len, head_dim)\n\u001b[0;32m-> 6278\u001b[0m attn_output \u001b[38;5;241m=\u001b[39m \u001b[43mscaled_dot_product_attention\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 6279\u001b[0m \u001b[43m \u001b[49m\u001b[43mq\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mk\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mv\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattn_mask\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdropout_p\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mis_causal\u001b[49m\n\u001b[1;32m 6280\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 6281\u001b[0m attn_output \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 6282\u001b[0m attn_output\u001b[38;5;241m.\u001b[39mpermute(\u001b[38;5;241m2\u001b[39m, \u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m3\u001b[39m)\u001b[38;5;241m.\u001b[39mcontiguous()\u001b[38;5;241m.\u001b[39mview(bsz \u001b[38;5;241m*\u001b[39m tgt_len, embed_dim)\n\u001b[1;32m 6283\u001b[0m )\n\u001b[1;32m 6285\u001b[0m attn_output \u001b[38;5;241m=\u001b[39m linear(attn_output, out_proj_weight, out_proj_bias)\n", "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "# Train the model\n", "train_model(model, dataloader, optimizer, epochs=10, scale_factor=1000)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#Additional Training\n", "model.train()\n", "train_model(model, dataloader, optimizer, epochs=10, scale_factor=1000)" ] } ], "metadata": { "kernelspec": { "display_name": "ml", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.2" } }, "nbformat": 4, "nbformat_minor": 2 }