init
This commit is contained in:
249
parker-squares.ipynb
Normal file
249
parker-squares.ipynb
Normal file
@@ -0,0 +1,249 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import torch.optim as optim\n",
|
||||
"import torch.nn.functional as F\n",
|
||||
"from torch.utils.data import DataLoader, Dataset\n",
|
||||
"import numpy as np\n",
|
||||
"import csv"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Dummy dataset class\n",
|
||||
"class DummyDataset(Dataset):\n",
|
||||
" def __init__(self, max=100, size=1000, seq_len=9): # Ensure seq_len = 9 for a 3x3 grid\n",
|
||||
" self.data = torch.randint(1, max, (size, seq_len)) # Generate possible grid values\n",
|
||||
" \n",
|
||||
" def __len__(self):\n",
|
||||
" return len(self.data)\n",
|
||||
" \n",
|
||||
" def __getitem__(self, idx):\n",
|
||||
" return self.data[idx]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Transformer model definition\n",
|
||||
"class TransformerModel(nn.Module):\n",
|
||||
" def __init__(self, input_dim, embed_dim, num_heads, ff_dim, num_layers, output_dim):\n",
|
||||
" super(TransformerModel, self).__init__()\n",
|
||||
" self.embedding = nn.Embedding(input_dim, embed_dim)\n",
|
||||
" self.encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=ff_dim)\n",
|
||||
" self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)\n",
|
||||
" self.fc = nn.Linear(embed_dim, output_dim)\n",
|
||||
" \n",
|
||||
" def forward(self, x):\n",
|
||||
" x = self.embedding(x)\n",
|
||||
" x = self.transformer_encoder(x)\n",
|
||||
" x = x.mean(dim=1) # Global average pooling\n",
|
||||
" return torch.exp(self.fc(x)) # Ensure outputs are greater than zero\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/nico/miniconda3/envs/ml/lib/python3.13/site-packages/torch/nn/modules/transformer.py:379: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.batch_first was not True(use batch_first for better inference performance)\n",
|
||||
" warnings.warn(\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Model initialization\n",
|
||||
"input_dim = 1000 # Vocabulary size (digits 0-8)\n",
|
||||
"embed_dim = 32\n",
|
||||
"num_heads = 4\n",
|
||||
"ff_dim = 64\n",
|
||||
"num_layers = 100\n",
|
||||
"output_dim = 9 # Output 9 numbers for the 3x3 grid\n",
|
||||
"\n",
|
||||
"model = TransformerModel(input_dim, embed_dim, num_heads, ff_dim, num_layers, output_dim)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Training setup\n",
|
||||
"dataset = DummyDataset(max=input_dim)\n",
|
||||
"dataloader = DataLoader(dataset, batch_size=32, shuffle=True)\n",
|
||||
"optimizer = optim.Adam(model.parameters(), lr=0.001)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Custom loss function with scaling\n",
|
||||
"def magic_square_loss(output, scale_factor=100):\n",
|
||||
" scaled_output = output * scale_factor\n",
|
||||
" grid = scaled_output.view(-1, 3, 3) # Reshape to 3x3 grid\n",
|
||||
" row_sums = torch.sum(grid ** 2, dim=1)\n",
|
||||
" col_sums = torch.sum(grid ** 2, dim=2)\n",
|
||||
" diag1_sum = torch.sum(torch.diagonal(grid, dim1=1, dim2=2) ** 2, dim=1)\n",
|
||||
" diag2_sum = torch.sum(torch.diagonal(torch.flip(grid, dims=[2]), dim1=1, dim2=2) ** 2, dim=1)\n",
|
||||
" all_sums = torch.cat([row_sums, col_sums, diag1_sum.unsqueeze(1), diag2_sum.unsqueeze(1)], dim=1)\n",
|
||||
" loss = torch.var(all_sums.float(), dim=1).mean() # Ensure float dtype for variance calculation\n",
|
||||
" return loss / (scale_factor ** 2) # Scale back\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Function to save outputs to file with computed sums and check for magic squares\n",
|
||||
"def save_outputs(epoch, batch_idx, outputs, scale_factor=100):\n",
|
||||
" with open(\"training_outputs.csv\", \"a\", newline=\"\") as f:\n",
|
||||
" writer = csv.writer(f)\n",
|
||||
" for output in outputs:\n",
|
||||
" scaled_output = torch.round(output * scale_factor).int().view(3, 3).detach().cpu().numpy()\n",
|
||||
" row_sums = np.sum(scaled_output ** 2, axis=1).tolist()\n",
|
||||
" col_sums = np.sum(scaled_output ** 2, axis=0).tolist()\n",
|
||||
" diag1_sum = np.sum(np.diagonal(scaled_output) ** 2)\n",
|
||||
" diag2_sum = np.sum(np.diagonal(np.fliplr(scaled_output)) ** 2)\n",
|
||||
" writer.writerow([epoch, batch_idx] + scaled_output.flatten().tolist() + row_sums + col_sums + [diag1_sum, diag2_sum])\n",
|
||||
" #print(f\"Sums: {row_sums + col_sums + [diag1_sum, diag2_sum]} {set(row_sums + col_sums + [diag1_sum, diag2_sum])}\")\n",
|
||||
" # Check if a magic square is found\n",
|
||||
" if len(set(row_sums + col_sums + [diag1_sum, diag2_sum])) == 1:\n",
|
||||
" print(\"MAGIC SQUARE FOUND!\")\n",
|
||||
" print(scaled_output,epoch, batch_idx)\n",
|
||||
" elif len(set(row_sums + col_sums + [diag1_sum, diag2_sum])) == 2:\n",
|
||||
" print(\"PARKER SQUARE FOUND!\")\n",
|
||||
" print(scaled_output, epoch, batch_idx)\n",
|
||||
" \"\"\" \n",
|
||||
" elif len(set(row_sums + col_sums + [diag1_sum, diag2_sum])) < 8:\n",
|
||||
" print(\"AT LEAST 2 NUMBERS MATCHED!!!! POG\")\n",
|
||||
" print(scaled_output, epoch, batch_idx) \n",
|
||||
" \"\"\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Function to save outputs to file with computed sums and check for magic squares\n",
|
||||
"def check_outputs(epoch, batch_idx, outputs, scale_factor=100):\n",
|
||||
" for output in outputs:\n",
|
||||
" scaled_output = torch.round(output * scale_factor).int().view(3, 3).detach().cpu().numpy()\n",
|
||||
" row_sums = np.sum(scaled_output ** 2, axis=1).tolist()\n",
|
||||
" col_sums = np.sum(scaled_output ** 2, axis=0).tolist()\n",
|
||||
" diag1_sum = np.sum(np.diagonal(scaled_output) ** 2)\n",
|
||||
" diag2_sum = np.sum(np.diagonal(np.fliplr(scaled_output)) ** 2)\n",
|
||||
" \n",
|
||||
" #print(f\"Sums: {row_sums + col_sums + [diag1_sum, diag2_sum]} {set(row_sums + col_sums + [diag1_sum, diag2_sum])}\")\n",
|
||||
" # Check if a magic square is found\n",
|
||||
" if len(set(row_sums + col_sums + [diag1_sum, diag2_sum])) == 1:\n",
|
||||
" print(\"MAGIC SQUARE FOUND!\")\n",
|
||||
" print(scaled_output,epoch, batch_idx)\n",
|
||||
" elif len(set(row_sums + col_sums + [diag1_sum, diag2_sum])) == 2:\n",
|
||||
" print(\"PARKER SQUARE FOUND!\")\n",
|
||||
" print(scaled_output, epoch, batch_idx)\n",
|
||||
" \"\"\" \n",
|
||||
" elif len(set(row_sums + col_sums + [diag1_sum, diag2_sum])) < 8:\n",
|
||||
" print(\"AT LEAST 2 NUMBERS MATCHED!!!! POG\")\n",
|
||||
" print(scaled_output, epoch, batch_idx) \n",
|
||||
" \"\"\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def train_model(model, dataloader, optimizer, epochs=5, scale_factor=100):\n",
|
||||
" model.train()\n",
|
||||
" with open(\"training_outputs.csv\", \"w\", newline=\"\") as f:\n",
|
||||
" writer = csv.writer(f)\n",
|
||||
" writer.writerow([\"Epoch\", \"Batch\"] + [f\"Pos_{i}\" for i in range(9)] + [f\"Row_{i}\" for i in range(3)] + [f\"Col_{i}\" for i in range(3)] + [\"Diag1\", \"Diag2\"])\n",
|
||||
" \n",
|
||||
" for epoch in range(epochs):\n",
|
||||
" total_loss = 0\n",
|
||||
" for batch_idx, batch in enumerate(dataloader):\n",
|
||||
" optimizer.zero_grad()\n",
|
||||
" outputs = model(batch)\n",
|
||||
" loss = magic_square_loss(outputs, scale_factor)\n",
|
||||
" loss.backward()\n",
|
||||
" optimizer.step()\n",
|
||||
" total_loss += loss.item()\n",
|
||||
" #save_outputs(epoch, batch_idx, outputs, scale_factor)\n",
|
||||
" check_outputs(epoch, batch_idx, outputs, scale_factor)\n",
|
||||
" print(f\"Epoch {epoch+1}, Loss: {total_loss/len(dataloader):.4f}\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 1, Loss: 16573101137920.0000\n",
|
||||
"Epoch 2, Loss: 277233013760.0000\n",
|
||||
"Epoch 3, Loss: 122353267200.0000\n",
|
||||
"Epoch 4, Loss: 106233198080.0000\n",
|
||||
"Epoch 5, Loss: 101196775808.0000\n",
|
||||
"Epoch 6, Loss: 98463991808.0000\n",
|
||||
"Epoch 7, Loss: 88858541056.0000\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Train the model\n",
|
||||
"train_model(model, dataloader, optimizer, epochs=10, scale_factor=10000000)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "ml",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.13.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
1
training_outputs.csv
Normal file
1
training_outputs.csv
Normal file
@@ -0,0 +1 @@
|
||||
Epoch,Batch,Pos_0,Pos_1,Pos_2,Pos_3,Pos_4,Pos_5,Pos_6,Pos_7,Pos_8,Row_0,Row_1,Row_2,Col_0,Col_1,Col_2,Diag1,Diag2
|
||||
|
Reference in New Issue
Block a user