Sometimes we want to reduce tensors based on a given index's value. I had that specific problem today. Tensorflow's reduce operations are not terribly flexible, and will only reduce along an entire axis, indiscriminately. Specifically, let's say for example that I have a Q-table (in reinforcement learning, a Q-table lists, for each state of the environment and possible action in that state, the fitness value of that action/state combination):
q = tf.constant([[
Here the first axis represent possible states (in this case, only 2 states), and on the second axis we have a list of action-value "pairs". Let's say that for whatever reason, I now need to return, for each state, the action-value pair that has the maximum value.
As far as I know (and if I'm wrong about this, please save me from unnecessary code complexity and do let me know in the comment section below), Tensorflow won't let you do this in a single operation. Instead you'll have to use a "pattern" of operations that I will describe shortly, which can be followed for any similar index-specific reduce operation that you need to do (be it a min, max, any, all, sum, mean, etc.)
So, what happens when I use tf.reduce_max:
reduced_q = tf.reduce_max(q, 1)
output = sess.run(reduced_q)
[[ 5. 7.80000019]
[ 3.5 8.89999962]]
In other words, it took the max along dimension 0 of the inner axis, and the max along dimension 1 of the inner axis, for each state. This is not what we want. We only want the maximum value for index 1.
Instead, here is the solution I came up with. First, let's address the 2-dimensional case:
q_2d = [
max_indices = tf.argmax(tf.slice(q_2d, [0, 1], [-1, 1]), 0)
max_entries = tf.gather(q_2d, max_indices)
First, we slice along the inner dimension to get a version of the tensor containing only the values we want to max. We perform tf.argmax, which returns the index of the entries containing those max-valued pairs. Finally, we use tf.gather, to obtain a tensor from these specific indices.
This works correctly for 2d tensors. However, if you apply this code to the original 3d q tensor, you will find that it returns:
This still isn't what we want. First, we need to change the code to use tf.gather_nd (The "N-dimensional" version). Even then, however, you will find that the indices as returned by tf.argmax do not correctly index the entries we want from tf.gather_nd. tf.argmax returns:
But we need:
So we will have to do some extra manipulation:
max_indices = tf.argmax(tf.slice(q, [0, 0, 1], [-1, -1, 1]), 1)
inc_tensor = tf.constant([, ], dtype=tf.int64)
final_indices = tf.concat((inc_tensor, max_indices), 1)
max_entries = tf.gather_nd(q, final_indices)
The inc_tensor constant will have to be adjusted on a case-by-case basis depending on the actual dimensionality of your tensor. And we get:
[[ 4. 7.80000019]
[ 1.5 8.89999962]]
as desired. I hope that helps someone.