diff --git a/src/main/java/ru/serega6531/packmate/service/optimization/TlsDecryptor.java b/src/main/java/ru/serega6531/packmate/service/optimization/TlsDecryptor.java index 68f8e0e..f775ebe 100644 --- a/src/main/java/ru/serega6531/packmate/service/optimization/TlsDecryptor.java +++ b/src/main/java/ru/serega6531/packmate/service/optimization/TlsDecryptor.java @@ -12,6 +12,7 @@ import org.bouncycastle.tls.PRFAlgorithm; import org.bouncycastle.tls.crypto.TlsSecret; import org.bouncycastle.tls.crypto.impl.bc.BcTlsCrypto; import org.bouncycastle.tls.crypto.impl.bc.BcTlsSecret; +import org.pcap4j.packet.IllegalRawDataException; import org.pcap4j.util.ByteArrays; import ru.serega6531.packmate.model.Packet; import ru.serega6531.packmate.service.optimization.tls.TlsPacket; @@ -31,6 +32,7 @@ import javax.crypto.spec.IvParameterSpec; import javax.crypto.spec.SecretKeySpec; import java.io.ByteArrayInputStream; import java.nio.ByteBuffer; +import java.security.InvalidAlgorithmParameterException; import java.security.InvalidKeyException; import java.security.NoSuchAlgorithmException; import java.security.cert.Certificate; @@ -41,7 +43,6 @@ import java.security.interfaces.RSAPublicKey; import java.util.*; import java.util.regex.Matcher; import java.util.regex.Pattern; -import java.util.stream.Stream; @Slf4j @RequiredArgsConstructor @@ -57,21 +58,32 @@ public class TlsDecryptor { private List result; private ListMultimap tlsPackets; - private CipherSuite cipherSuite; private byte[] clientRandom; private byte[] serverRandom; - @SneakyThrows public void decryptTls() { tlsPackets = ArrayListMultimap.create(packets.size(), 1); - packets.forEach(p -> tlsPackets.putAll(p, createTlsHeaders(p))); - var clientHello = (ClientHelloHandshakeRecordContent) - getHandshake(tlsPackets.values(), HandshakeType.CLIENT_HELLO).orElseThrow(); - var serverHello = (ServerHelloHandshakeRecordContent) - getHandshake(tlsPackets.values(), HandshakeType.SERVER_HELLO).orElseThrow(); + try { + for (Packet p : packets) { + tlsPackets.putAll(p, createTlsHeaders(p)); + } + } catch (IllegalRawDataException e) { + log.warn("Failed to parse TLS packets", e); + return; + } - cipherSuite = serverHello.getCipherSuite(); + var clientHelloOpt = getHandshake(HandshakeType.CLIENT_HELLO); + var serverHelloOpt = getHandshake(HandshakeType.SERVER_HELLO); + + if (clientHelloOpt.isEmpty() || serverHelloOpt.isEmpty()) { + return; + } + + var clientHello = (ClientHelloHandshakeRecordContent) clientHelloOpt.get(); + var serverHello = (ServerHelloHandshakeRecordContent) serverHelloOpt.get(); + + CipherSuite cipherSuite = serverHello.getCipherSuite(); if (cipherSuite.name().startsWith("TLS_RSA_WITH_")) { Matcher matcher = cipherSuitePattern.matcher(cipherSuite.name()); @@ -87,8 +99,14 @@ public class TlsDecryptor { } } - private void decryptTlsRsa(String blockCipher, String hashAlgo) throws CertificateException, NoSuchPaddingException, NoSuchAlgorithmException { - RSAPublicKey publicKey = getRsaPublicKey(); + private void decryptTlsRsa(String blockCipher, String hashAlgo) { + Optional publicKeyOpt = getRsaPublicKey(); + + if (publicKeyOpt.isEmpty()) { + return; + } + + RSAPublicKey publicKey = publicKeyOpt.get(); RSAPrivateKey privateKey = keysHolder.getKey(publicKey.getModulus()); if (privateKey == null) { String n = publicKey.getModulus().toString(); @@ -96,14 +114,13 @@ public class TlsDecryptor { return; } - BcTlsSecret preMaster; - try { - preMaster = getPreMaster(privateKey); - } catch (InvalidKeyException | IllegalBlockSizeException | BadPaddingException e) { - log.warn("Failed do get pre-master key", e); + Optional preMasterOptional = getPreMaster(privateKey); + if (preMasterOptional.isEmpty()) { return; } + BcTlsSecret preMaster = preMasterOptional.get(); + byte[] randomCS = ArrayUtils.addAll(clientRandom, serverRandom); byte[] randomSC = ArrayUtils.addAll(serverRandom, clientRandom); @@ -127,11 +144,25 @@ public class TlsDecryptor { bb.get(clientIV); bb.get(serverIV); - byte[] clientFinishedEncrypted = getFinishedData(tlsPackets, true); - byte[] serverFinishedEncrypted = getFinishedData(tlsPackets, false); + Optional clientFinishedOpt = getFinishedData(true); + Optional serverFinishedOpt = getFinishedData(false); - Cipher clientCipher = createCipher(clientEncryptionKey, clientIV, clientFinishedEncrypted); - Cipher serverCipher = createCipher(serverEncryptionKey, serverIV, serverFinishedEncrypted); + if (clientFinishedOpt.isEmpty() || serverFinishedOpt.isEmpty()) { + return; + } + + byte[] clientFinishedEncrypted = clientFinishedOpt.get(); + byte[] serverFinishedEncrypted = serverFinishedOpt.get(); + + Optional clientCipherOpt = createCipher(clientEncryptionKey, clientIV, clientFinishedEncrypted); + Optional serverCipherOpt = createCipher(serverEncryptionKey, serverIV, serverFinishedEncrypted); + + if (clientCipherOpt.isEmpty() || serverCipherOpt.isEmpty()) { + return; + } + + Cipher clientCipher = clientCipherOpt.get(); + Cipher serverCipher = serverCipherOpt.get(); result = new ArrayList<>(packets.size()); @@ -169,37 +200,66 @@ public class TlsDecryptor { parsed = true; } - private BcTlsSecret getPreMaster(RSAPrivateKey privateKey) throws NoSuchAlgorithmException, NoSuchPaddingException, InvalidKeyException, IllegalBlockSizeException, BadPaddingException { - var clientKeyExchange = (BasicHandshakeRecordContent) - getHandshake(tlsPackets.values(), HandshakeType.CLIENT_KEY_EXCHANGE).orElseThrow(); + @SneakyThrows(value = {NoSuchAlgorithmException.class, NoSuchPaddingException.class}) + private Optional getPreMaster(RSAPrivateKey privateKey) { + Optional opt = getHandshake(HandshakeType.CLIENT_KEY_EXCHANGE); - byte[] encryptedPreMaster = TlsKeyUtils.getClientRsaPreMaster(clientKeyExchange.getContent(), 0); + if (opt.isEmpty()) { + return Optional.empty(); + } - Cipher rsa = Cipher.getInstance("RSA/ECB/PKCS1Padding"); - rsa.init(Cipher.DECRYPT_MODE, privateKey); - byte[] preMaster = rsa.doFinal(encryptedPreMaster); - return new BcTlsSecret(new BcTlsCrypto(null), preMaster); + var clientKeyExchange = (BasicHandshakeRecordContent) opt.get(); + + try { + byte[] encryptedPreMaster = TlsKeyUtils.getClientRsaPreMaster(clientKeyExchange.getContent(), 0); + + Cipher rsa = Cipher.getInstance("RSA/ECB/PKCS1Padding"); + rsa.init(Cipher.DECRYPT_MODE, privateKey); + byte[] preMaster = rsa.doFinal(encryptedPreMaster); + return Optional.of(new BcTlsSecret(new BcTlsCrypto(null), preMaster)); + } catch (InvalidKeyException | IllegalBlockSizeException | BadPaddingException e) { + log.warn("Failed do get pre-master key", e); + return Optional.empty(); + } } - private RSAPublicKey getRsaPublicKey() throws CertificateException { - var certificateHandshake = ((CertificateHandshakeRecordContent) - getHandshake(tlsPackets.values(), HandshakeType.CERTIFICATE).orElseThrow()); + private Optional getRsaPublicKey() { + var certificateHandshakeOpt = getHandshake(HandshakeType.CERTIFICATE); + + if (certificateHandshakeOpt.isEmpty()) { + return Optional.empty(); + } + + var certificateHandshake = (CertificateHandshakeRecordContent) certificateHandshakeOpt.get(); List chain = certificateHandshake.getRawCertificates(); byte[] rawCertificate = chain.get(0); - CertificateFactory cf = CertificateFactory.getInstance("X.509"); - Certificate certificate = cf.generateCertificate(new ByteArrayInputStream(rawCertificate)); - return (RSAPublicKey) certificate.getPublicKey(); + + try { + CertificateFactory cf = CertificateFactory.getInstance("X.509"); + Certificate certificate = cf.generateCertificate(new ByteArrayInputStream(rawCertificate)); + RSAPublicKey publicKey = (RSAPublicKey) certificate.getPublicKey(); + return Optional.of(publicKey); + } catch (CertificateException e) { + log.warn("Error while getting certificate", e); + return Optional.empty(); + } } - @SneakyThrows - private Cipher createCipher(byte[] key, byte[] iv, byte[] initData) { + @SneakyThrows(value = {NoSuchAlgorithmException.class, NoSuchPaddingException.class}) + private Optional createCipher(byte[] key, byte[] iv, byte[] initData) { Cipher cipher = Cipher.getInstance("AES/CBC/PKCS5Padding"); // TLS_RSA_WITH_AES_256_CBC_SHA SecretKeySpec serverSkeySpec = new SecretKeySpec(key, "AES"); IvParameterSpec serverIvParameterSpec = new IvParameterSpec(iv); - cipher.init(Cipher.DECRYPT_MODE, serverSkeySpec, serverIvParameterSpec); - cipher.update(initData); - return cipher; + try { + cipher.init(Cipher.DECRYPT_MODE, serverSkeySpec, serverIvParameterSpec); + cipher.update(initData); + + return Optional.of(cipher); + } catch (InvalidKeyException | InvalidAlgorithmParameterException e) { + log.warn("Error decrypting TLS", e); + return Optional.empty(); + } } private byte[] clearDecodedData(byte[] decoded) { @@ -209,26 +269,22 @@ public class TlsDecryptor { return decoded; } - private byte[] getFinishedData(ListMultimap tlsPackets, boolean incoming) { - return ((BasicHandshakeRecordContent) getHandshake(tlsPackets.asMap().entrySet().stream() + private Optional getFinishedData(boolean incoming) { + var contentOpt = tlsPackets.asMap().entrySet().stream() .filter(ent -> ent.getKey().isIncoming() == incoming) .map(Map.Entry::getValue) - .flatMap(Collection::stream), HandshakeType.ENCRYPTED_HANDSHAKE_MESSAGE)) - .getContent(); - } - - private HandshakeRecordContent getHandshake(Stream stream, HandshakeType handshakeType) { - return stream.filter(p -> p.getContentType() == ContentType.HANDSHAKE) + .flatMap(Collection::stream) + .filter(p -> p.getContentType() == ContentType.HANDSHAKE) .map(p -> ((HandshakeRecord) p.getRecord())) - .filter(r -> r.getHandshakeType() == handshakeType) + .filter(r -> r.getHandshakeType() == HandshakeType.ENCRYPTED_HANDSHAKE_MESSAGE) .map(r -> ((BasicHandshakeRecordContent) r.getContent())) - .findFirst() - .orElseThrow(); + .findFirst(); + + return contentOpt.map(BasicHandshakeRecordContent::getContent); } - private Optional getHandshake(Collection packets, - HandshakeType handshakeType) { - return packets.stream() + private Optional getHandshake(HandshakeType handshakeType) { + return tlsPackets.values().stream() .filter(p -> p.getContentType() == ContentType.HANDSHAKE) .map(p -> ((HandshakeRecord) p.getRecord())) .filter(r -> r.getHandshakeType() == handshakeType) @@ -236,8 +292,7 @@ public class TlsDecryptor { .findFirst(); } - @SneakyThrows - private List createTlsHeaders(Packet p) { + private List createTlsHeaders(Packet p) throws IllegalRawDataException { List headers = new ArrayList<>(); TlsPacket tlsPacket = TlsPacket.newPacket(p.getContent(), 0, p.getContent().length); diff --git a/src/test/java/ru/serega6531/packmate/TlsDecryptorTest.java b/src/test/java/ru/serega6531/packmate/TlsDecryptorTest.java index cb32b14..f23c447 100644 --- a/src/test/java/ru/serega6531/packmate/TlsDecryptorTest.java +++ b/src/test/java/ru/serega6531/packmate/TlsDecryptorTest.java @@ -9,6 +9,8 @@ import java.io.File; import java.io.IOException; import java.util.List; +import static org.junit.jupiter.api.Assertions.*; + public class TlsDecryptorTest { @Test @@ -21,6 +23,14 @@ public class TlsDecryptorTest { TlsDecryptor decryptor = new TlsDecryptor(packets, keysHolder); decryptor.decryptTls(); + + assertTrue(decryptor.isParsed(), "TLS not parsed"); + List parsed = decryptor.getParsedPackets(); + assertNotNull(parsed, "Parsed packets list is null"); + assertEquals(4, parsed.size(), "Wrong packets list size"); + + assertTrue(new String(parsed.get(0).getContent()).contains("GET /"), "Wrong content at the start"); + assertTrue(new String(parsed.get(3).getContent()).contains("Not Found"), "Wrong content at the end"); } } diff --git a/src/test/java/ru/serega6531/packmate/TlsPacketTest.java b/src/test/java/ru/serega6531/packmate/TlsPacketTest.java deleted file mode 100644 index b4bdc62..0000000 --- a/src/test/java/ru/serega6531/packmate/TlsPacketTest.java +++ /dev/null @@ -1,26 +0,0 @@ -package ru.serega6531.packmate; - -import org.junit.jupiter.api.Test; -import org.pcap4j.packet.IllegalRawDataException; -import ru.serega6531.packmate.model.Packet; -import ru.serega6531.packmate.service.optimization.tls.TlsPacket; - -import java.io.IOException; -import java.util.List; - -public class TlsPacketTest { - - @Test - public void testHandshake() throws IOException, IllegalRawDataException { - List packets = new PackmateDumpFileLoader("tls-wolfram.pkmt").getPackets(); - - for (int i = 0; i < packets.size(); i++) { - Packet packet = packets.get(i); - System.out.println("Packet " + i + ", incoming: " + packet.isIncoming()); - byte[] content = packet.getContent(); - TlsPacket tlsPacket = TlsPacket.newPacket(content, 0, content.length); - System.out.println(tlsPacket.toString()); - } - } - -}