今天说一下这个非对称密钥沉思系列(1):RSA专题之PKCSv1.5填充模式下的选择性密文攻击概述

RSA向来就很容易受到选择密文攻击,这主要是因为RSA在乘法上具有同态特性。

本文主要梳理RSA 在PKCSv1.5 Padding模式下的 Oracle攻击。

1. 经典RSA

RSA算法作为经典的非对称加解密算法,破天荒的实现了“在不直接传递密钥的情况下,完成数据加解密”的构想。

RSA算法是建立在数论基础上的,其数学工具涉及到欧拉函数、模反元素、大素数分解等等,这里不再赘述。

但是有一点是值得强调的,RSA算法的安全性是建立在对大素数分解的复杂度上的。

基于数学工具生成RSA密钥对可以参考以下的示例代码:

import random
import time
from typing import List, Tuple
from gmpy2 import gmpy2

def odd_iter():
    n: int = 1
    while True:
        n = n + 2
        yield n

def not_divisible(n):
    def divide(x):
        return x % n > 0
    
    return divide

# 素数生成器
def primes():
    yield 2
    it = odd_iter()
    while True:
        n = next(it)
        yield n
        '''
            https://docs.python.org/3/library/functions.html#filter
            Note that filter(function, iterable) is equivalent to the generator expression
                (item for item in iterable if function(item))
            if function is not None and (item for item in iterable if item) if function is None.
        '''
        it = filter(not_divisible(n), it)


# 素数列表计算器
def get_primes(start: int, stop: int) -> List[int]:
    ret: List[int] = []
    for n in primes():
        if start <= n <= stop:
            ret.append(n)
        elif n > stop:
            break
    return ret

def get_p_q(start: int = 100, stop: int = 200) -> Tuple[int, int]:
    """
        随机选择两个不相等的质数p和q
    """
    primes_list = get_primes(start, stop)
    length = len(primes_list)
    if length <= 0:
        raise Exception("invalid start and stop range")
    while True:
        p_index: int = random.randint(0, length - 1)
        q_index: int = random.randint(0, length - 1)
        # print("primes:{} primes_list length:{}\np_index:{} q_index:{}".format(primes_list, length, p_index, q_index))
        if p_index != q_index:
            return primes_list[p_index], primes_list[q_index]


def get_n(p: int, q: int) -> Tuple[int, int]:
    """
        计算p和q的乘积n:
            n = p * q
        计算n的欧拉函数φ(n):
            φ(n) = (p-1)(q-1)
    """
    return p * q, (p - 1) * (q - 1)

def get_Euler_n(p: int, q: int) -> int:
    return (p - 1) * (q - 1)

def get_e(euler_number: int) -> int:
    """
        随机选择一个整数e,条件是1< e < φ(n),且e与φ(n) 互质。
    """
    primes_list = get_primes(1, euler_number - 1)
    length = len(primes_list)
    return primes_list[random.randint(0, length - 1)]


def get_d(e: int, euler_number: int) -> int:
    """
        计算e对于φ(n)的模反元素d。
        所谓"模反元素"就是指有一个整数d,可以使得ed被φ(n)除的余数为1:
            ed - 1 = (n)
        也就是需要求出:
            (e * x - 1) % φ(n) == 0
    """
    return int(gmpy2.invert(e, euler_number))

def get_key_pair(start: int = 100, stop: int = 200) -> Tuple[Tuple[int, int], Tuple[int, int]]:
    begin: int = int(round(time.time() * 1000))
    prime_p, prime_q = get_p_q(start, stop)
    print("get p={} q={}".format(prime_p, prime_q))
    n, euler_n = get_n(prime_p, prime_q)
    print("get n={} euler_n={}".format(n, euler_n))
    e: int = get_e(euler_n)
    print("get e={}".format(e))
    d: int = get_d(e, euler_n)
    print("get d={}".format(d))
    """
        上面的密钥生成步骤,一共出现六个数字:
        p q
        n euler_n
        e d
        这六个数字之中,公钥用到了两个(n和e),其余四个数字都是不公开的。其中最关键的是d,因为n和d组成了私钥,一旦d泄漏,就等于私钥泄漏。

        有无可能在已知n和e的情况下,推导出d?
        已知:
            (e * d - 1) % euler_n == 0 --> 只有知道e和euler_n,才能算出d
        已知:
            euler_n = (p - 1) * (q - 1) --> 只有知道p、q才能算出euler_n
        已知:
            n = p * q --> 只有将n因数分解,才能算出p和q
        因此,RSA的安全性在于对大数做因数分解的难度!
    """
    print("time cost= {} ms".format(int(round(time.time() * 1000)) - begin))
    return (n, e), (n, d)

