refactor: optimize RSA encryption and decryption functions with caching
parent
23147e5498
commit
37d886e9ed
|
|
@ -8,6 +8,7 @@
|
||||||
"""
|
"""
|
||||||
import base64
|
import base64
|
||||||
import threading
|
import threading
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
from Crypto.Cipher import PKCS1_v1_5 as PKCS1_cipher
|
from Crypto.Cipher import PKCS1_v1_5 as PKCS1_cipher
|
||||||
from Crypto.PublicKey import RSA
|
from Crypto.PublicKey import RSA
|
||||||
|
|
@ -70,7 +71,7 @@ def encrypt(msg, public_key: str | None = None):
|
||||||
"""
|
"""
|
||||||
if public_key is None:
|
if public_key is None:
|
||||||
public_key = get_key_pair().get('key')
|
public_key = get_key_pair().get('key')
|
||||||
cipher = PKCS1_cipher.new(RSA.importKey(public_key))
|
cipher = _get_encrypt_cipher(public_key)
|
||||||
encrypt_msg = cipher.encrypt(msg.encode("utf-8"))
|
encrypt_msg = cipher.encrypt(msg.encode("utf-8"))
|
||||||
return base64.b64encode(encrypt_msg).decode()
|
return base64.b64encode(encrypt_msg).decode()
|
||||||
|
|
||||||
|
|
@ -84,56 +85,69 @@ def decrypt(msg, pri_key: str | None = None):
|
||||||
"""
|
"""
|
||||||
if pri_key is None:
|
if pri_key is None:
|
||||||
pri_key = get_key_pair().get('value')
|
pri_key = get_key_pair().get('value')
|
||||||
cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code))
|
cipher = _get_cipher(pri_key)
|
||||||
decrypt_data = cipher.decrypt(base64.b64decode(msg), 0)
|
decrypt_data = cipher.decrypt(base64.b64decode(msg), 0)
|
||||||
return decrypt_data.decode("utf-8")
|
return decrypt_data.decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=2)
|
||||||
|
def _get_encrypt_cipher(public_key: str):
|
||||||
|
"""缓存加密 cipher 对象"""
|
||||||
|
return PKCS1_cipher.new(RSA.importKey(extern_key=public_key, passphrase=secret_code))
|
||||||
|
|
||||||
|
|
||||||
def rsa_long_encrypt(message, public_key: str | None = None, length=200):
|
def rsa_long_encrypt(message, public_key: str | None = None, length=200):
|
||||||
"""
|
"""
|
||||||
超长文本加密
|
超长文本加密
|
||||||
|
|
||||||
:param message: 需要加密的字符串
|
:param message: 需要加密的字符串
|
||||||
:param public_key 公钥
|
:param public_key 公钥
|
||||||
:param length: 1024bit的证书用100, 2048bit的证书用 200
|
:param length: 1024bit的证书用100, 2048bit的证书用 200
|
||||||
:return: 加密后的数据
|
:return: 加密后的数据
|
||||||
"""
|
"""
|
||||||
# 读取公钥
|
|
||||||
if public_key is None:
|
if public_key is None:
|
||||||
public_key = get_key_pair().get('key')
|
public_key = get_key_pair().get('key')
|
||||||
cipher = PKCS1_cipher.new(RSA.importKey(extern_key=public_key,
|
|
||||||
passphrase=secret_code))
|
cipher = _get_encrypt_cipher(public_key)
|
||||||
# 处理:Plaintext is too long. 分段加密
|
|
||||||
if len(message) <= length:
|
if len(message) <= length:
|
||||||
# 对编码的数据进行加密,并通过base64进行编码
|
|
||||||
result = base64.b64encode(cipher.encrypt(message.encode('utf-8')))
|
result = base64.b64encode(cipher.encrypt(message.encode('utf-8')))
|
||||||
else:
|
else:
|
||||||
rsa_text = []
|
rsa_text = []
|
||||||
# 对编码后的数据进行切片,原因:加密长度不能过长
|
|
||||||
for i in range(0, len(message), length):
|
for i in range(0, len(message), length):
|
||||||
cont = message[i:i + length]
|
cont = message[i:i + length]
|
||||||
# 对切片后的数据进行加密,并新增到text后面
|
|
||||||
rsa_text.append(cipher.encrypt(cont.encode('utf-8')))
|
rsa_text.append(cipher.encrypt(cont.encode('utf-8')))
|
||||||
# 加密完进行拼接
|
|
||||||
cipher_text = b''.join(rsa_text)
|
cipher_text = b''.join(rsa_text)
|
||||||
# base64进行编码
|
|
||||||
result = base64.b64encode(cipher_text)
|
result = base64.b64encode(cipher_text)
|
||||||
|
|
||||||
return result.decode()
|
return result.decode()
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=2)
|
||||||
|
def _get_cipher(pri_key: str):
|
||||||
|
"""缓存 cipher 对象,避免重复创建"""
|
||||||
|
return PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code))
|
||||||
|
|
||||||
|
|
||||||
def rsa_long_decrypt(message, pri_key: str | None = None, length=256):
|
def rsa_long_decrypt(message, pri_key: str | None = None, length=256):
|
||||||
"""
|
"""
|
||||||
超长文本解密,默认不加密
|
超长文本解密,优化内存使用
|
||||||
:param message: 需要解密的数据
|
:param message: 需要解密的数据
|
||||||
:param pri_key: 秘钥
|
:param pri_key: 秘钥
|
||||||
:param length : 1024bit的证书用128,2048bit证书用256位
|
:param length : 1024bit的证书用128,2048bit证书用256位
|
||||||
:return: 解密后的数据
|
:return: 解密后的数据
|
||||||
"""
|
"""
|
||||||
if pri_key is None:
|
if pri_key is None:
|
||||||
pri_key = get_key_pair().get('value')
|
pri_key = get_key_pair().get('value')
|
||||||
cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code))
|
|
||||||
|
cipher = _get_cipher(pri_key)
|
||||||
base64_de = base64.b64decode(message)
|
base64_de = base64.b64decode(message)
|
||||||
res = []
|
|
||||||
|
# 使用 bytearray 减少内存分配
|
||||||
|
result = bytearray()
|
||||||
for i in range(0, len(base64_de), length):
|
for i in range(0, len(base64_de), length):
|
||||||
res.append(cipher.decrypt(base64_de[i:i + length], 0))
|
result.extend(cipher.decrypt(base64_de[i:i + length], 0))
|
||||||
return b"".join(res).decode()
|
|
||||||
|
return result.decode()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue