Добавлена обработка ошибок при расшифровке TLS

This commit is contained in:
serega6531
2020-04-26 04:37:20 +03:00
parent fd50fff1a2
commit 6f5ceb6174
3 changed files with 120 additions and 81 deletions

View File

@@ -12,6 +12,7 @@ import org.bouncycastle.tls.PRFAlgorithm;
import org.bouncycastle.tls.crypto.TlsSecret; import org.bouncycastle.tls.crypto.TlsSecret;
import org.bouncycastle.tls.crypto.impl.bc.BcTlsCrypto; import org.bouncycastle.tls.crypto.impl.bc.BcTlsCrypto;
import org.bouncycastle.tls.crypto.impl.bc.BcTlsSecret; import org.bouncycastle.tls.crypto.impl.bc.BcTlsSecret;
import org.pcap4j.packet.IllegalRawDataException;
import org.pcap4j.util.ByteArrays; import org.pcap4j.util.ByteArrays;
import ru.serega6531.packmate.model.Packet; import ru.serega6531.packmate.model.Packet;
import ru.serega6531.packmate.service.optimization.tls.TlsPacket; import ru.serega6531.packmate.service.optimization.tls.TlsPacket;
@@ -31,6 +32,7 @@ import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.SecretKeySpec; import javax.crypto.spec.SecretKeySpec;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.security.InvalidAlgorithmParameterException;
import java.security.InvalidKeyException; import java.security.InvalidKeyException;
import java.security.NoSuchAlgorithmException; import java.security.NoSuchAlgorithmException;
import java.security.cert.Certificate; import java.security.cert.Certificate;
@@ -41,7 +43,6 @@ import java.security.interfaces.RSAPublicKey;
import java.util.*; import java.util.*;
import java.util.regex.Matcher; import java.util.regex.Matcher;
import java.util.regex.Pattern; import java.util.regex.Pattern;
import java.util.stream.Stream;
@Slf4j @Slf4j
@RequiredArgsConstructor @RequiredArgsConstructor
@@ -57,21 +58,32 @@ public class TlsDecryptor {
private List<Packet> result; private List<Packet> result;
private ListMultimap<Packet, TlsPacket.TlsHeader> tlsPackets; private ListMultimap<Packet, TlsPacket.TlsHeader> tlsPackets;
private CipherSuite cipherSuite;
private byte[] clientRandom; private byte[] clientRandom;
private byte[] serverRandom; private byte[] serverRandom;
@SneakyThrows
public void decryptTls() { public void decryptTls() {
tlsPackets = ArrayListMultimap.create(packets.size(), 1); tlsPackets = ArrayListMultimap.create(packets.size(), 1);
packets.forEach(p -> tlsPackets.putAll(p, createTlsHeaders(p)));
var clientHello = (ClientHelloHandshakeRecordContent) try {
getHandshake(tlsPackets.values(), HandshakeType.CLIENT_HELLO).orElseThrow(); for (Packet p : packets) {
var serverHello = (ServerHelloHandshakeRecordContent) tlsPackets.putAll(p, createTlsHeaders(p));
getHandshake(tlsPackets.values(), HandshakeType.SERVER_HELLO).orElseThrow(); }
} catch (IllegalRawDataException e) {
log.warn("Failed to parse TLS packets", e);
return;
}
cipherSuite = serverHello.getCipherSuite(); var clientHelloOpt = getHandshake(HandshakeType.CLIENT_HELLO);
var serverHelloOpt = getHandshake(HandshakeType.SERVER_HELLO);
if (clientHelloOpt.isEmpty() || serverHelloOpt.isEmpty()) {
return;
}
var clientHello = (ClientHelloHandshakeRecordContent) clientHelloOpt.get();
var serverHello = (ServerHelloHandshakeRecordContent) serverHelloOpt.get();
CipherSuite cipherSuite = serverHello.getCipherSuite();
if (cipherSuite.name().startsWith("TLS_RSA_WITH_")) { if (cipherSuite.name().startsWith("TLS_RSA_WITH_")) {
Matcher matcher = cipherSuitePattern.matcher(cipherSuite.name()); Matcher matcher = cipherSuitePattern.matcher(cipherSuite.name());
@@ -87,8 +99,14 @@ public class TlsDecryptor {
} }
} }
private void decryptTlsRsa(String blockCipher, String hashAlgo) throws CertificateException, NoSuchPaddingException, NoSuchAlgorithmException { private void decryptTlsRsa(String blockCipher, String hashAlgo) {
RSAPublicKey publicKey = getRsaPublicKey(); Optional<RSAPublicKey> publicKeyOpt = getRsaPublicKey();
if (publicKeyOpt.isEmpty()) {
return;
}
RSAPublicKey publicKey = publicKeyOpt.get();
RSAPrivateKey privateKey = keysHolder.getKey(publicKey.getModulus()); RSAPrivateKey privateKey = keysHolder.getKey(publicKey.getModulus());
if (privateKey == null) { if (privateKey == null) {
String n = publicKey.getModulus().toString(); String n = publicKey.getModulus().toString();
@@ -96,14 +114,13 @@ public class TlsDecryptor {
return; return;
} }
BcTlsSecret preMaster; Optional<BcTlsSecret> preMasterOptional = getPreMaster(privateKey);
try { if (preMasterOptional.isEmpty()) {
preMaster = getPreMaster(privateKey);
} catch (InvalidKeyException | IllegalBlockSizeException | BadPaddingException e) {
log.warn("Failed do get pre-master key", e);
return; return;
} }
BcTlsSecret preMaster = preMasterOptional.get();
byte[] randomCS = ArrayUtils.addAll(clientRandom, serverRandom); byte[] randomCS = ArrayUtils.addAll(clientRandom, serverRandom);
byte[] randomSC = ArrayUtils.addAll(serverRandom, clientRandom); byte[] randomSC = ArrayUtils.addAll(serverRandom, clientRandom);
@@ -127,11 +144,25 @@ public class TlsDecryptor {
bb.get(clientIV); bb.get(clientIV);
bb.get(serverIV); bb.get(serverIV);
byte[] clientFinishedEncrypted = getFinishedData(tlsPackets, true); Optional<byte[]> clientFinishedOpt = getFinishedData(true);
byte[] serverFinishedEncrypted = getFinishedData(tlsPackets, false); Optional<byte[]> serverFinishedOpt = getFinishedData(false);
Cipher clientCipher = createCipher(clientEncryptionKey, clientIV, clientFinishedEncrypted); if (clientFinishedOpt.isEmpty() || serverFinishedOpt.isEmpty()) {
Cipher serverCipher = createCipher(serverEncryptionKey, serverIV, serverFinishedEncrypted); return;
}
byte[] clientFinishedEncrypted = clientFinishedOpt.get();
byte[] serverFinishedEncrypted = serverFinishedOpt.get();
Optional<Cipher> clientCipherOpt = createCipher(clientEncryptionKey, clientIV, clientFinishedEncrypted);
Optional<Cipher> serverCipherOpt = createCipher(serverEncryptionKey, serverIV, serverFinishedEncrypted);
if (clientCipherOpt.isEmpty() || serverCipherOpt.isEmpty()) {
return;
}
Cipher clientCipher = clientCipherOpt.get();
Cipher serverCipher = serverCipherOpt.get();
result = new ArrayList<>(packets.size()); result = new ArrayList<>(packets.size());
@@ -169,37 +200,66 @@ public class TlsDecryptor {
parsed = true; parsed = true;
} }
private BcTlsSecret getPreMaster(RSAPrivateKey privateKey) throws NoSuchAlgorithmException, NoSuchPaddingException, InvalidKeyException, IllegalBlockSizeException, BadPaddingException { @SneakyThrows(value = {NoSuchAlgorithmException.class, NoSuchPaddingException.class})
var clientKeyExchange = (BasicHandshakeRecordContent) private Optional<BcTlsSecret> getPreMaster(RSAPrivateKey privateKey) {
getHandshake(tlsPackets.values(), HandshakeType.CLIENT_KEY_EXCHANGE).orElseThrow(); Optional<HandshakeRecordContent> opt = getHandshake(HandshakeType.CLIENT_KEY_EXCHANGE);
byte[] encryptedPreMaster = TlsKeyUtils.getClientRsaPreMaster(clientKeyExchange.getContent(), 0); if (opt.isEmpty()) {
return Optional.empty();
}
Cipher rsa = Cipher.getInstance("RSA/ECB/PKCS1Padding"); var clientKeyExchange = (BasicHandshakeRecordContent) opt.get();
rsa.init(Cipher.DECRYPT_MODE, privateKey);
byte[] preMaster = rsa.doFinal(encryptedPreMaster); try {
return new BcTlsSecret(new BcTlsCrypto(null), preMaster); byte[] encryptedPreMaster = TlsKeyUtils.getClientRsaPreMaster(clientKeyExchange.getContent(), 0);
Cipher rsa = Cipher.getInstance("RSA/ECB/PKCS1Padding");
rsa.init(Cipher.DECRYPT_MODE, privateKey);
byte[] preMaster = rsa.doFinal(encryptedPreMaster);
return Optional.of(new BcTlsSecret(new BcTlsCrypto(null), preMaster));
} catch (InvalidKeyException | IllegalBlockSizeException | BadPaddingException e) {
log.warn("Failed do get pre-master key", e);
return Optional.empty();
}
} }
private RSAPublicKey getRsaPublicKey() throws CertificateException { private Optional<RSAPublicKey> getRsaPublicKey() {
var certificateHandshake = ((CertificateHandshakeRecordContent) var certificateHandshakeOpt = getHandshake(HandshakeType.CERTIFICATE);
getHandshake(tlsPackets.values(), HandshakeType.CERTIFICATE).orElseThrow());
if (certificateHandshakeOpt.isEmpty()) {
return Optional.empty();
}
var certificateHandshake = (CertificateHandshakeRecordContent) certificateHandshakeOpt.get();
List<byte[]> chain = certificateHandshake.getRawCertificates(); List<byte[]> chain = certificateHandshake.getRawCertificates();
byte[] rawCertificate = chain.get(0); byte[] rawCertificate = chain.get(0);
CertificateFactory cf = CertificateFactory.getInstance("X.509");
Certificate certificate = cf.generateCertificate(new ByteArrayInputStream(rawCertificate)); try {
return (RSAPublicKey) certificate.getPublicKey(); CertificateFactory cf = CertificateFactory.getInstance("X.509");
Certificate certificate = cf.generateCertificate(new ByteArrayInputStream(rawCertificate));
RSAPublicKey publicKey = (RSAPublicKey) certificate.getPublicKey();
return Optional.of(publicKey);
} catch (CertificateException e) {
log.warn("Error while getting certificate", e);
return Optional.empty();
}
} }
@SneakyThrows @SneakyThrows(value = {NoSuchAlgorithmException.class, NoSuchPaddingException.class})
private Cipher createCipher(byte[] key, byte[] iv, byte[] initData) { private Optional<Cipher> createCipher(byte[] key, byte[] iv, byte[] initData) {
Cipher cipher = Cipher.getInstance("AES/CBC/PKCS5Padding"); // TLS_RSA_WITH_AES_256_CBC_SHA Cipher cipher = Cipher.getInstance("AES/CBC/PKCS5Padding"); // TLS_RSA_WITH_AES_256_CBC_SHA
SecretKeySpec serverSkeySpec = new SecretKeySpec(key, "AES"); SecretKeySpec serverSkeySpec = new SecretKeySpec(key, "AES");
IvParameterSpec serverIvParameterSpec = new IvParameterSpec(iv); IvParameterSpec serverIvParameterSpec = new IvParameterSpec(iv);
cipher.init(Cipher.DECRYPT_MODE, serverSkeySpec, serverIvParameterSpec);
cipher.update(initData);
return cipher; try {
cipher.init(Cipher.DECRYPT_MODE, serverSkeySpec, serverIvParameterSpec);
cipher.update(initData);
return Optional.of(cipher);
} catch (InvalidKeyException | InvalidAlgorithmParameterException e) {
log.warn("Error decrypting TLS", e);
return Optional.empty();
}
} }
private byte[] clearDecodedData(byte[] decoded) { private byte[] clearDecodedData(byte[] decoded) {
@@ -209,26 +269,22 @@ public class TlsDecryptor {
return decoded; return decoded;
} }
private byte[] getFinishedData(ListMultimap<Packet, TlsPacket.TlsHeader> tlsPackets, boolean incoming) { private Optional<byte[]> getFinishedData(boolean incoming) {
return ((BasicHandshakeRecordContent) getHandshake(tlsPackets.asMap().entrySet().stream() var contentOpt = tlsPackets.asMap().entrySet().stream()
.filter(ent -> ent.getKey().isIncoming() == incoming) .filter(ent -> ent.getKey().isIncoming() == incoming)
.map(Map.Entry::getValue) .map(Map.Entry::getValue)
.flatMap(Collection::stream), HandshakeType.ENCRYPTED_HANDSHAKE_MESSAGE)) .flatMap(Collection::stream)
.getContent(); .filter(p -> p.getContentType() == ContentType.HANDSHAKE)
}
private HandshakeRecordContent getHandshake(Stream<TlsPacket.TlsHeader> stream, HandshakeType handshakeType) {
return stream.filter(p -> p.getContentType() == ContentType.HANDSHAKE)
.map(p -> ((HandshakeRecord) p.getRecord())) .map(p -> ((HandshakeRecord) p.getRecord()))
.filter(r -> r.getHandshakeType() == handshakeType) .filter(r -> r.getHandshakeType() == HandshakeType.ENCRYPTED_HANDSHAKE_MESSAGE)
.map(r -> ((BasicHandshakeRecordContent) r.getContent())) .map(r -> ((BasicHandshakeRecordContent) r.getContent()))
.findFirst() .findFirst();
.orElseThrow();
return contentOpt.map(BasicHandshakeRecordContent::getContent);
} }
private Optional<HandshakeRecordContent> getHandshake(Collection<TlsPacket.TlsHeader> packets, private Optional<HandshakeRecordContent> getHandshake(HandshakeType handshakeType) {
HandshakeType handshakeType) { return tlsPackets.values().stream()
return packets.stream()
.filter(p -> p.getContentType() == ContentType.HANDSHAKE) .filter(p -> p.getContentType() == ContentType.HANDSHAKE)
.map(p -> ((HandshakeRecord) p.getRecord())) .map(p -> ((HandshakeRecord) p.getRecord()))
.filter(r -> r.getHandshakeType() == handshakeType) .filter(r -> r.getHandshakeType() == handshakeType)
@@ -236,8 +292,7 @@ public class TlsDecryptor {
.findFirst(); .findFirst();
} }
@SneakyThrows private List<TlsPacket.TlsHeader> createTlsHeaders(Packet p) throws IllegalRawDataException {
private List<TlsPacket.TlsHeader> createTlsHeaders(Packet p) {
List<TlsPacket.TlsHeader> headers = new ArrayList<>(); List<TlsPacket.TlsHeader> headers = new ArrayList<>();
TlsPacket tlsPacket = TlsPacket.newPacket(p.getContent(), 0, p.getContent().length); TlsPacket tlsPacket = TlsPacket.newPacket(p.getContent(), 0, p.getContent().length);

View File

@@ -9,6 +9,8 @@ import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.*;
public class TlsDecryptorTest { public class TlsDecryptorTest {
@Test @Test
@@ -21,6 +23,14 @@ public class TlsDecryptorTest {
TlsDecryptor decryptor = new TlsDecryptor(packets, keysHolder); TlsDecryptor decryptor = new TlsDecryptor(packets, keysHolder);
decryptor.decryptTls(); decryptor.decryptTls();
assertTrue(decryptor.isParsed(), "TLS not parsed");
List<Packet> parsed = decryptor.getParsedPackets();
assertNotNull(parsed, "Parsed packets list is null");
assertEquals(4, parsed.size(), "Wrong packets list size");
assertTrue(new String(parsed.get(0).getContent()).contains("GET /"), "Wrong content at the start");
assertTrue(new String(parsed.get(3).getContent()).contains("Not Found"), "Wrong content at the end");
} }
} }

View File

@@ -1,26 +0,0 @@
package ru.serega6531.packmate;
import org.junit.jupiter.api.Test;
import org.pcap4j.packet.IllegalRawDataException;
import ru.serega6531.packmate.model.Packet;
import ru.serega6531.packmate.service.optimization.tls.TlsPacket;
import java.io.IOException;
import java.util.List;
public class TlsPacketTest {
@Test
public void testHandshake() throws IOException, IllegalRawDataException {
List<Packet> packets = new PackmateDumpFileLoader("tls-wolfram.pkmt").getPackets();
for (int i = 0; i < packets.size(); i++) {
Packet packet = packets.get(i);
System.out.println("Packet " + i + ", incoming: " + packet.isIncoming());
byte[] content = packet.getContent();
TlsPacket tlsPacket = TlsPacket.newPacket(content, 0, content.length);
System.out.println(tlsPacket.toString());
}
}
}