if __name__ == '__main__':
    pub, pri = get_key_pair(start = 17, stop = 37)
    print("get pub key:{}, pri key:{}".format(pub, pri))

2. RSA加解密

2.1 模数运算法则

在真正讲解RSA的加解密算法之前,有必要先说明下,常用的模数运算法则,因为这个是RSA实现加解密的必要基础。

模数运算与基本四则运算有类似之处,同时也具备结合律、交换律、分配律的特性。

具体的推导过程这里不再赘述,直接使用代码进行结论的验证:

class ModularArithmetic(unittest.TestCase):
    def setUp(self):
        random.seed(time.time_ns())
        self.p: int = random.randint(1, 9)
        self.a: int = random.randint(2048, 4096)
        self.b: int = random.randint(1024, 2048)
        self.c: int = random.randint(128, 256)
        self.d: int = random.randint(128, 256)
        print("a={}, b={}, c={}, d={}, p={}".format(self.a, self.b, self.c, self.d, self.p))
    
    def test_basic(self):
        """
        模运算与基本四则运算有些相似,但是除法例外。其规则如下:
            (a + b) % p = (a % p + b % p) % p
            (a - b) % p = (a % p - b % p) % p
            (a * b) % p = (a % p * b % p) % p
            (a ^ b) % p = ((a % p)^b) % p
        """
        self.assertEqual(((self.a + self.b) % self.p), ((self.a % self.p) + (self.b % self.p)) % self.p)
        self.assertEqual(((self.a - self.b) % self.p), ((self.a % self.p) - (self.b % self.p)) % self.p)
        self.assertEqual(((self.a * self.b) % self.p), ((self.a % self.p) * (self.b % self.p)) % self.p)
        self.assertEqual(((self.a ** self.b) % self.p), ((self.a % self.p) ** self.b) % self.p)
    
    def test_law_of_association(self):
        """
        结合律:
            ((a+b) % p + c) % p = (a + (b+c) % p) % p
            ((a*b) % p * c)% p = (a *(b*c)%p) % p
        """
        self.assertEqual((((self.a + self.b) % self.p) + self.c) % self.p,
                         (self.a + ((self.b + self.c) % self.p)) % self.p)
        self.assertEqual((((self.a * self.b) % self.p) * self.c) % self.p,
                         (self.a * ((self.b * self.c) % self.p)) % self.p)
    
    def test_law_of_commutation(self):
        """
        交换律:
            (a + b) % p = (b + a) % p
            (a * b) % p = (b * a) % p
        """
        self.assertEqual((self.a + self.b) % self.p, (self.b + self.a) % self.p)
        self.assertEqual((self.a * self.b) % self.p, (self.b * self.a) % self.p)
    
    def test_law_of_distribution(self):
        """
        分配律:
            ((a +b)% p * c) % p = ((a * c) %p + (b * c) % p) % p
        """
        self.assertEqual((((self.a + self.b) % self.p) * self.c) % self.p,
                         ((self.a * self.c) % self.p + (self.b * self.c) % self.p) % self.p)

2.2 RSA加解密数学模型

RSA的加密逻辑,在数学模型上可以抽象为:c = (m^e) % n,其中的m为明文数据字节流在转换为int型大数后的值(这里默认为转换为大端字节序的数字),e为加密质数,n为RSA密钥对的模数,其来源为两个随机质数p、q的乘积,n的比特位数也就是我们常说的密钥长度,常见的值如1024、2048等等。

RSA的解密逻辑,在数学模型上可以抽象为:m = (c^d) % n,其中的c为密文数据字节流在转换为int型大数后的值(这里默认为转换为大端字节序的数字),d为解密质数;

字节流转换为大整数的代码示例可以参考如下:

def bytes_to_int(num_bytes: bytes, auto_select_order: bool = True) -> int:
    """
        auto_select_order 是否自动根据系统选择大小端字节序转换
        True:根据当前运行的OS选择大小端字节序转换
        False:默认按照大端字节序转换
    """
    _, order = bytes_order()
    if auto_select_order is False or order == 'big':
        return int.from_bytes(num_bytes, byteorder = 'big')
    else:
        return int.from_bytes(num_bytes, byteorder = 'little')

基于纯数学模型的加解密可以参考如下示例代码:

