How AES Works
3. Structure of AES
def matrix2bytes(matrix):
""" Converts a 4x4 matrix into a 16-byte array. """
return bytes([matrix[i][j] for i in range(4) for j in range(4)])
4. Round Keys
def add_round_key(s, k):
return [[s[i][j] ^ k[i][j] for j in range(4)] for i in range(4)]
5. Confusion through Substitution
def sub_bytes(s, sbox=s_box):
return [[sbox[b] for b in x] for x in s]
6. Diffusion through Permutation
def inv_shift_rows(s):
s[0][1], s[1][1], s[2][1], s[3][1] = s[3][1], s[0][1], s[1][1], s[2][1]
s[0][2], s[1][2], s[2][2], s[3][2] = s[2][2], s[3][2], s[0][2], s[1][2]
s[0][3], s[1][3], s[2][3], s[3][3] = s[1][3], s[2][3], s[3][3], s[0][3]
7. Bringing it all Together
def decrypt(key, ciphertext):
round_keys = expand_key(key) # Remember to start from the last round key and work backwards through them when decrypting
# Convert ciphertext to state matrix
state = bytes2matrix(ciphertext)
# Initial add round key step
state = add_round_key(state, round_keys[-1])
for i in range(N_ROUNDS - 1, 0, -1):
# Do round
inv_shift_rows(state)
state = sub_bytes(state, inv_s_box)
state = add_round_key(state, round_keys[i])
inv_mix_columns(state)
# Run final round (skips the InvMixColumns step)
inv_shift_rows(state)
state = sub_bytes(state, inv_s_box)
state = add_round_key(state, round_keys[0])
# Convert state matrix to plaintext
plaintext = matrix2bytes(state)
return plaintext