Learning to solve the example 1 of puzzle 3aa6fb7a in the ARC prize
Hi dear reader,
Puzzle rules
The example 1 of the puzzle 3aa6fb7a of the ARC prize looks like this:
A playout starts with an blank board of width 7 and of height 7. There are therefore 49 cells. Each cell can be blank or can contain one of the 10 possible colors. The colors are identified from 0 to 9.
I added a rule. This rule is that each cell has to be changed exactly once.
The game ends when all cells have been changed. The player wins if all the cell colors match the cell colors in the expected output.
Some numbers:
- Each cell has 11 possible values: 10 colors + blank state
- There are 49! (6.0828186e+62) distinct possible orders of cell changes in a playout.
- There are
49!*10 (6.0828186e+63)49 * 10 = 490 possible actions for the first action. - There are 11^49 (1.0671896e+51) possible board states
- The number of possible pairs (s, a) is even larger and upper-bounded by 11^49 * 49 * 10 = 5.2292289e+53.
Given a board state s and a proposed action a, an action value is calculated. From a pair (s, a), an action value is predicted See SARSA.
The research question is: can a neural net learn to predict the action value for any pair (s,a) of board state s and proposed action a, while training on only a very small fraction of possible pairs (s,a) ?
Playout simulation
The simulator generate a list of legal actions given the current state. Then, these actions are shuffled to generate random playouts.
The simulator then tests all legal actions. It then selects the action with maximal action value Q. The action is applied and a new state is obtained.
A single playout generates 12250 (s, a) samples. It takes about 3 hours to simulate 2048 playouts to generate a total of 25088000 simulated (s, a) samples for this example of this puzzle.
This number of train samples is basically very small compared to the search space (25088000 / 6.0828186e+63 = 4.1244038e-57 25088000 / 5.2292289e+53). But it is not zero nonetheless.
Sample model
The samples are stored in a HDF5 file.
The data model for the samples is the following:
DATATYPE H5T_COMPOUND {
H5T_ARRAY { [60] H5T_STD_U8LE } "input_state";
H5T_ARRAY { [7] H5T_STD_U8LE } "full_move_counter";
H5T_ARRAY { [60] H5T_STD_U8LE } "current_state";
H5T_ARRAY { [60] H5T_STD_U8LE } "action";
H5T_IEEE_F32LE "action_value";
}
Basically:
- "input_state" contains the input of the example.
- "full_move_counter" is the number of moves since the beginning of the playout.
- "current_state" is the current board state.
- "action" is the proposed action, which is a blank board with only the cell that changes in the changed color.
- "action_value" is the Q (quality) value of the proposed action a given the state s. See Q-Learning.
The neural net is given as input the "input_state", "full_move_counter", "current_state", "action" in 4 distinct tensors.
I use just 1 epoch for training. This means that every sample that the neural net sees is seen only once. This can be thought as some some of generalization. But not really since each playout generates correlated (s,a) samples.
The file containing the samples has a size of 4.5G.
Neural network model
The model used a decoder-only transformer model with a non-causal attention mechanism.
My neural network model is coded in PyTorch with a mix of PyTorch built-ins and also components from Meta's xformers.
The methodology is inspired in part from this work by Google DeepMind: Grandmaster-Level Chess Without Search. It's a very nice paper, you should check it out.
Train loss
It took 389m18.631s (6.48 h) to train the model on a pod on runpod.io. I used a NVIDIA A40.
My research and development for the ARC prize is advancing slowly. My neural net is learning on the example 1 of the puzzle 3aa6fb7a in the ARC prize.
Why it works ?
See the Bellman equation.
Next steps
To test whether or not it generalize, I need to do inference with the neural net using samples that were not in the train set.
Also, instead of generating 25088000 samples from 2048 playouts, maybe it would be better to generate more playouts and randomly sample (s,a) samples from those playouts.
Comments