def demo_rsa_encrypt(message: int, rsa_pub_key: RSAPublicKey) -> int:
    print("origin message:{}".format(message))
    numbers = rsa_pub_key.public_numbers()
    """
        加密逻辑:(m^e) % n
    """
    return pow(message, numbers.e, numbers.n)


def demo_rsa_decrypt(cipher: int, rsa_pri_key: RSAPrivateKey) -> int:
    # print("cipher message:{}".format(cipher))
    numbers = rsa_pri_key.private_numbers()
    """
        解密逻辑:(c^d) % n
    """
    return pow(cipher, numbers.d, numbers.public_numbers.n)

2.3 纯加解密数学模型带来的问题

2.3.1 选择性密文攻击(无填充的同态攻击)

RSA算法本身具备同态性质。仅基于纯数学模型进行加解密运算,容易出现同态性质带来的选择性密文攻击。

这里的同态性质的选择密文攻击在工程意义的含义为:在无填充的前提下,同一对公私钥生成的两个密文的乘积,将解密为其对应的两个明文的乘积。

其数学层面上的推导过程可以概述成以下过程:

"""
存在加密逻辑:
   cipher_1 == (message_1 ^ e) % n
   cipher_2 == (message_2 ^ e) % n
以及解密逻辑:
   message_1 == (cipher_1 ^ d) % n
   message_2 == (cipher_2 ^ d) % n
现使得解密参数为cipher_1 * cipher_2:
   plain = (cipher_1 * cipher_2) ^ d % n
因式转换可以得到:
   plain = ((cipher_1 ^ d) * (cipher_2 ^ d)) % n
   plain = (((cipher_1 ^ d) % n) * ((cipher_2 ^ d) % n)) % n # 定理:(a ^ b) % p = ((a % p)^b) % p
   plain = ((cipher_1 % n) ^ d)%n * (((cipher_2 % n) ^ d)%n)
   plain = (((message_1 ^ e) % n)^d % n) * ((((message_2 ^ e) % n)^d % n)) # (m^e%n)^d%n == m
   plain = message_1 * message_2
"""
 def test_cca_attack(self):
     pub, pri = rsa_base.generate_rsa_keypair()
     message_1 = random.randint(100, 999)
     cipher_1 = rsa_base.demo_rsa_encrypt(message_1, pub)
     message_2 = random.randint(1000, 9999)
     cipher_2 = rsa_base.demo_rsa_encrypt(message_2, pub)
     # 密文乘积的解是明文的乘积
     plain = rsa_base.demo_rsa_decrypt(cipher_1 * cipher_2, pri)
     self.assertEqual(message_1 * message_2, plain)
     # 密文乘积共模取余的解仍然是明文的乘积
     plain = rsa_base.demo_rsa_decrypt((cipher_1 * cipher_2) % pub.public_numbers().n, pri)
     self.assertEqual(message_1 * message_2, plain)

2.3.2 共模攻击

共模攻击也是建立在无填充的基础上展开的。其工程意义上的含义为:在明文m不变的情况下,使用具有相同模数的两对密钥对,在只知道公钥(n、e)、不知道私钥d的前提下,可以解密对应的明文。

