diff --git a/src/main/java/ru/serega6531/packmate/model/Packet.java b/src/main/java/ru/serega6531/packmate/model/Packet.java index cc60bd2..603202f 100644 --- a/src/main/java/ru/serega6531/packmate/model/Packet.java +++ b/src/main/java/ru/serega6531/packmate/model/Packet.java @@ -54,4 +54,14 @@ public class Packet { private byte[] content; + @Transient + @JsonIgnore + public String getContentString() { + return new String(content); + } + + public String toString() { + return "Packet(id=" + id + ", content=" + getContentString() + ")"; + } + } diff --git a/src/main/java/ru/serega6531/packmate/service/StreamOptimizer.java b/src/main/java/ru/serega6531/packmate/service/StreamOptimizer.java index 6274a97..a161afe 100644 --- a/src/main/java/ru/serega6531/packmate/service/StreamOptimizer.java +++ b/src/main/java/ru/serega6531/packmate/service/StreamOptimizer.java @@ -5,13 +5,6 @@ import lombok.SneakyThrows; import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.IOUtils; import org.apache.commons.lang3.ArrayUtils; -import org.java_websocket.drafts.Draft_6455; -import org.java_websocket.exceptions.InvalidDataException; -import org.java_websocket.exceptions.InvalidHandshakeException; -import org.java_websocket.extensions.permessage_deflate.PerMessageDeflateExtension; -import org.java_websocket.framing.Framedata; -import org.java_websocket.handshake.HandshakeImpl1Client; -import org.java_websocket.handshake.HandshakeImpl1Server; import ru.serega6531.packmate.model.CtfService; import ru.serega6531.packmate.model.Packet; import ru.serega6531.packmate.utils.Bytes; @@ -20,12 +13,9 @@ import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.net.URLDecoder; -import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.List; -import java.util.regex.Matcher; -import java.util.stream.Collectors; import java.util.zip.GZIPInputStream; import java.util.zip.ZipException; @@ -34,43 +24,31 @@ import java.util.zip.ZipException; public class StreamOptimizer { private final CtfService service; - private final List packets; + private List packets; private static final byte[] GZIP_HEADER = {0x1f, (byte) 0x8b, 0x08}; - private static final java.util.regex.Pattern WEBSOCKET_KEY_PATTERN = - java.util.regex.Pattern.compile("Sec-WebSocket-Key: (.+)\\r\\n"); - private static final java.util.regex.Pattern WEBSOCKET_EXTENSIONS_PATTERN = - java.util.regex.Pattern.compile("Sec-WebSocket-Extensions?: (.+)\\r\\n"); - private static final java.util.regex.Pattern WEBSOCKET_VERSION_PATTERN = - java.util.regex.Pattern.compile("Sec-WebSocket-Version: (\\d+)\\r\\n"); - private static final java.util.regex.Pattern WEBSOCKET_ACCEPT_PATTERN = - java.util.regex.Pattern.compile("Sec-WebSocket-Accept: (.+)\\r\\n"); - - private static final String WEBSOCKET_EXTENSION_HEADER = "Sec-WebSocket-Extension: permessage-deflate"; - private static final String WEBSOCKET_EXTENSIONS_HEADER = "Sec-WebSocket-Extensions: permessage-deflate"; - private static final String WEBSOCKET_UPGRADE_HEADER = "Upgrade: websocket\r\n"; - private static final String WEBSOCKET_CONNECTION_HEADER = "Connection: Upgrade\r\n"; - /** * Вызвать для выполнения оптимизаций на переданном списке пакетов. */ - public void optimizeStream() { + public List optimizeStream() { if (service.isUngzipHttp()) { unpackGzip(); } - if (service.isUrldecodeHttpRequests()) { - urldecodeRequests(); - } - if (service.isInflateWebSockets()) { inflateWebSocket(); } + if (service.isUrldecodeHttpRequests()) { + urldecodeRequests(); + } + if (service.isMergeAdjacentPackets()) { mergeAdjacentPackets(); } + + return packets; } /** @@ -136,7 +114,7 @@ public class StreamOptimizer { for (Packet packet : packets) { if (packet.isIncoming()) { - String content = new String(packet.getContent()); + String content = packet.getContentString(); if (content.contains("HTTP/")) { httpStarted = true; } @@ -173,7 +151,7 @@ public class StreamOptimizer { i = gzipStartPacket + 1; // продвигаем указатель на следующий после склеенного блок } } else if (!packet.isIncoming()) { - String content = new String(packet.getContent()); + String content = packet.getContentString(); int contentPos = content.indexOf("\r\n\r\n"); boolean http = content.startsWith("HTTP/"); @@ -255,147 +233,16 @@ public class StreamOptimizer { } private void inflateWebSocket() { - if (!new String(packets.get(0).getContent()).contains("HTTP/")) { + if (!packets.get(0).getContentString().contains("HTTP/")) { return; } - final List clientHandshakePackets = packets.stream() - .takeWhile(Packet::isIncoming) - .collect(Collectors.toList()); - - final String clientHandshake = getHandshake(clientHandshakePackets); - if (clientHandshake == null) { + final WebSocketsParser parser = new WebSocketsParser(packets); + if(!parser.isParsed()) { return; } - final List serverHandshakePackets = packets.stream() - .skip(clientHandshakePackets.size()) - .takeWhile(p -> !p.isIncoming()) - .collect(Collectors.toList()); - - final String serverHandshake = getHandshake(serverHandshakePackets); - if (serverHandshake == null) { - return; - } - - HandshakeImpl1Server serverHandshakeImpl = fillServerHandshake(serverHandshake); - HandshakeImpl1Client clientHandshakeImpl = fillClientHandshake(clientHandshake); - - if (serverHandshakeImpl == null || clientHandshakeImpl == null) { - return; - } - - final List wsPackets = this.packets.subList( - clientHandshakePackets.size() + serverHandshakePackets.size(), - this.packets.size()); - - final byte[] wsContent = wsPackets.stream() - .map(Packet::getContent) - .reduce(ArrayUtils::addAll) - .orElse(null); - - if (wsContent == null) { - return; - } - - final ByteBuffer frame = ByteBuffer.wrap(wsContent); - Draft_6455 draft = new Draft_6455(new PerMessageDeflateExtension()); - - try { - draft.acceptHandshakeAsServer(clientHandshakeImpl); - draft.acceptHandshakeAsClient(clientHandshakeImpl, serverHandshakeImpl); - } catch (InvalidHandshakeException e) { - log.warn("WebSocket handshake", e); - return; - } - - try { - final List list = draft.translateFrame(frame); - log.info(list.toString()); - } catch (InvalidDataException e) { - log.warn("WebSocket data", e); - } - } - - private String getHandshake(final List packets) { - final String handshake = packets.stream() - .map(Packet::getContent) - .reduce(ArrayUtils::addAll) - .map(String::new) - .orElse(null); - - if (handshake == null || - !handshake.contains(WEBSOCKET_CONNECTION_HEADER) || - !handshake.contains(WEBSOCKET_UPGRADE_HEADER)) { - return null; - } - - if (!handshake.contains(WEBSOCKET_EXTENSION_HEADER) && - !handshake.contains(WEBSOCKET_EXTENSIONS_HEADER)) { - return null; - } - - return handshake; - } - - private HandshakeImpl1Client fillClientHandshake(String clientHandshake) { - Matcher matcher = WEBSOCKET_VERSION_PATTERN.matcher(clientHandshake); - if (!matcher.find()) { - return null; - } - String version = matcher.group(1); - - matcher = WEBSOCKET_KEY_PATTERN.matcher(clientHandshake); - if (!matcher.find()) { - return null; - } - String key = matcher.group(1); - - matcher = WEBSOCKET_EXTENSIONS_PATTERN.matcher(clientHandshake); - if (!matcher.find()) { - return null; - } - String extensions = matcher.group(1); - - HandshakeImpl1Client clientHandshakeImpl = new HandshakeImpl1Client(); - - clientHandshakeImpl.put("Upgrade", "websocket"); - clientHandshakeImpl.put("Connection", "Upgrade"); - clientHandshakeImpl.put("Sec-WebSocket-Version", version); - clientHandshakeImpl.put("Sec-WebSocket-Key", key); - clientHandshakeImpl.put("Sec-WebSocket-Extensions", extensions); - - return clientHandshakeImpl; - } - - private HandshakeImpl1Server fillServerHandshake(String serverHandshake) { - Matcher matcher = WEBSOCKET_VERSION_PATTERN.matcher(serverHandshake); - if (!matcher.find()) { - return null; - } - String version = matcher.group(1); - - matcher = WEBSOCKET_ACCEPT_PATTERN.matcher(serverHandshake); - if (!matcher.find()) { - return null; - } - String accept = matcher.group(1); - - matcher = WEBSOCKET_EXTENSIONS_PATTERN.matcher(serverHandshake); - if (!matcher.find()) { - return null; - } - String extensions = matcher.group(1); - - HandshakeImpl1Server serverHandshakeImpl = new HandshakeImpl1Server(); - - serverHandshakeImpl.put("Upgrade", "websocket"); - serverHandshakeImpl.put("Connection", "Upgrade"); - serverHandshakeImpl.put("Sec-WebSocket-Version", version); - serverHandshakeImpl.put("Sec-WebSocket-Accept", accept); - serverHandshakeImpl.put("Sec-WebSocket-Extensions", extensions); - - return serverHandshakeImpl; + packets = parser.getParsedPackets(); } } diff --git a/src/main/java/ru/serega6531/packmate/service/StreamService.java b/src/main/java/ru/serega6531/packmate/service/StreamService.java index 19afa74..8d0fc25 100644 --- a/src/main/java/ru/serega6531/packmate/service/StreamService.java +++ b/src/main/java/ru/serega6531/packmate/service/StreamService.java @@ -93,7 +93,7 @@ public class StreamService { countingService.countStream(service.getPort(), packets.size()); - new StreamOptimizer(service, packets).optimizeStream(); + packets = new StreamOptimizer(service, packets).optimizeStream(); processUserAgent(packets, stream); Stream savedStream = save(stream); @@ -110,7 +110,7 @@ public class StreamService { private void processUserAgent(List packets, Stream stream) { String ua = null; for (Packet packet : packets) { - String content = new String(packet.getContent()); + String content = packet.getContentString(); final Matcher matcher = userAgentPattern.matcher(content); if (matcher.find()) { ua = matcher.group(1); diff --git a/src/main/java/ru/serega6531/packmate/service/WebSocketsParser.java b/src/main/java/ru/serega6531/packmate/service/WebSocketsParser.java new file mode 100644 index 0000000..d8bde27 --- /dev/null +++ b/src/main/java/ru/serega6531/packmate/service/WebSocketsParser.java @@ -0,0 +1,221 @@ +package ru.serega6531.packmate.service; + +import lombok.Getter; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.ArrayUtils; +import org.java_websocket.drafts.Draft_6455; +import org.java_websocket.exceptions.InvalidDataException; +import org.java_websocket.exceptions.InvalidHandshakeException; +import org.java_websocket.extensions.permessage_deflate.PerMessageDeflateExtension; +import org.java_websocket.framing.DataFrame; +import org.java_websocket.framing.Framedata; +import org.java_websocket.handshake.HandshakeImpl1Client; +import org.java_websocket.handshake.HandshakeImpl1Server; +import ru.serega6531.packmate.model.Packet; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import java.util.regex.Matcher; +import java.util.stream.Collectors; + +@Slf4j +public class WebSocketsParser { + + private static final java.util.regex.Pattern WEBSOCKET_KEY_PATTERN = + java.util.regex.Pattern.compile("Sec-WebSocket-Key: (.+)\\r\\n"); + private static final java.util.regex.Pattern WEBSOCKET_EXTENSIONS_PATTERN = + java.util.regex.Pattern.compile("Sec-WebSocket-Extensions?: (.+)\\r\\n"); + private static final java.util.regex.Pattern WEBSOCKET_VERSION_PATTERN = + java.util.regex.Pattern.compile("Sec-WebSocket-Version: (\\d+)\\r\\n"); + private static final java.util.regex.Pattern WEBSOCKET_ACCEPT_PATTERN = + java.util.regex.Pattern.compile("Sec-WebSocket-Accept: (.+)\\r\\n"); + + private static final String WEBSOCKET_EXTENSION_HEADER = "Sec-WebSocket-Extension: permessage-deflate"; + private static final String WEBSOCKET_EXTENSIONS_HEADER = "Sec-WebSocket-Extensions: permessage-deflate"; + private static final String WEBSOCKET_UPGRADE_HEADER = "upgrade: websocket\r\n"; + private static final String WEBSOCKET_CONNECTION_HEADER = "connection: upgrade\r\n"; + + private final List packets; + private List frames; + + @Getter + private boolean parsed = false; + private int httpEnd = -1; + + public WebSocketsParser(List packets) { + this.packets = packets; + detectWebSockets(); + } + + private void detectWebSockets() { + final List clientHandshakePackets = packets.stream() + .takeWhile(Packet::isIncoming) + .collect(Collectors.toList()); + + final String clientHandshake = getHandshake(clientHandshakePackets); + if (clientHandshake == null) { + return; + } + + for (int i = clientHandshakePackets.size(); i < packets.size(); i++) { + if (packets.get(i).getContentString().endsWith("\r\n\r\n")) { + httpEnd = i + 1; + break; + } + } + + if (httpEnd == -1) { + return; + } + + final List serverHandshakePackets = packets.subList(clientHandshakePackets.size(), httpEnd); + final String serverHandshake = getHandshake(serverHandshakePackets); + if (serverHandshake == null) { + return; + } + + HandshakeImpl1Server serverHandshakeImpl = fillServerHandshake(serverHandshake); + HandshakeImpl1Client clientHandshakeImpl = fillClientHandshake(clientHandshake); + + if (serverHandshakeImpl == null || clientHandshakeImpl == null) { + return; + } + + Draft_6455 draft = new Draft_6455(new PerMessageDeflateExtension()); + + try { + draft.acceptHandshakeAsServer(clientHandshakeImpl); + draft.acceptHandshakeAsClient(clientHandshakeImpl, serverHandshakeImpl); + } catch (InvalidHandshakeException e) { + log.warn("WebSocket handshake", e); + return; + } + + final List wsPackets = this.packets.subList( + httpEnd, + this.packets.size()); + + final byte[] wsContent = wsPackets.stream() + .map(Packet::getContent) + .reduce(ArrayUtils::addAll) + .orElse(null); + + if (wsContent == null) { + return; + } + + final ByteBuffer frame = ByteBuffer.wrap(wsContent); + + try { + frames = draft.translateFrame(frame); + } catch (InvalidDataException e) { + log.warn("WebSocket data", e); + return; + } + + parsed = true; + } + + public List getParsedPackets() { + if (!parsed) { + throw new IllegalStateException("WS is not parsed"); + } + + final List handshakes = packets.subList(0, httpEnd); + List newPackets = new ArrayList<>(handshakes.size() + frames.size()); + newPackets.addAll(handshakes); + + final Packet lastPacket = packets.get(packets.size() - 1); + + for (Framedata frame : frames) { + if(frame instanceof DataFrame) { + newPackets.add(Packet.builder() + .content(frame.getPayloadData().array()) + .incoming(true) //TODO + .timestamp(lastPacket.getTimestamp()) + .ttl(lastPacket.getTtl()) + .ungzipped(lastPacket.isUngzipped()) + .build() + ); + } + } + + return newPackets; + } + + private String getHandshake(final List packets) { + final String handshake = packets.stream() + .map(Packet::getContent) + .reduce(ArrayUtils::addAll) + .map(String::new) + .orElse(null); + + if (handshake == null || + !handshake.toLowerCase().contains(WEBSOCKET_CONNECTION_HEADER) || + !handshake.toLowerCase().contains(WEBSOCKET_UPGRADE_HEADER)) { + return null; + } + + if (!handshake.contains(WEBSOCKET_EXTENSION_HEADER) && + !handshake.contains(WEBSOCKET_EXTENSIONS_HEADER)) { + return null; + } + + return handshake; + } + + private HandshakeImpl1Client fillClientHandshake(String clientHandshake) { + Matcher matcher = WEBSOCKET_VERSION_PATTERN.matcher(clientHandshake); + if (!matcher.find()) { + return null; + } + String version = matcher.group(1); + + matcher = WEBSOCKET_KEY_PATTERN.matcher(clientHandshake); + if (!matcher.find()) { + return null; + } + String key = matcher.group(1); + + matcher = WEBSOCKET_EXTENSIONS_PATTERN.matcher(clientHandshake); + if (!matcher.find()) { + return null; + } + String extensions = matcher.group(1); + + HandshakeImpl1Client clientHandshakeImpl = new HandshakeImpl1Client(); + + clientHandshakeImpl.put("Upgrade", "websocket"); + clientHandshakeImpl.put("Connection", "Upgrade"); + clientHandshakeImpl.put("Sec-WebSocket-Version", version); + clientHandshakeImpl.put("Sec-WebSocket-Key", key); + clientHandshakeImpl.put("Sec-WebSocket-Extensions", extensions); + + return clientHandshakeImpl; + } + + private HandshakeImpl1Server fillServerHandshake(String serverHandshake) { + Matcher matcher = WEBSOCKET_ACCEPT_PATTERN.matcher(serverHandshake); + if (!matcher.find()) { + return null; + } + String accept = matcher.group(1); + + matcher = WEBSOCKET_EXTENSIONS_PATTERN.matcher(serverHandshake); + if (!matcher.find()) { + return null; + } + String extensions = matcher.group(1); + + HandshakeImpl1Server serverHandshakeImpl = new HandshakeImpl1Server(); + + serverHandshakeImpl.put("Upgrade", "websocket"); + serverHandshakeImpl.put("Connection", "Upgrade"); + serverHandshakeImpl.put("Sec-WebSocket-Accept", accept); + serverHandshakeImpl.put("Sec-WebSocket-Extensions", extensions); + + return serverHandshakeImpl; + } + +} diff --git a/src/test/java/ru/serega6531/packmate/StreamOptimizerTest.java b/src/test/java/ru/serega6531/packmate/StreamOptimizerTest.java index ecdbf5c..67bde4a 100644 --- a/src/test/java/ru/serega6531/packmate/StreamOptimizerTest.java +++ b/src/test/java/ru/serega6531/packmate/StreamOptimizerTest.java @@ -28,8 +28,8 @@ class StreamOptimizerTest { List list = new ArrayList<>(); list.add(p); - new StreamOptimizer(service, list).optimizeStream(); - final String processed = new String(list.get(0).getContent()); + list = new StreamOptimizer(service, list).optimizeStream(); + final String processed = list.get(0).getContentString(); assertTrue(processed.contains("aaabbb")); } @@ -42,8 +42,8 @@ class StreamOptimizerTest { List list = new ArrayList<>(); list.add(p); - new StreamOptimizer(service, list).optimizeStream(); - final String processed = new String(list.get(0).getContent()); + list = new StreamOptimizer(service, list).optimizeStream(); + final String processed = list.get(0).getContentString(); assertTrue(processed.contains("а б")); } @@ -67,7 +67,7 @@ class StreamOptimizerTest { list.add(p5); list.add(p6); - new StreamOptimizer(service, list).optimizeStream(); + list = new StreamOptimizer(service, list).optimizeStream(); assertEquals(4, list.size()); assertEquals(2, list.get(1).getContent().length);