1. AES 介绍

AES(高级加密标准,Advanced Encryption Standard),在密码学中又称 Rijndael 加密法,是美国联邦政府采用的一种分组加密标准。这个标准用来替代原先的 DES,目前已经广为全世界所使用,成为对称密钥算法中最流行的算法之一。

鉴于这三种模式的算法在本质上没有区别,所以本文主要介绍 AES-128 (数据分组为16字节,秘钥长度为16字节,加密轮数为10轮) ,并给出 ECB 模式下的 C++ 实现。

2. AES算法流程

AES算法主要可以分为秘钥扩展、字节替换、行移位、列混合和轮秘钥加这5个步骤。

  • 秘钥扩展(KeyExpansions:给定的初始秘钥一般比较短,比如16字节,而算法如果进行10轮运算的话就需要16x(10+1)字节长度的秘钥,需要对原始秘钥进行秘钥扩展。
  • 字节替换(SubBytes):一个非线性的替换步骤,根据查表把一个字节替换为另一个字节。
  • 行移位(ShiftRows):将数据矩阵的每一行循环移位一定长度。
  • 列混淆(MixColumns):将数据矩阵乘以一个固定的矩阵,增加混淆程度。
  • 轮秘钥加(AddRoundKey):将数据矩阵与秘钥矩阵进行异或操作。

具体的加解密流程可以见下图:
AES.jpg

3. AES 算法步骤

3.1 前提

AES 加密都是以一个 4x4 的状态矩阵为单位加密的。例如,将一串十六字节的字符串 0123456789abcdef 放入状态矩阵如下:

$$ \left[ \begin{matrix} 0 & 4 & 8 & c\\ 1 & 5 & 9 & d\\ 2 & 6 & a & e\\ 3 & 7 & b & f \end{matrix} \right] $$

注意这里的排列的顺序是竖排而不是横排

3.2 密钥扩展

根据上面的流程图可以看出,AES 原始密钥是 16 个字节的, 而进行一次 AES 加密需要 11 个矩阵大小的密钥,所以多出来的密钥肯定要通过一系列运算求出来。下面我们来讨论一下这些密钥是怎么来的。

首先,先给出 AES 密钥扩展的图解:
AES Key Expansion.png

由图可以看到,我们将密钥矩阵中每一列的 4 个字节都记为一个 bitset<32> w,这样就得到了 w[0] - w[3]。

得到了 w[0] - w[3] 之后,就可以根据它们计算出剩下的密钥了。计算方式如下:

W(4i)   = G(W(4i-1)) xor W(4i-4)
W(4i+1) = W(4i)      xor W(4i-3)
W(4i+2) = W(4i+1)    xor W(4i-2)
W(4i+3) = W(4i+2)    xor W(4i-1)
(其中 i = 1, 2, 3, …… , 10)

具体的流程这时候已经明确了。剩下的问题就是这个混淆函数 G 了。那么它做了什么呢?

这个 G 函数做了以下几件事:

  1. RotWord 循环左移: 将输入循环左移一个字节。如 输入0x12345678,输出0x34567812。
  2. SubWord 字节替代: S 盒替代。这个可以参考下面的字节替代
  3. Rcon 轮常量异或: 将输入的第一个字节和轮常量 Rcon 异或。

这个 Rcon 数组是通过计算算出来的,不过它不会变,所以可以直接当常量数组用。想了解它的算法可以参考这里

流程大概说完了,下面放出密钥扩展的代码实现吧:

// 轮常数,密钥扩展中用到。(AES-128只需要10轮)
byte Rcon[10] = { 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36 };

/**
 * \brief 密钥扩展, 生成轮密钥
 */
void AES::keyExpansion()
{
    // 最开始就是初始的 key
    for (auto i = 0; i < 4; i++)
        for (auto j = 0; j < 4; j++)
            subKey[0][i][j] = key[i][j];

    // 计算轮密钥
    for (auto round = 1; round <= 10; round++)
    {
        byte last[4]; // 前一个密钥

        for (auto i = 0; i < 4; i++)
                last[i] = subKey[round - 1][i][3];
        
        for (auto i = 0; i < 4; i++)
        {
            // 每轮密钥的前 4 byte 要进行自混淆
            if (i == 0)
                G(last, round - 1);

            // 计算轮密钥
            for (auto j = 0; j < 4; j++)
                subKey[round][j][i] = last[j] ^ subKey[round - 1][j][i];

            for (auto j = 0; j < 4; j++)
                last[j] = subKey[round][j][i];
        }
    }
}

/**
 * \brief 密钥自混淆用的 G 函数
 */
void AES::G(byte k[4], const int round)
{
    // 左移一字节
    cyclicShift(k, 1, true);

    // S 盒替换
    for (auto i = 0; i < 4; i++)
    {
        const auto row = k[i][7] * 8 + k[i][6] * 4 + k[i][5] * 2 + k[i][4];
        const auto column = k[i][3] * 8 + k[i][2] * 4 + k[i][1] * 2 + k[i][0];

        k[i] = S_BOX[row][column];
    }

    // 和 RC 常量进行异或
    k[0] ^= byte(Rcon[round]);
}

里面涉及到的 S 盒会在下面提到,可以将其当成一个常量数组。

3.3 字节替换 SubBytes

字节替换很简单,就是去 S 盒中查表就可以了。S 盒是一个 16 行 16 列的表,表中每个元素都是一个字节。具体流程就是函数 SubBytes() 接受一个 4x4 的字节矩阵作为输入,对其中的每个字节,前四位组成十六进制数 x 作为行号,后四位组成的十六进制数 y 作为列号,查找表中对应的值替换原来位置上的字节。

这里我将 S 盒定义为 16x16 的二维数组 S[16][16],字节替换时取该字节的高4位作为行下标,低4位作为列下标。这种方式因为还得对需要替换字节分别取高低位,所以会使字节替换操作复杂一点。可以采用 S[256] 一维数组省去这些操作,这样进行字节替换时就可以直接把该字节的值作为S盒数组的下标来进行替换。

/**
* S盒
*/
const int S_BOX[16][16] = {
    {0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76},
    {0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0},
    {0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15},
    {0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75},
    {0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84},
    {0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf},
    {0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8},
    {0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2},
    {0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73},
    {0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb},
    {0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79},
    {0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08},
    {0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a},
    {0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e},
    {0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf},
    {0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16}
};

解密时的 逆字节替换 就是用 逆S盒 进行字节替换。
逆S盒 如下:

/**
 * 逆S盒
 */
const int INVERSE_S_BOX[16][16] = {
    {0x52, 0x09, 0x6a, 0xd5, 0x30, 0x36, 0xa5, 0x38, 0xbf, 0x40, 0xa3, 0x9e, 0x81, 0xf3, 0xd7, 0xfb},
    {0x7c, 0xe3, 0x39, 0x82, 0x9b, 0x2f, 0xff, 0x87, 0x34, 0x8e, 0x43, 0x44, 0xc4, 0xde, 0xe9, 0xcb},
    {0x54, 0x7b, 0x94, 0x32, 0xa6, 0xc2, 0x23, 0x3d, 0xee, 0x4c, 0x95, 0x0b, 0x42, 0xfa, 0xc3, 0x4e},
    {0x08, 0x2e, 0xa1, 0x66, 0x28, 0xd9, 0x24, 0xb2, 0x76, 0x5b, 0xa2, 0x49, 0x6d, 0x8b, 0xd1, 0x25},
    {0x72, 0xf8, 0xf6, 0x64, 0x86, 0x68, 0x98, 0x16, 0xd4, 0xa4, 0x5c, 0xcc, 0x5d, 0x65, 0xb6, 0x92},
    {0x6c, 0x70, 0x48, 0x50, 0xfd, 0xed, 0xb9, 0xda, 0x5e, 0x15, 0x46, 0x57, 0xa7, 0x8d, 0x9d, 0x84},
    {0x90, 0xd8, 0xab, 0x00, 0x8c, 0xbc, 0xd3, 0x0a, 0xf7, 0xe4, 0x58, 0x05, 0xb8, 0xb3, 0x45, 0x06},
    {0xd0, 0x2c, 0x1e, 0x8f, 0xca, 0x3f, 0x0f, 0x02, 0xc1, 0xaf, 0xbd, 0x03, 0x01, 0x13, 0x8a, 0x6b},
    {0x3a, 0x91, 0x11, 0x41, 0x4f, 0x67, 0xdc, 0xea, 0x97, 0xf2, 0xcf, 0xce, 0xf0, 0xb4, 0xe6, 0x73},
    {0x96, 0xac, 0x74, 0x22, 0xe7, 0xad, 0x35, 0x85, 0xe2, 0xf9, 0x37, 0xe8, 0x1c, 0x75, 0xdf, 0x6e},
    {0x47, 0xf1, 0x1a, 0x71, 0x1d, 0x29, 0xc5, 0x89, 0x6f, 0xb7, 0x62, 0x0e, 0xaa, 0x18, 0xbe, 0x1b},
    {0xfc, 0x56, 0x3e, 0x4b, 0xc6, 0xd2, 0x79, 0x20, 0x9a, 0xdb, 0xc0, 0xfe, 0x78, 0xcd, 0x5a, 0xf4},
    {0x1f, 0xdd, 0xa8, 0x33, 0x88, 0x07, 0xc7, 0x31, 0xb1, 0x12, 0x10, 0x59, 0x27, 0x80, 0xec, 0x5f},
    {0x60, 0x51, 0x7f, 0xa9, 0x19, 0xb5, 0x4a, 0x0d, 0x2d, 0xe5, 0x7a, 0x9f, 0x93, 0xc9, 0x9c, 0xef},
    {0xa0, 0xe0, 0x3b, 0x4d, 0xae, 0x2a, 0xf5, 0xb0, 0xc8, 0xeb, 0xbb, 0x3c, 0x83, 0x53, 0x99, 0x61},
    {0x17, 0x2b, 0x04, 0x7e, 0xba, 0x77, 0xd6, 0x26, 0xe1, 0x69, 0x14, 0x63, 0x55, 0x21, 0x0c, 0x7d}
};

S-BOX 也是通过计算算出来的,具体方法可以参考这里

3.4 行移位 ShiftRows

前面已经说过,AES运算都是基于4x4二维数组进行的。行移位操作为:第0行不移动,第1行循环左移1字节,第2行循环左移2字节,第3行循环左移3字节。

shiftRows.png

解密时逆行移位操作为:第0行不移动,第1行循环右移1字节,第2行循环右移2字节,第3行循环右移3字节。

/**
 * \brief 循环位移
 * \param shift 移动的位数
 * \param encrypt true 为左移, false 为右移
 */
void AES::cyclicShift(byte row[4], const int shift, const bool encrypt)
{
    byte tmp[4];

    for (auto i = 0; i < 4; i++)
    {
        if (encrypt)
            tmp[i] = row[(i + shift) % 4];
        else
            tmp[i] = row[(i - shift + 4) % 4];
    }

    for (auto i = 0; i < 4; i++)
        row[i] = tmp[i];
}

/**
 * \brief 行位移
 */
void AES::shiftRows(byte matrix[4][4], bool encrypt)
{
    // 第一行不需要位移
    for (auto i = 1; i < 4; i++)
        cyclicShift(matrix[i], i, encrypt);
}

3.5 列混淆 MixColumns

列混淆通过矩阵相乘来实现,经过移位后的矩阵左乘一个固定的矩阵,得到混淆后的矩阵,如下公式所示:
mixColumns.png

注意公式中用到的乘法是伽罗华域(GF,有限域)上的乘法,高级加密标准文档 fips-197 上有讲。

解密时逆列混淆操作和列混淆一样,只是左乘的矩阵使用如下矩阵。

$$ \left[ \begin{matrix} 0E & 0B & 0D & 09\\ 09 & 0E & 0B & 0D\\ 0D & 09 & 0E & 0B\\ 0B & 0D & 09 & 0E \end{matrix} \right] $$

可以验证此矩阵B是列混合使用矩阵A的逆矩阵,两者乘积为单位矩阵,即AB=BA=E。

/**
 * \brief 列混淆
 */
void AES::mixColumns(byte matrix[4][4], bool encrypt)
{
    for (auto i = 0; i < 4; i++)
    {
        byte tmp[4];

        for (auto j = 0; j < 4; j++)
            tmp[j] = matrix[j][i];

        // 列混淆的矩阵相乘, 这里直接用了展开后的式子
        if (encrypt)
        {
            matrix[0][i] = GFMul(0x02, tmp[0]) ^ GFMul(0x03, tmp[1]) ^ tmp[2] ^ tmp[3];
            matrix[1][i] = tmp[0] ^ GFMul(0x02, tmp[1]) ^ GFMul(0x03, tmp[2]) ^ tmp[3];
            matrix[2][i] = tmp[0] ^ tmp[1] ^ GFMul(0x02, tmp[2]) ^ GFMul(0x03, tmp[3]);
            matrix[3][i] = GFMul(0x03, tmp[0]) ^ tmp[1] ^ tmp[2] ^ GFMul(0x02, tmp[3]);        
        }
        else
        {
            matrix[0][i] = GFMul(0x0e, tmp[0]) ^ GFMul(0x0b, tmp[1]) ^ GFMul(0x0d, tmp[2]) ^ GFMul(0x09, tmp[3]);
            matrix[1][i] = GFMul(0x09, tmp[0]) ^ GFMul(0x0e, tmp[1]) ^ GFMul(0x0b, tmp[2]) ^ GFMul(0x0d, tmp[3]);
            matrix[2][i] = GFMul(0x0d, tmp[0]) ^ GFMul(0x09, tmp[1]) ^ GFMul(0x0e, tmp[2]) ^ GFMul(0x0b, tmp[3]);
            matrix[3][i] = GFMul(0x0b, tmp[0]) ^ GFMul(0x0d, tmp[1]) ^ GFMul(0x09, tmp[2]) ^ GFMul(0x0e, tmp[3]);
        }
    }
}

/**
 * \brief GF(2^8) 上的乘法
 */
byte AES::GFMul(byte u, byte v)
{
    byte res;

    for (auto i = 0; i < 8; i++)
    {
        if ((v & byte(1)) != 0)
            res ^= u;

        byte flag = u & byte(0x80);
        
        u <<= 1;

        if (flag != 0)
            u ^= 0x1b; // 模素多项式

        v >>= 1;
    }

    return res;
}

3.6 轮密钥加 AddRoundKey

扩展密钥只参与了这一步。根据当前加密的轮数,用w[]中的 4 个扩展密钥与矩阵的 4 个列进行按位异或。如下图:
AddRoundKey.png

这里由于我写的代码使用 w11[4] 这样的数组来存放密钥的,所以与图中有所不同。具体代码放在最后一起吧。

4. C++ 代码实现

下面实现的是 ECB 模式下的 AES-128 加密算法:

AES.h

#pragma once

#include <iostream>
#include <bitset>
#include <string>
#include <vector>

#define ECB 0 // 电子密码本模式(Electronic Codebook Book)
#define CBC 1 // 密码分组链接模式(Cipher Block Chaining)
#define CFB 2 // 密码反馈模式(Cipher FeedBack)
#define OFB 3 // 输出反馈模式(Output FeedBack)
#define CTR 4 // 计算器模式(Counter)

using std::bitset;
using std::string;
using std::vector;
using std::cout;
using std::endl;

typedef bitset<8> byte;
typedef bitset<32> word;

class AES
{
private:
    byte key[4][4];
    byte subKey[11][4][4];

    string OFB_stream; // OFB 加密模式用的流
    vector<string> CTR_stream; // CTR 加密模式用的流

    /**
    * S盒
    */
    const int S_BOX[16][16] = {
        {0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76},
        {0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0},
        {0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15},
        {0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75},
        {0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84},
        {0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf},
        {0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8},
        {0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2},
        {0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73},
        {0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb},
        {0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79},
        {0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08},
        {0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a},
        {0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e},
        {0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf},
        {0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16}
    };

    /**
     * 逆S盒
     */
    const int INVERSE_S_BOX[16][16] = {
        {0x52, 0x09, 0x6a, 0xd5, 0x30, 0x36, 0xa5, 0x38, 0xbf, 0x40, 0xa3, 0x9e, 0x81, 0xf3, 0xd7, 0xfb},
        {0x7c, 0xe3, 0x39, 0x82, 0x9b, 0x2f, 0xff, 0x87, 0x34, 0x8e, 0x43, 0x44, 0xc4, 0xde, 0xe9, 0xcb},
        {0x54, 0x7b, 0x94, 0x32, 0xa6, 0xc2, 0x23, 0x3d, 0xee, 0x4c, 0x95, 0x0b, 0x42, 0xfa, 0xc3, 0x4e},
        {0x08, 0x2e, 0xa1, 0x66, 0x28, 0xd9, 0x24, 0xb2, 0x76, 0x5b, 0xa2, 0x49, 0x6d, 0x8b, 0xd1, 0x25},
        {0x72, 0xf8, 0xf6, 0x64, 0x86, 0x68, 0x98, 0x16, 0xd4, 0xa4, 0x5c, 0xcc, 0x5d, 0x65, 0xb6, 0x92},
        {0x6c, 0x70, 0x48, 0x50, 0xfd, 0xed, 0xb9, 0xda, 0x5e, 0x15, 0x46, 0x57, 0xa7, 0x8d, 0x9d, 0x84},
        {0x90, 0xd8, 0xab, 0x00, 0x8c, 0xbc, 0xd3, 0x0a, 0xf7, 0xe4, 0x58, 0x05, 0xb8, 0xb3, 0x45, 0x06},
        {0xd0, 0x2c, 0x1e, 0x8f, 0xca, 0x3f, 0x0f, 0x02, 0xc1, 0xaf, 0xbd, 0x03, 0x01, 0x13, 0x8a, 0x6b},
        {0x3a, 0x91, 0x11, 0x41, 0x4f, 0x67, 0xdc, 0xea, 0x97, 0xf2, 0xcf, 0xce, 0xf0, 0xb4, 0xe6, 0x73},
        {0x96, 0xac, 0x74, 0x22, 0xe7, 0xad, 0x35, 0x85, 0xe2, 0xf9, 0x37, 0xe8, 0x1c, 0x75, 0xdf, 0x6e},
        {0x47, 0xf1, 0x1a, 0x71, 0x1d, 0x29, 0xc5, 0x89, 0x6f, 0xb7, 0x62, 0x0e, 0xaa, 0x18, 0xbe, 0x1b},
        {0xfc, 0x56, 0x3e, 0x4b, 0xc6, 0xd2, 0x79, 0x20, 0x9a, 0xdb, 0xc0, 0xfe, 0x78, 0xcd, 0x5a, 0xf4},
        {0x1f, 0xdd, 0xa8, 0x33, 0x88, 0x07, 0xc7, 0x31, 0xb1, 0x12, 0x10, 0x59, 0x27, 0x80, 0xec, 0x5f},
        {0x60, 0x51, 0x7f, 0xa9, 0x19, 0xb5, 0x4a, 0x0d, 0x2d, 0xe5, 0x7a, 0x9f, 0x93, 0xc9, 0x9c, 0xef},
        {0xa0, 0xe0, 0x3b, 0x4d, 0xae, 0x2a, 0xf5, 0xb0, 0xc8, 0xeb, 0xbb, 0x3c, 0x83, 0x53, 0x99, 0x61},
        {0x17, 0x2b, 0x04, 0x7e, 0xba, 0x77, 0xd6, 0x26, 0xe1, 0x69, 0x14, 0x63, 0x55, 0x21, 0x0c, 0x7d}
    };

    // 轮常数,密钥扩展中用到。(AES-128只需要10轮)
    byte Rcon[10] = { 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36 };
    
/**********************************************************************/
/*                                                                    */
/*                              AES算法实现                           */
/*                                                                    */
/**********************************************************************/
    
    void keyExpansion();
    
    /* 轮加密所需的函数 */
    void subByte(byte matrix[4][4], bool encrypt = true) const;
    void shiftRows(byte matrix[4][4], bool encrypt = true);
    void mixColumns(byte matrix[4][4], bool encrypt = true);
    void addRoundKey(byte matrix[4][4], int round) const;

    /* 工具函数 */
    void cyclicShift(byte row[4], int shift, bool encrypt = true);
    static byte GFMul(byte u, byte v);
    void T(byte k[4], int round);
    static word byteToWord(byte k[4]);

protected:
    void _encrypt(byte plain[4][4]);
    void _decrypt(byte cipher[4][4]);

    string ECB_encryption(const string& plain);
    string ECB_decryption(const string& cipher);
    
    string CBC_encryption(const string& plain, const string& IV);
    string CBC_decryption(const string& cipher, const string& IV);
    
    string CFB_encryption(const string& plain, const string& IV, const int& segment_size = 8);
    string CFB_decryption(const string& cipher, const string& IV, const int& segment_size = 8);
    
    string OFB_encryption(const string& plain, const string& IV);
    string OFB_decryption(const string& cipher, const string& IV);
    
    string CTR_encryption(const string& plain, const string& nonce);
    string CTR_decryption(const string& cipher, const string& nonce);

public:
    AES();
    explicit AES(const string& k);
    ~AES();
    
    bool set_key(string k);

    static string padding(string str, const int& unit = 16);
    static string depadding(const string& str);

    string encrypt(const string& plain, int mode = ECB, const string& IV = "", const int& segment_size = 8);
    string decrypt(const string& cipher, int mode = ECB, const string& IV = "", const int& segment_size = 8);
};

/**
 * \brief string 转换为 hex
 */
inline string string2hex(const string& in)
{
    string res;

    for (auto i : in)
    {
        int tmp = static_cast<int>(i);

        int high = (tmp & 0xf0) >> 4;
        int low = (tmp & 0x0f);

        res += (high < 10 ? '0' + high : 'a' + high - 10);
        res += (low < 10 ? '0' + low : 'a' + low - 10);
    }

    return res;
}

inline string hex2string(const string& in)
{
    // 长度不是偶数说明有问题
    if (in.length() % 2 != 0)
        throw "Hex length error!";

    string res;

    for (auto i = 0; i < in.length(); i += 2)
    {
        int tmp = 0;

        tmp += (in[i] - '0' > 10) ? (in[i] - 'a' + 10) * 16 : (in[i] - '0') * 16;
        tmp += (in[i + 1] - '0' > 10) ? (in[i + 1] - 'a' + 10) : (in[i + 1] - '0');

        res += static_cast<char>(tmp);
    }

    return res;
}

/**
 * \brief 将 string 转换为 4x4 的矩阵
 */
inline void string2matrix(byte matrix[4][4], string plain)
{
    for (auto i = 0; i < 4; i++)
    {
        for (auto j = 0; j < 4; j++)
            matrix[j][i] = byte(plain[i * 4 + j]);
    }
}

/**
 * \brief 将 4x4 的矩阵转换为 string
 */
inline string matrix2string(byte matrix[4][4])
{
    string res;

    for (auto i = 0; i < 4; i++)
    {
        for (auto j = 0; j < 4; j++)
            res += static_cast<char>(matrix[j][i].to_ulong());
    }

    return res;
}

AES.cpp

#include "AES.h"

/**
 * \brief 密钥扩展, 生成轮密钥
 */
void AES::keyExpansion()
{
    // 最开始就是初始的 key
    for (auto i = 0; i < 4; i++)
        for (auto j = 0; j < 4; j++)
            subKey[0][i][j] = key[i][j];

    // 计算轮密钥
    for (auto round = 1; round <= 10; round++)
    {
        byte last[4]; // 前一个密钥

        for (auto i = 0; i < 4; i++)
                last[i] = subKey[round - 1][i][3];
        
        for (auto i = 0; i < 4; i++)
        {
            // 每轮密钥的前 4 byte 要进行自混淆
            if (i == 0)
                T(last, round - 1);

            // 计算轮密钥
            for (auto j = 0; j < 4; j++)
                subKey[round][j][i] = last[j] ^ subKey[round - 1][j][i];

            for (auto j = 0; j < 4; j++)
                last[j] = subKey[round][j][i];
        }
    }
}

/**
 * \brief 字节替换
 * \param plain 输入的原文
 * \param encrypt 标识是否加密
 */
void AES::subByte(byte matrix[4][4], bool encrypt) const
{
    for (auto i = 0; i < 4; i++)
    {
        for (auto j = 0; j < 4; j++)
        {
            // 计算出 S-BOX 中的行列
            const int row = matrix[i][j][7] * 8 + matrix[i][j][6] * 4 + matrix[i][j][5] * 2 + matrix[i][j][4];
            const int column = matrix[i][j][3] * 8 + matrix[i][j][2] * 4 + matrix[i][j][1] * 2 + matrix[i][j][0];

            if (encrypt)
                matrix[i][j] = S_BOX[row][column];
            else
                matrix[i][j] = INVERSE_S_BOX[row][column];
        }
    }
}

/**
 * \brief 循环位移
 * \param shift 移动的位数
 * \param encrypt true 为左移, false 为右移
 */
void AES::cyclicShift(byte row[4], const int shift, const bool encrypt)
{
    byte tmp[4];

    for (auto i = 0; i < 4; i++)
    {
        if (encrypt)
            tmp[i] = row[(i + shift) % 4];
        else
            tmp[i] = row[(i - shift + 4) % 4];
    }

    for (auto i = 0; i < 4; i++)
        row[i] = tmp[i];
}

/**
 * \brief 行位移
 */
void AES::shiftRows(byte matrix[4][4], bool encrypt)
{
    // 第一行不需要位移
    for (auto i = 1; i < 4; i++)
        cyclicShift(matrix[i], i, encrypt);
}

/**
 * \brief 列混淆
 */
void AES::mixColumns(byte matrix[4][4], bool encrypt)
{
    for (auto i = 0; i < 4; i++)
    {
        byte tmp[4];

        for (auto j = 0; j < 4; j++)
            tmp[j] = matrix[j][i];

        // 列混淆的矩阵相乘, 这里直接用了展开后的式子
        if (encrypt)
        {
            matrix[0][i] = GFMul(0x02, tmp[0]) ^ GFMul(0x03, tmp[1]) ^ tmp[2] ^ tmp[3];
            matrix[1][i] = tmp[0] ^ GFMul(0x02, tmp[1]) ^ GFMul(0x03, tmp[2]) ^ tmp[3];
            matrix[2][i] = tmp[0] ^ tmp[1] ^ GFMul(0x02, tmp[2]) ^ GFMul(0x03, tmp[3]);
            matrix[3][i] = GFMul(0x03, tmp[0]) ^ tmp[1] ^ tmp[2] ^ GFMul(0x02, tmp[3]);        
        }
        else
        {
            matrix[0][i] = GFMul(0x0e, tmp[0]) ^ GFMul(0x0b, tmp[1]) ^ GFMul(0x0d, tmp[2]) ^ GFMul(0x09, tmp[3]);
            matrix[1][i] = GFMul(0x09, tmp[0]) ^ GFMul(0x0e, tmp[1]) ^ GFMul(0x0b, tmp[2]) ^ GFMul(0x0d, tmp[3]);
            matrix[2][i] = GFMul(0x0d, tmp[0]) ^ GFMul(0x09, tmp[1]) ^ GFMul(0x0e, tmp[2]) ^ GFMul(0x0b, tmp[3]);
            matrix[3][i] = GFMul(0x0b, tmp[0]) ^ GFMul(0x0d, tmp[1]) ^ GFMul(0x09, tmp[2]) ^ GFMul(0x0e, tmp[3]);
        }
    }
}

/**
 * \brief 轮密钥加
 */
void AES::addRoundKey(byte matrix[4][4], const int round) const
{
    for (auto i = 0; i < 4; i++)
        for (auto j = 0; j < 4; j++)
            matrix[i][j] ^= subKey[round][i][j];
}

/**
 * \brief GF(2^8) 上的乘法
 */
byte AES::GFMul(byte u, byte v)
{
    byte res;

    for (auto i = 0; i < 8; i++)
    {
        if ((v & byte(1)) != 0)
            res ^= u;

        byte flag = u & byte(0x80);
        
        u <<= 1;

        if (flag != 0)
            u ^= 0x1b; // 模素多项式

        v >>= 1;
    }

    return res;
}

/**
 * \brief 密钥自混淆用的 T 函数
 */
void AES::T(byte k[4], const int round)
{
    // 左移一字节
    cyclicShift(k, 1, true);

    // S 盒替换
    for (auto i = 0; i < 4; i++)
    {
        const auto row = k[i][7] * 8 + k[i][6] * 4 + k[i][5] * 2 + k[i][4];
        const auto column = k[i][3] * 8 + k[i][2] * 4 + k[i][1] * 2 + k[i][0];

        k[i] = S_BOX[row][column];
    }

    // 和 RC 常量进行异或
    k[0] ^= byte(Rcon[round]);
}

/**
 * \brief 将 4 个 byte 类型转换为 word 类型
 */
word AES::byteToWord(byte k[4])
{
    word res(0x00000000);

    for (auto i = 0; i < 4; i++)
    {
        word tmp = k[i].to_ulong();
        tmp <<= 24 - 8 * i;
        res |= tmp;
    }

    return res;
}

/**
 * \brief padding 函数
 */
string AES::padding(string str, const int& unit)
{
    if (unit > BLOCK_SIZE)
        throw "Unit error!";
    
    const int n = unit - (str.length() % unit);
    const char pad = static_cast<char>(n);

    for (auto i = 0; i < n; i++)
        str += pad;

    return str;
}

/**
 * \brief 去除 padding
 */
string AES::depadding(const string& str)
{
    string res = str;
    const auto mark = str[str.length() - 1];

    for (auto i = 0; i < mark; i++)
        res.pop_back();

    return res;
}

/**
 * \brief 加密 128 位的数据
 */
void AES::_encrypt(byte plain[4][4])
{
    // 第一步: 轮密钥加
    addRoundKey(plain, 0);

    // 第二步: 开始轮加密
    for (auto round = 1; round <= 10; round++)
    {
        subByte(plain, true);
        shiftRows(plain, true);

        // 最后一轮不进行列混淆
        if (round != 10)
            mixColumns(plain, true); 
        
        addRoundKey(plain, round);
    }
}

/**
 * \brief 解密 128 位的数据
 */
void AES::_decrypt(byte cipher[4][4])
{
    // 第一步: 轮密钥加
    addRoundKey(cipher, 10);

    // 第二步: 开始轮解密
    for (auto round = 9; round >= 0; round--)
    {
        shiftRows(cipher, false);
        subByte(cipher, false);
        addRoundKey(cipher, round);

        // 最后一轮不进行列混淆
        if (round != 0)
            mixColumns(cipher, false);
    }
}

/**
 * \brief ECB 模式下的 AES 加密
 */
string AES::ECB_encryption(const string& plain)
{
    if (plain.length() % BLOCK_SIZE != 0)
        throw "Plain text's length must be a multiple of 16";
    
    string res;

    for (auto begin = 0; begin < plain.length(); begin += BLOCK_SIZE)
    {
        byte subText[4][4];

        string2matrix(subText, string(plain, begin, BLOCK_SIZE));

        _encrypt(subText);

        res += matrix2string(subText);
    }

    return res;
}

/**
 * \brief ECB 模式下的 AES 解密
 */
string AES::ECB_decryption(const string& cipher)
{
    if (cipher.length() % BLOCK_SIZE != 0)
        throw "Cipher text's length must be a multiple of 16";

    string res;

    for (auto begin = 0; begin < cipher.length(); begin += BLOCK_SIZE)
    {
        byte subText[4][4];

        string2matrix(subText, string(cipher, begin, BLOCK_SIZE));

        _decrypt(subText);

        res += matrix2string(subText);
    }

    return res;
}

/**
 * \brief CBC 模式下的 AES 加密
 */
string AES::CBC_encryption(const string& plain, const string& IV)
{
    if (plain.length() % BLOCK_SIZE != 0)
        throw "Plain text's length must be a multiple of 16";

    if (IV.length() != BLOCK_SIZE)
        throw "Initialization vector's length must be 128 bits!";

    string res;
    
    byte last_matrix[4][4]; // 需要被异或的矩阵
    string2matrix(last_matrix, IV); // 将 IV 转换为矩阵形式

    for (auto begin = 0; begin < plain.length(); begin += BLOCK_SIZE)
    {
        byte subText[4][4];
        string2matrix(subText, string(plain, begin,  BLOCK_SIZE));

        // 将当前需要加密的块与之前的密文块异或
        for (auto i = 0; i < 4; i++)
        {
            for (auto j = 0; j < 4; j++)
                subText[i][j] ^= last_matrix[i][j];
        }
        
        _encrypt(subText);

        res += matrix2string(subText);

        // 更新需要异或的矩阵
        for (auto i = 0; i < 4; i++)
        {
            for (auto j = 0; j < 4; j++)
                last_matrix[i][j] = subText[i][j];
        }
    }

    return res;
}

/**
 * \brief CBC 模式下的 AES 解密
 */
string AES::CBC_decryption(const string& cipher, const string& IV)
{
    if (cipher.length() % BLOCK_SIZE != 0)
        throw "Cipher text's length must be a multiple of 16";

    if (IV.length() != BLOCK_SIZE)
        throw "Initialization vector's length must be 128 bits!";

    string res;
    
    byte last_matrix[4][4]; // 需要被异或的矩阵
    string2matrix(last_matrix, IV); // 将 IV 转换为矩阵形式

    for (auto begin = 0; begin < cipher.length(); begin += BLOCK_SIZE)
    {
        byte subText[4][4];
        string2matrix(subText, string(cipher, begin, BLOCK_SIZE));

        _decrypt(subText);

        // 将当前需要解密的块与之前的密文块异或
        for (auto i = 0; i < 4; i++)
        {
            for (auto j = 0; j < 4; j++)
                subText[i][j] ^= last_matrix[i][j];
        }

        res += matrix2string(subText);
        
        // 更新需要异或的矩阵
        string2matrix(last_matrix, string(cipher, begin, BLOCK_SIZE));
    }

    return res;
}

/**
 * \brief CFB 模式下的 AES 加密
 * \param segment_size 每次加密的长度, 单位是位
 */
string AES::CFB_encryption(const string& plain, const string& IV, const int& segment_size)
{
    if (segment_size % 8 != 0)
        throw "Segment size must be a multiple of 8!";

    if (plain.length() % BLOCK_SIZE != 0)
        throw "Plain text's length must be a multiple of 16";

    if (IV.length() != BLOCK_SIZE)
        throw "Initialization vector's length must be 128 bits!";

    const int real_size = segment_size / 8;
    string shift_register = IV; // 模拟移位寄存器
    string res;

    for (auto begin = 0; begin < plain.length(); begin += real_size)
    {
        // 对移位器中数据进行加密
        byte subText[4][4];
        string2matrix(subText, shift_register);
        _encrypt(subText);
        // 注意这里不是直接保存在移位寄存器中的
        string encrypted_buffer = matrix2string(subText);

        string tmp;

        // 从Encrypted Buffer左侧取出Real Size个字节,与长度为Real Size的Plain Block进行异或操作
        for (auto i = 0; i < real_size; i++)
            tmp += plain[begin + i] ^ encrypted_buffer[i];

        res += tmp;

        // 将移位器中的数据左移Real Size个字节
        shift_register = string(shift_register, real_size);
        
        // 将前面的加密结果从右侧移入寄存器
        shift_register += tmp;
    }

    return res;
}

/**
 * \brief CFB 模式下的 AES 解密
 * \param segment_size 每次加密的长度, 单位是位
 */
string AES::CFB_decryption(const string& cipher, const string& IV, const int& segment_size)
{
    if (segment_size % 8 != 0)
        throw "Segment size must be a multiple of 8!";

    if (cipher.length() % BLOCK_SIZE != 0)
        throw "Cipher text's length must be a multiple of 16";

    if (IV.length() != BLOCK_SIZE)
        throw "Initialization vector's length must be 128 bits!";

    const int real_size = segment_size / 8;
    string shift_register = IV; // 模拟移位寄存器
    string res;

    for (auto begin = 0; begin < cipher.length(); begin += real_size)
    {
        // 对移位器中数据进行加密
        byte subText[4][4];
        string2matrix(subText, shift_register);
        _encrypt(subText);
        // 注意这里不是直接保存在移位寄存器中的
        string encrypted_buffer = matrix2string(subText);

        string tmp;

        // 从Encrypted Buffer左侧取出Real Size个字节,与长度为Real Size的Plain Block进行异或操作
        for (auto i = 0; i < real_size; i++)
            tmp += cipher[begin + i] ^ encrypted_buffer[i];

        res += tmp;

        // 将移位器中的数据左移Real Size个字节
        shift_register = string(shift_register, real_size);

        // 将密文从右侧移入寄存器
        shift_register += string(cipher, begin, real_size);
    }

    return res;
}

/**
 * \brief OFB 模式下的 AES 加密
 */
string AES::OFB_encryption(const string& plain, const string& IV)
{
    if (IV.length() != BLOCK_SIZE)
        throw "Initialization vector's length must be 128 bits!";

    string res;
    
    // 流为空代表之前没有使用过 OFB 加密模式, 需设置 IV
    if (OFB_stream.empty())
        OFB_stream = ECB_encryption(IV);

    // 如果明文长度小于等于流的长度则直接加密
    if (plain.length() <= OFB_stream.length())
    {
        for (auto i = 0; i < plain.length(); i++)
            res += plain[i] ^ OFB_stream[i];

        return res;
    }

    // 如果明文长度大于流的长度则继续生成流
    for (auto i = OFB_stream.length(); i < plain.length(); i += BLOCK_SIZE)
        OFB_stream += ECB_encryption(string(OFB_stream, i - BLOCK_SIZE, BLOCK_SIZE));

    for (auto i = 0; i < plain.length(); i++)
        res += plain[i] ^ OFB_stream[i];

    return res;
}

/**
 * \brief OFB 模式下的 AES 解密。
 * 由于明文和密文只在最终的异或过程中使用, 故加密与解密是对称的
 */
string AES::OFB_decryption(const string& cipher, const string& IV)
{
    if (IV.length() != BLOCK_SIZE)
        throw "Initialization vector's length must be 128 bits!";

    string res;

    // 流为空代表之前没有使用过 OFB 加密模式, 需设置 IV
    if (OFB_stream.empty())
        OFB_stream = ECB_encryption(IV);

    // 如果密文长度小于等于流的长度则直接解密
    if (cipher.length() <= OFB_stream.length())
    {
        for (auto i = 0; i < cipher.length(); i++)
            res += cipher[i] ^ OFB_stream[i];

        return res;
    }

    // 如果密文长度大于流的长度则继续生成流
    for (auto i = OFB_stream.length() - BLOCK_SIZE; i < cipher.length(); i += BLOCK_SIZE)
        OFB_stream += ECB_encryption(string(OFB_stream, i, BLOCK_SIZE));

    for (auto i = 0; i < cipher.length(); i++)
        res += cipher[i] ^ OFB_stream[i];

    return res;
}


/**
 * \brief CTR 模式下的 AES 加密
 */
string AES::CTR_encryption(const string& plain, const string& nonce)
{
    if (nonce.length() >= BLOCK_SIZE)
        throw "Nonce's length must be less than 128 bits!";

    string res;

    // 如果明文长度小于等于流的长度则直接加密
    if (plain.length() <= OFB_stream.length())
    {
        for (auto i = 0; i < plain.length() / BLOCK_SIZE; i++)
        {
            for (auto j = 0; j < BLOCK_SIZE; j++)
                res += plain[i * BLOCK_SIZE + j] ^ CTR_stream[i][j];
        }

        return res;
    }

    // 如果明文长度大于流的长度则继续生成流
    for (auto i = CTR_stream.size(); i < plain.length() / BLOCK_SIZE; i++)
    {
        // 计数器的值即为 vector 下标的 hex 值
        string counter = string2hex(std::to_string(CTR_stream.size()));
        int counter_len = BLOCK_SIZE - nonce.length();

        string input; // 需被加密的输入

        if (counter.length() > counter_len)
            input = nonce + string(counter, counter.length() - counter_len);
        else
        {
            input = nonce;

            // 补零
            for (auto i = 0; i < counter_len - counter.length(); i++)
                input += static_cast<char>(0);

            input += counter;
        }

        CTR_stream.push_back(ECB_encryption(input));
    }

    for (auto i = 0; i < plain.length() / BLOCK_SIZE; i++)
    {
        for (auto j = 0; j < BLOCK_SIZE; j++)
            res += plain[i * BLOCK_SIZE + j] ^ CTR_stream[i][j];
    }

    return res;
}

/**
 * \brief CTR 模式下的 AES 解密
 */
string AES::CTR_decryption(const string& cipher, const string& nonce)
{
    if (nonce.length() >= BLOCK_SIZE)
        throw "Nonce's length must be less than 128 bits!";

    string res;

    // 如果明文长度小于等于流的长度则直接加密
    if (cipher.length() <= OFB_stream.length())
    {
        for (auto i = 0; i < cipher.length() / BLOCK_SIZE; i++)
        {
            for (auto j = 0; j < BLOCK_SIZE; j++)
                res += cipher[i * BLOCK_SIZE + j] ^ CTR_stream[i][j];
        }

        return res;
    }

    // 如果明文长度大于流的长度则继续生成流
    for (auto i = CTR_stream.size(); i < cipher.length() / BLOCK_SIZE; i++)
    {
        // 计数器的值即为 vector 下标的 hex 值
        string counter = string2hex(std::to_string(CTR_stream.size()));
        int counter_len = BLOCK_SIZE - nonce.length();

        string input; // 需被加密的输入

        if (counter.length() > counter_len)
            input = nonce + string(counter, counter.length() - counter_len);
        else
        {
            input = nonce;

            // 补零
            for (auto i = 0; i < counter_len - counter.length(); i++)
                input += static_cast<char>(0);

            input += counter;
        }

        CTR_stream.push_back(ECB_encryption(input));
    }

    for (auto i = 0; i < cipher.length() / BLOCK_SIZE; i++)
    {
        for (auto j = 0; j < BLOCK_SIZE; j++)
            res += cipher[i * BLOCK_SIZE + j] ^ CTR_stream[i][j];
    }

    return res;
}

AES::AES()
= default;

AES::AES(const string& k)
{
    set_key(k);
}

AES::~AES()
= default;

bool AES::set_key(string k)
{
    if (k.length() != BLOCK_SIZE)
        throw "Key length must be 128 bits!";

    // 密钥更换后需要重置两个流的状态
    OFB_stream.clear();
    CTR_stream.clear();
    
    for (auto i = 0; i < 4; i++)
    {
        for (auto j = 0; j < 4; j++)
            key[j][i] = byte(k[i * 4 + j]);
    }

    // 生成轮密钥
    keyExpansion();
    
    return true;
}

string AES::encrypt(const string& plain, const int mode, const string& IV, const int& segment_size)
{
    switch (mode)
    {
    case ECB:
    {
        try
        {
            return ECB_encryption(plain);
        }
        catch (const char* msg)
        {
            cout << msg << endl;
            return string();
        }
    }
    case CBC:
    {
        try
        {
            return CBC_encryption(plain, IV);
        }
        catch (const char* msg)
        {
            cout << msg << endl;
            return string();
        }
    }
    case CFB:
    {
        try
        {
            return CFB_encryption(plain, IV, segment_size);
        }
        catch (const char* msg)
        {
            cout << msg << endl;
            return string();
        }
    }
    case OFB:
    {
        try
        {
            return OFB_encryption(plain, IV);
        }
        catch (const char* msg)
        {
            cout << msg << endl;
            return string();
        }
    }
    case CTR:
    {
        try
        {
            return CTR_encryption(plain, IV);
        }
        catch (const char* msg)
        {
            cout << msg << endl;
            return string();
        }
    }
    default:
    {
        cout << "Unsupported mode!" << endl;
        return string();
    }
    }
}

string AES::decrypt(const string& cipher, const int mode, const string& IV, const int& segment_size)
{
    switch (mode)
    {
    case ECB:
    {
        try
        {
            return ECB_decryption(cipher);
        }
        catch (const char* msg)
        {
            cout << msg << endl;
            return string();
        }
    }
    case CBC:
    {
        try
        {
            return CBC_decryption(cipher, IV);
        }
        catch (const char* msg)
        {
            cout << msg << endl;
            return string();
        }
    }
    case CFB:
    {
        try
        {
            return CFB_decryption(cipher, IV, segment_size);
        }
        catch (const char* msg)
        {
            cout << msg << endl;
            return string();
        }
    }
    case OFB:
    {
        try
        {
            return OFB_decryption(cipher, IV);
        }
        catch (const char* msg)
        {
            cout << msg << endl;
            return string();
        }
    }
    case CTR:
    {
        try
        {
            return CTR_decryption(cipher, IV);
        }
        catch (const char* msg)
        {
            cout << msg << endl;
            return string();
        }
    }
    default:
    {
        cout << "Unsupported mode!" << endl;
        return string();
    }
    }
}

测试代码 main.cpp

#include <iostream>
#include <iomanip>
#include <fstream>
#include <chrono>
#include "AES.h"
#include "base64.h"

using std::cin;
using std::cout;
using std::ifstream;
using std::ofstream;
using std::istreambuf_iterator;
using std::ios;
using std::endl;

AES aes;
int mode = -1;
string k;
string plain_path;
string cipher_path;

const string MODE[5] = { "ECB", "CBC", "CFB", "OFB", "CTR" };

int main()
{
    ifstream file;
    ofstream out;
    char cmd;

CHANGE_MODE:
    while (mode < 0 || mode >= 5)
    {
        cout << "请选择加密模式" << endl;
        cout << "1. ECB\n2. CBC\n3. CFB\n4. OFB\n5. CTR" << endl;
        cin >> mode;
        mode--;

        if (mode < 0 || mode >= 5)
        {
            cout << "不支持的加密模式! 请重试" << endl;
            system("pause");
        }
        
        system("cls");
    }
    
CHANGE_KEY:
    while (k.empty())
    {
        cout << "请输入你的密钥(16个字符)" << endl;
        cin >> k;

        try
        {
            aes.set_key(k);
            break;
        }
        catch (const char * msg)
        {
            cout << msg << endl;
            k.clear();
            system("pause");
        }
        
        system("cls");
    }
    
    system("cls");

    while (true)
    {
        cout << "当前加密模式为: " << MODE[mode] << endl;
        cout << "你的密钥为: " << k << endl;
        cout << "1. 加密\n2. 解密\n3. 修改密钥\n4. 修改加密模式\n5. 退出" << endl;
        cin >> cmd;

        switch (cmd)
        {
        case '1':
        {
            cout << "输入明文所在文件的绝对路径" << endl;
            cin.get();
            getline(cin, plain_path);

            file.open(plain_path, ios::binary);

            if (!file.is_open())
            {
                cout << "打开明文文件时发生错误!" << endl;
                break;
            }

            cout << "输入密文输出文件的绝对路径" << endl;
            getline(cin, cipher_path);

            out.open(cipher_path, ios::binary);

            if (!out.is_open())
            {
                cout << "打开输出文件时发生错误!" << endl;
                break;
            }
                
            // 一次读入整个文件,之后再分组加密
            string plain((istreambuf_iterator<char>(file)), istreambuf_iterator<char>());

            auto begin = std::chrono::system_clock::now();
            string cipher;

            try
            {
                cipher = aes.encrypt(AES::padding(plain, 16));
            }
            catch (const char* msg)
            {
                cout << msg << endl;
                break;
            }

            auto end = std::chrono::system_clock::now();
            auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end - begin);

            out.write(cipher.c_str(), cipher.length());
                
            file.clear();
            file.close();
            out.clear();
            out.close();

            cout << "加密成功!" << endl;
            cout << "一共花费了" << double(duration.count()) * \
                std::chrono::microseconds::period::num / std::chrono::microseconds::period::den \
                << "秒" << endl;

            break;
        }
        case '2':
        {
            cout << "输入密文所在文件的绝对路径" << endl;
            cin.get();
            getline(cin, cipher_path);

            file.open(cipher_path, ios::binary);

            if (!file.is_open())
            {
                cout << "打开明文文件时发生错误!" << endl;
                break;
            }

            cout << "输入明文输出文件的绝对路径" << endl;
            getline(cin, plain_path);

            out.open(plain_path, ios::binary);

            if (!out.is_open())
            {
                cout << "打开输出文件时发生错误!" << endl;
                break;
            }

            string cipher((istreambuf_iterator<char>(file)), istreambuf_iterator<char>());

            auto begin = std::chrono::system_clock::now();

            string plain;

            try
            {
                plain = AES::depadding(aes.decrypt(cipher));
            }
            catch (const char* msg)
            {
                cout << msg << endl;
                break;
            }

            auto end = std::chrono::system_clock::now();
            auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end - begin);

            out.write(plain.c_str(), plain.length());

            file.clear();
            file.close();
            out.clear();
            out.close();

            cout << "解密成功!" << endl;
            cout << "一共花费了" << double(duration.count()) * \
                std::chrono::microseconds::period::num / std::chrono::microseconds::period::den \
                << "秒" << endl;

            break;
        }
        case '3': system("cls"); k.clear(); goto CHANGE_KEY;
        case '4': system("cls"); mode = -1; goto CHANGE_MODE;
        case '5': return 0;
        default: cout << "不要瞎按!" << endl; break;
        }

        system("pause");
        system("cls");
    }

    return 0;
}

5. Referrer

https://blog.csdn.net/shaosunrise/article/details/80219950

https://boxueio.com/series/let-us-build-an-apn-provider/ebook/564

https://www.cnblogs.com/Junbo20141201/p/9369860.html

Last modification:November 18th, 2019 at 12:55 pm
If you think my article is useful to you, please feel free to appreciate