def common_modulus_decrypt(cipher1: int, cipher2: int, pub1: RSAPublicKey, pub2: RSAPublicKey) -> int:
    """
        已知rsa加密运算逻辑:
            c = (m^e) % n
        则对于共模n的两个公钥对同一个明文m加密,有:
            c1 = (m^e1) % n
            c2 = (m^e2) % n
        假设e1、e2的最大公约数为gcd,则根据扩展欧几里德算法可以得出,必定存在一组解使得:
            e1 * x + e2 * y == gcd,其中x、y均为实数
        现在对c1、c2进行如下运算:
            (c1^x * c2^y) % n
        我们可以得出:
            (c1^x * c2^y) % n == ((((m^e1) % n)^x) * (((m^e2) % n)^y)) % n
        通过模运算简化,可以得到:
            (c1^x * c2^y) % n == ((m^e1)^x * (m^e2)^y) % n
        对右边进一步简化可以得到:
            (c1^x * c2^y) % n == ((m^(e1*x)) * (m^(e2*y))) % n
        对右边再进行合并可以得到:
            (c1^x * c2^y) % n == (m^(e1*x + e2*y)) % n
        进一步我们可以得到:
            (c1^x * c2^y) % n == (m^(gcd)) % n
            (c1^x * c2^y) % n == ((m%n)^gcd)%n
        在rsa中,e1与e2必定互斥,即:gcd(e1,e2) == 1:
            (c1^x * c2^y) % n == m%n
        则我们可以进一步简化:
            (c1^x * c2^y) % n = m
        模运算拓展开,即:
            ((c1^x % n) * (c2^y % n)) % n = m
        即只需要我们求出唯一解x、y,就可以根据密文以及模长n计算出明文m
    """
    n1 = pub1.public_numbers().n
    e1 = pub1.public_numbers().e
    n2 = pub2.public_numbers().n
    e2 = pub2.public_numbers().e
    if n1 != n2:
        raise ValueError("required a common modulus")
    if e1 == e2:
        raise ValueError("required different public exponents")
    # 计算e1 和 e2 的 最大公约数gcd以及唯一解x、y,使得:e1 * x + e2 * y = gcd
    gcd, x, y = gmpy2.gcdext(e1, e2)
    print("e1={}, e2={}, gcd={}, x={}, y={}".format(e1, e2, gcd, x, y))
    if gcd != 1:
        raise ValueError("invalid 2 public exponents")
    print("before die inverse element calculate: n={}, x={}, cipher1={}, y={}, cipher2={}".
          format(n1, x, cipher1, y, cipher2))
    """
        假设x<0,记x==-a,则:
            c1^x % n 等价于 c1^-a % n
            右边可以转化为:(1/(c1^a)) % n
            由于在模n下的除法可以用和对应模逆元的乘法来表达。"分数取模",等价于求分母的模逆元
    """
    if x < 0:
        x, cipher1 = get_die_inverse_element(x, cipher1, n1)
    elif y < 0:
        y, cipher2 = get_die_inverse_element(y, cipher2, n1)
    print("after die inverse element calculate: n={}, x={}, cipher1={}, y={}, cipher2={}".
          format(n1, x, cipher1, y, cipher2))
    plain = (pow(int(cipher1), int(x)) * pow(int(cipher2), int(y))) % n1
    return plain

其中求模反元素的代码可以参考:

def get_die_inverse_element(x: float, cipher: int, n: int) -> Tuple[float, float]:
    """
        求模反元素
        如果两个正整数a和n互质,那么一定可以找到整数b,使得 ab-1 被n整除,或者说ab被n除的余数是1
        这时,b就叫做a的"模反元素"
        比如,3和11互质,那么3的模反元素就是4,因为 (3 × 4)-1 可以被11整除
        显然,模反元素不止一个,4加减11的整数倍都是3的模反元素 {...,-18,-7,4,15,26,...},即:
            如果b是a的模反元素,则 b+kn 都是a的模反元素。
        
        如果ax≡1(mod p),且a与p互质(gcd(a,p)=1),则称a关于模p的乘法逆元为x。(不互质则乘法逆元不存在)
        
        两个整数 a、b,若它们除以正整数 n 所得的余数相等,即 a mod n = b mod n, 则称 a 和 b 对于模 n 同余
    """
    if x > 0:
        raise ValueError("invalid parameter x, should be less than 0")
    return 0 - x, gmpy2.invert(cipher, n)

3. PKCSv1.5 Padding

3.1 填充规范

PKCS#1针对的是RSA算法。

RSA加密数据的长度和密钥位数有关,常用的密钥长度有1024bits,2048bits等,理论上1024bits的密钥可以加密的数据最大长度为1024bits(即1024/8 = 128bytes),2048bits的密钥可以加密的数据最大长度为2048bits(2048/8 = 256bytes)。

但是RSA在实际应用中不可能使用这种“教科书式的RSA”系统,实际应用中RSA经常与填充技术(padding)一起使用,旨在可以增加RSA的安全性(当然现在这种填充规范已经不再安全了)。

def is_pkcs_1_v_1_5_format_conforming(data: bytes, is_pub_key_op: bool = True) -> bool:
    """
        判断data是否按照PKCS#1v1.5规范进行填充
        PKCS#1v1.5规范遵循如下规则:
            在进行RSA运算时需要将源数据D转化为Encryption block(EB):
                EB = 00 + BT + PS + 00 + D
            00: 开头为00,是一个保留位
            BT: 用一个字节表示,在目前的版本上,有三个值000102,
                如果使用公钥操作,BT02(加密),
                如果用私钥操作,则可能为0001(签名)
            PS:填充位,PS = k - 3D 个字节,k表示密钥的字节长度,D表示明文数据D的字节长度
                如果BT00,则PS全部为00,
                如果BT01,则PS全部为FF,
                如果BT02PS为随机产生的非0x00的字节数据
            00: 在源数据D前一个字节用00分割
            D:  实际源数据
    """
    bits_len = len(data) * 8
    if bits_len < 512:
        """
            RSA密钥长度最少为512比特
        """
        print("invalid data length in bits:{}".format(bits_len))
        return False
    data_list = list(data)
    if data_list[0] != 0x00:
        print("invalid first byte value:{}, should be:{}".format(data_list[0], 0x00))
        return False
    if is_pub_key_op and data_list[1] != 0x02:
        print("invalid BT byte value:{}, should be:{}".format(data_list[1], 0x02))
        return False
    if (not is_pub_key_op) and (data_list[1] not in [0x00, 0x01]):
        print("invalid BT byte value:{}, should be:{}".format(data_list[1], [0x00, 0x01]))
        return False
    zero_split_byte_index: int = 0
    for i in range(2, len(data)):
        if (data_list[1] == 0x02 or data_list[1] == 0x01) and data_list[i] == 0x00:
            zero_split_byte_index = i
            break
        elif data_list[1] == 0x00:
            pass
    if zero_split_byte_index <= 0:
        print("zero split byte not found")
        return False
    if zero_split_byte_index >= len(data) - 1:
        print("no plain data appended")
        return False
    if zero_split_byte_index - 2 < 8:
        print("PS data too short")
        return False
    for i in range(2, zero_split_byte_index):
        if data_list[1] == 0x00 and data_list[i] != 0x00:
            print("BT byte is:{}, PS should all be 0x00".format(data_list[1]))
            return False
        elif data_list[1] == 0x01 and data_list[i] != 0xff:
            print("BT byte is:{}, PS should all be 0xff".format(data_list[1]))
            return False
        elif data_list[1] == 0x02 and data_list[i] == 0x00:
            print("BT byte is:{}, PS should all greater than 0x00".format(data_list[1]))
            return False
    return True

3.2 特性

满足PKCSv1.5填充规范的明文一定具有如下两个特征:

  1. 实际明文的长度一定小于等于 k – 11 字节(PS字段至少具有8字节),其中 k == 模数n/8。
  2. 填充后的明文的整数值一定满足 2B <= m < 3B,其中B == 2^(8*(k-2))。
    def test_PKCS_1_v_1_5_conforming_proper(self):
        """
            PKCS#1v1.5填充格式的字节序列,假设总长度为k字节(k取值:128256512)
            则一定会满足:
                 = 00 + BT + PS + 00 + D
                00: 开头为00,是一个保留位
                BT: 用一个字节表示,在目前的版本上,有三个值000102,
                    如果使用公钥操作,BT02(加密),
                    如果用私钥操作,则可能为0001(签名)
                PS:填充位,PS = k - 3D 个字节,k表示密钥的字节长度,D表示明文数据D的字节长度
                    如果BT00,则PS全部为00,
                    如果BT01,则PS全部为FF,
                    如果BT02PS为随机产生的非0x00的字节数据
                00: 在源数据D前一个字节用00分割
                 D:  实际源数据
            
            对于RSA加密的场景来说,前面两个字节一定是:00 02,后面会跟着k-2个字节,
            假设k-2个字节全部为00,则EB的最小值为:00 02 00 00 00 ...
            将其转换为大整数时,前面的0没有任何权重,实际上就是:
                2 00 00 00 ...2最少需要2个比特位10来表示,
            2后面共有k-200字节,则一共有8 * (k - 2)个比特,
            此时共有2 + 8 * (k - 2)个比特位,且第一个比特位为1,则其大整数值为:
                 2^(8 * (k-2) + 2 - 1) = 2^(8*(k-2)+1) = 2 * 2^(8*(k-2))
            也就是说,最小值一定是2 * 2^(8*(k-2))
            
            类似的,对于最大值的情况,无论后面的k-2个字节取值如何,他一定小于 00 03 00 00 00 ...
            将最大值转换为大整数时,前面的0没有任何权重,实际上就是:
                3 00 00 00 ...3最少需要2个比特位11来表示,
            3后面共有k-200字节,则一共有8 * (k - 2)个比特,
            此时共有2 + 8 * (k - 2)个比特位,且第一和第二个比特位均为1,则其大整数值为:
                1 * 2^(8 * (k-2) + 2 - 1) + 1 * 2^(8 * (k-2) + 2 - 2)
                = 2^(8 * (k-2) + 1) + 2^(8 * (k-2))
                = 2 * 2^(8 * (k-2)) + 2^(8 * (k-2))
                = 3 * 2^(8 * (k-2))
            也就是说,最大值一定小于3 * 2^(8*(k-2))
            
            最终得到:
                2 * 2^(8*(k-2)) <= bytes_to_int() < 3 * 2^(8*(k-2))
        """
        for k in [128, 256, 512]:
            random.seed(time.time_ns())
            eb = bytes([0x00, 0x02])  # 在RSA加密的场景下,默认以0x00 0x02 开头
            print("k = {}".format(k))
            min_int_eb = 2 * (2 ** (8 * (k - 2)))
            max_int_eb = 3 * (2 ** (8 * (k - 2)))
            plain_len = random.randint(1, k - 11)  # 生成随机明文长度
            plain = [random.randint(0, 255) for _ in range(0, plain_len)]  # 生成随机的字节序列形式的明文
            self.assertEqual(len(plain), plain_len)
            ps_len = k - plain_len - 3
            ps = [random.randint(1, 255) for _ in range(0, ps_len)]  # 根据明文长度生成随机的ps
            self.assertEqual(len(ps), ps_len)
            for v in ps:
                self.assertNotEqual(v, 0x00)  # ps为非0x00的字节序列
            eb = eb + bytes(ps) + bytes([0x00]) + bytes(plain)
            self.assertEqual(len(eb), k)  # EB的长度应该与k相等
            print("eb = {}".format(list(eb)))
            # 默认使用大端字节序进行字节转int的计算,因为我们自己的书写习惯是类似于大端字节序
            real_value = rsa_base.bytes_to_int(num_bytes = eb, auto_select_order = False)
            # print("real value={}".format(real_value))
            self.assertGreaterEqual(real_value, min_int_eb)
            self.assertLess(real_value, max_int_eb)
            # 默认使用大端字节序进行字节转int的计算,因为我们自己的书写习惯是类似于大端字节序
            tmp = rsa_base.int_to_bytes(real_value, padding_zero_count = 1, auto_select_order = False)
            self.assertEqual(tmp, eb)
            print("=" * 16)

其中,将大整数转换为字节流的代码可以参考示例:

def int_to_bytes(num_int: int, auto_select_order: bool = True, padding_zero_count: int = 0) -> bytes:
    """
        auto_select_order 是否自动根据系统选择大小端字节序转换
        True:根据当前运行的OS选择大小端字节序转换
        False:默认按照大端字节序转换
    """
    padding_zero = bytes([0x00 for _ in range(0, padding_zero_count)])
    _, order = bytes_order()
    """
         num_int.bit_length()为最少需要多少个bit才能表达num_int的值
         比如:
             1,最少需要1个比特
             255,最少需要8比特
             以此类推
        num_int.bit_length() + 7目的是为了最少凑齐8比特,形成一个字节
     """
    if auto_select_order is False or order == 'big':
        return padding_zero + num_int.to_bytes(length = get_num_int_least_bytes(num_int), byteorder = 'big')
    else:
        return num_int.to_bytes(length = get_num_int_least_bytes(num_int), byteorder = 'little') + padding_zero

4. 填充攻击Oracle

4.1 数学推导

攻击的总体数学思路:

基本加解密原理:

加密:c = (m^e) % n

解密:m = (c^d) % n

基本定理:

(a * b) % p = (a % p * b % p) % p

(a ^ b) % p = ((a % p)^b) % p

假设此时有一个随机的明文s,通过构造这样一种密文c_x,使得:

c_x = (c * s^e) % n

=( c % n * s^e % n ) % n

= (c * c_s) % n

其中c_s为明文s加密后对应的密文

进而反推对c_x进行解密,假设其明文为s_m,此时有:

s_m = (c_x^d) % n

s_m = ((c * s^e) % n)^d % n # 定理:(a ^ b) % p = ((a % p)^b) % p

= (c * s^e) ^ d % n

= (((m^e)%n) * s^e)^d%n

= ((m^e)%n)^d * s^e^d) % n # 定理:(a * b) % p = (a % p * b % p) % p

= ((m^e)%n)^d%n * s^e^d % n) %n # m = (m^e)%n)^d%n

= (m * s^e^d % n) % n

= (m * (s^e%n)^d%n) % n # s = (s^e%n)^d%n

= (m * s) % n

总体上来说,攻击者可以通过不断的发送特定的s给RSA解密服务端,通过服务端返回的解密明文ms是否符合PKCSv1.5规范来缩小明文m的取值范围,直到最后得到精确的明文m。

此攻击场景,在早期的SSL/TLS协议握手过程中,在对使用PKCS#1填充方式的RSA解密结果作处理时,会从中提取部分内容作版本号检查,版本号检查的结果能够被作为侧信道来泄露相关信息,攻击者可以利用泄露的信息来通过Bleichenbachor’s Attack解密任意明文或者伪造签名。

4.2 示例

from cryptography.hazmat.primitives.asymmetric import rsa, padding
from cryptography.hazmat.backends import default_backend
import gmpy2
from collections import namedtuple
from Cryptography import rsa_base

def simple_rsa_encrypt(m, public_key):
    numbers = public_key.public_numbers()
    # Encryption is(m^e) % n.
    return gmpy2.powmod(m, numbers.e, numbers.n)

def simple_rsa_decrypt(c, private_key):
    numbers = private_key.private_numbers()
    # Decryption is(c^d) % n.
    return gmpy2.powmod(c, numbers.d, numbers.public_numbers.n)

def int_to_bytes(i, min_size = None):
    i = int(i)
    b = i.to_bytes((i.bit_length() + 7) // 8, byteorder = 'big')
    if min_size is not None and len(b) < min_size:
        b = b'\x00' * (min_size - len(b)) + b
    return b

def bytes_to_int(b):
    return int.from_bytes(b, byteorder = 'big')

Interval = namedtuple('Interval', ['a', 'b'])

# RSA Oracle Attack Component
class FakeOracle:
    def __init__(self, private_key):
        self.private_key = private_key
    
    def __call__(self, cipher_text):
        recovered_as_int = simple_rsa_decrypt(cipher_text, self.private_key)
        recovered = int_to_bytes(recovered_as_int, self.private_key.key_size // 8)
        return recovered[0:2] == bytes([0, 2])


class RSAOracleAttacker:
    def __init__(self, public_key, oracle):
        self.public_key = public_key
        self.oracle: FakeOracle = oracle
    
    def _step1_blinding(self, c):
        """
            盲猜: 随机选取一个s,使得:c * ((s^e) % n)得到的密文结果,解密后也是符合PKCSv1.5填充规范的
            这里由于c本身在加密的时候,其明文本身已经是被PKCS填充过的,
            因此如果只对c做解密得到的明文也一定是符合PKCS规范的
            因此这里直接将s赋值成1,可以保证: c * ((s^e) % n) == c
        """
        self.c0 = c  # 原始密文
        self.B = 2 ** (self.public_key.key_size - 16)  # 被pkcs_v1.5填充后可能存在的取值数量
        self.s = [1]  # 随机选取的明文值,第一个值取数字1
        self.M = [[Interval(2 * self.B, (3 * self.B) - 1)]]  # 原始的pkcs_v1.5的取值范围
        self.i = 1  # 第一次随机选取
        self.n = self.public_key.public_numbers().n  # 密钥的模数
    
    # RSA Oracle Attack Component, part of class RSAOracleAttacker
    def _find_s(self, start_s, s_max = None):
        si = start_s
        ci = simple_rsa_encrypt(si, self.public_key)  # 运算(s^e)%n
        # 按照 c_x = (c * s^e) % n 构造密文
        while not self.oracle(((self.c0 * ci) % self.n)):
            si += 1  # s每次加一
            if s_max and (si > s_max):
                return None
            ci = simple_rsa_encrypt(si, self.public_key)
        return si
    
    # RSA Oracle Attack Component, part of class RSAOracleAttacker
    def _step2a_start_the_searching(self):
        # 从n/3B开始寻找符合要求的第一个s
        si = self._find_s(start_s = gmpy2.c_div(self.n, 3 * self.B))
        return si
    
    def _step2b_searching_with_more_than_one_interval(self):
        si = self._find_s(start_s = self.s[-1] + 1)
        return si
    
    def _step2c_searching_with_one_interval_left(self):
        a, b = self.M[-1][0]
        ri = gmpy2.c_div(2 * (b * self.s[-1] - 2 * self.B), self.n)
        si = None
        while si is None:
            si = gmpy2.c_div((2 * self.B + ri * self.n), b)
            s_max = gmpy2.c_div((3 * self.B + ri * self.n), a)
            si = self._find_s(start_s = si, s_max = s_max)
            ri += 1
        return si
    
    def _step3_narrowing_set_of_solutions(self, si):
        new_intervals = set()
        for a, b in self.M[-1]:
            r_min = gmpy2.c_div((a * si - 3 * self.B + 1), self.n)
            r_max = gmpy2.f_div((b * si - 2 * self.B), self.n)
            
            for r in range(r_min, r_max + 1):
                a_candidate = gmpy2.c_div((2 * self.B + r * self.n), si)
                b_candidate = gmpy2.f_div((3 * self.B - 1 + r * self.n), si)
                
                new_interval = Interval(max(a, a_candidate), min(b, b_candidate))
                new_intervals.add(new_interval)
        new_intervals = list(new_intervals)
        self.M.append(new_intervals)
        self.s.append(si)
        if len(new_intervals) == 1 and new_intervals[0].a == new_intervals[0].b:
            return True
        return False
    
    def _step4_computing_the_solution(self):
        interval = self.M[-1][0]
        return interval.a
    
    def attack(self, c):
        self._step1_blinding(c)
        # do this until there is one interval left
        finished = False
        while not finished:
            if self.i == 1:
                si = self._step2a_start_the_searching()
            elif len(self.M[-1]) > 1:
                si = self._step2b_searching_with_more_than_one_interval()
            elif len(self.M[-1]) == 1:
                # interval = self.M[-1][0]
                si = self._step2c_searching_with_one_interval_left()
            finished = self._step3_narrowing_set_of_solutions(si)
            self.i += 1
        m = self._step4_computing_the_solution()
        return m


if __name__ == "__main__":
    private_key = rsa.generate_private_key(
        public_exponent = 65537,
        key_size = 1024,
        backend = default_backend()
    )
    public_key = private_key.public_key()
    # 使用公钥对明文加密得到密文
    test_plain = "hello,world"
    test_plain_int = bytes_to_int(test_plain.encode(encoding = 'utf-8'))
    print("test_plain_int:{}".format(hex(test_plain_int)))
    test_cipher = public_key.encrypt(plaintext = test_plain.encode(encoding = 'utf-8'), padding = padding.PKCS1v15())
    test_cipher_int = bytes_to_int(test_cipher)
    print("test_cipher_int:{}".format(hex(test_cipher_int)))
    
    test_oracle = FakeOracle(private_key)
    test_attacker = RSAOracleAttacker(public_key, test_oracle)
    attack_plain_int = test_attacker.attack(test_cipher_int)
    print("attack_plain_int:{}".format(hex(attack_plain_int)))
    attack_plain_bytes = int_to_bytes(attack_plain_int, 1024 // 8)
    print("attack_plain_bytes:{}".format(list(attack_plain_bytes)))
    print("pkcs_v1.5 format padded:{}".format(rsa_base.is_pkcs_1_v_1_5_format_conforming(attack_plain_bytes, True)))

上述的示例代码中,在步骤_find_s,_step2c_searching_with_one_interval_left,_step3_narrowing_set_of_solutions中涉及到部分较复杂的数学工具(主要是数论知识)的使用:

这里由于本人时间、精力、以及数学功底能力有限,并未进行深入的数学推导,欢迎感兴趣的同学进行指导。

5. RSA加解密的填充选择

RSA_PKCS1_PADDING(V1.5)的缺点是无法验证解密的结果的正确性,为了解决该问题,RSA_PKCS1_OAEP_PADDING引入了类似HMAC消息验证码的算法,其原理就是填充了一些与原文相关的哈希值,解密后可以进行验证。

RSA_PKCS1_OAEP_PADDING是目前RSA填充方式里安全性最高的一种,代价则是可加密的明文长度较短。

其最大明文长度计算公式可以概括为:

mLen = k - 2 * hLen - 2 if we want to calculate the maximum message size:
    k - length in octets of the RSA modulus n
    hLen - output length in octets of hash function Hash
    mLen - length in octets of a message M

但是有一点值得注意的是,无论选择PKCS#1还是OAEP的填充方式,经过填充后,即使明文和公钥每次都相同,但是每次填充后输出的密文都会改变。

工程实践上,对于RSA的加解密,建议默认使用RSA_PKCS1_OAEP_PADDING填充方式。

6. 参考文档

《Chosen Ciphertext Attacks Against Protocols Based on the RSA Encryption Standard PKCS #1》

《PKCS #1: RSA Cryptography Specifications Version 2.2》

《PKCS #1 Version 2.2: RSA Cryptography Specifications draft-moriarty-pkcs1-00》

《Klima-Pokorny-Rosa extension of Bleichbacher’s attack on PKCS #1 v1.5 padding》

正文完