SUPPORT-9023: move socket connection to listener; fix NPE; add retryable annotation

This commit is contained in:
gulnaz 2025-03-24 12:45:14 +03:00
parent b7e59faa06
commit 8350cde23b
7 changed files with 76 additions and 24 deletions

View file

@ -224,6 +224,14 @@
<groupId>org.hibernate.validator</groupId> <groupId>org.hibernate.validator</groupId>
<artifactId>hibernate-validator</artifactId> <artifactId>hibernate-validator</artifactId>
</dependency> </dependency>
<dependency>
<groupId>org.springframework.retry</groupId>
<artifactId>spring-retry</artifactId>
</dependency>
<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-aspects</artifactId>
</dependency>
</dependencies> </dependencies>
<build> <build>
<finalName>${project.parent.artifactId}</finalName> <finalName>${project.parent.artifactId}</finalName>

View file

@ -16,6 +16,7 @@ import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.EnableAspectJAutoProxy; import org.springframework.context.annotation.EnableAspectJAutoProxy;
import org.springframework.context.annotation.FilterType; import org.springframework.context.annotation.FilterType;
import org.springframework.context.support.PropertySourcesPlaceholderConfigurer; import org.springframework.context.support.PropertySourcesPlaceholderConfigurer;
import org.springframework.retry.annotation.EnableRetry;
import org.springframework.scheduling.annotation.EnableScheduling; import org.springframework.scheduling.annotation.EnableScheduling;
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler; import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
import org.springframework.web.client.RestTemplate; import org.springframework.web.client.RestTemplate;
@ -47,6 +48,7 @@ import org.springframework.web.servlet.config.annotation.EnableWebMvc;
@EnableAspectJAutoProxy(proxyTargetClass = true) @EnableAspectJAutoProxy(proxyTargetClass = true)
@EnableWebMvc @EnableWebMvc
@EnableScheduling @EnableScheduling
@EnableRetry
public class AppConfig { public class AppConfig {
@Bean @Bean

View file

@ -5,7 +5,6 @@ import java.lang.invoke.MethodHandles;
import java.net.URLDecoder; import java.net.URLDecoder;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.util.Optional; import java.util.Optional;
import java.util.concurrent.Executors;
import javax.servlet.FilterChain; import javax.servlet.FilterChain;
import javax.servlet.ServletException; import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
@ -16,7 +15,6 @@ import org.slf4j.LoggerFactory;
import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import org.springframework.web.filter.OncePerRequestFilter; import org.springframework.web.filter.OncePerRequestFilter;
@ -44,13 +42,7 @@ public class JwtAuthenticationFilter extends OncePerRequestFilter {
try { try {
Authentication authentication = attemptAuthentication(request); Authentication authentication = attemptAuthentication(request);
if (authentication != null) { if (authentication != null) {
SecurityContext context = SecurityContextHolder.getContext(); SecurityContextHolder.getContext().setAuthentication(authentication);
context.setAuthentication(authentication);
//TODO SUPPORT-9009 connection by duty user
new Thread(() -> {
SecurityContextHolder.setContext(context);
webSocketService.connectToSocket();
}).start();
} }
} }
catch (AuthenticationException e) { catch (AuthenticationException e) {

View file

@ -0,0 +1,34 @@
package ru.micord.ervu.account_applications.security.listener;
import org.springframework.context.ApplicationListener;
import org.springframework.security.authentication.event.AuthenticationSuccessEvent;
import org.springframework.stereotype.Component;
import ru.micord.ervu.account_applications.security.model.jwt.UserSession;
import ru.micord.ervu.account_applications.security.model.jwt.authentication.JwtTokenAuthentication;
import ru.micord.ervu.account_applications.websocket.service.WebSocketService;
/**
* @author gulnaz
*/
@Component
public class SuccessfulAuthListener implements ApplicationListener<AuthenticationSuccessEvent> {
private static final String ADMIN_ROLE = "security_administrator";
private final WebSocketService webSocketService;
public SuccessfulAuthListener(WebSocketService webSocketService) {
this.webSocketService = webSocketService;
}
@Override
public void onApplicationEvent(AuthenticationSuccessEvent event) {
JwtTokenAuthentication authentication = (JwtTokenAuthentication) event.getAuthentication();
UserSession userSession = authentication.getUserSession();
boolean isAdmin = userSession.roles().stream()
.anyMatch(ervuRoleAuthority -> ervuRoleAuthority.getAuthority().equals(ADMIN_ROLE));
if (isAdmin) {
webSocketService.connectToSocket(userSession.userId(), authentication.getCredentials().toString());
}
}
}

View file

@ -15,7 +15,6 @@ import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.WebSocketMessage; import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.TextWebSocketHandler; import org.springframework.web.socket.handler.TextWebSocketHandler;
import ru.micord.ervu.account_applications.security.context.SecurityContext;
import ru.micord.ervu.account_applications.security.service.EncryptionService; import ru.micord.ervu.account_applications.security.service.EncryptionService;
import ru.micord.ervu.account_applications.service.UserApplicationListService; import ru.micord.ervu.account_applications.service.UserApplicationListService;
import ru.micord.ervu.account_applications.websocket.dto.ProcessResponseDto; import ru.micord.ervu.account_applications.websocket.dto.ProcessResponseDto;
@ -28,23 +27,21 @@ import ru.micord.ervu.account_applications.websocket.service.WebSocketService;
public class ClientSocketHandler extends TextWebSocketHandler { public class ClientSocketHandler extends TextWebSocketHandler {
private static final Logger LOGGER = LoggerFactory.getLogger(TextWebSocketHandler.class); private static final Logger LOGGER = LoggerFactory.getLogger(TextWebSocketHandler.class);
private static final Map<String, WebSocketSession> sessionByUserId = new ConcurrentHashMap<>(); private static final Map<String, WebSocketSession> sessionByUserId = new ConcurrentHashMap<>();
private static final Map<String, UserData> userDataBySessionId = new ConcurrentHashMap<>();
private final ObjectMapper objectMapper; private final ObjectMapper objectMapper;
private final UserApplicationListService applicationService; private final UserApplicationListService applicationService;
private final EncryptionService encryptionService; private final EncryptionService encryptionService;
private final WebSocketService webSocketService; private final WebSocketService webSocketService;
private final SecurityContext securityContext;
public ClientSocketHandler(ObjectMapper objectMapper, public ClientSocketHandler(ObjectMapper objectMapper,
UserApplicationListService applicationService, UserApplicationListService applicationService,
EncryptionService encryptionService, EncryptionService encryptionService,
@Lazy WebSocketService webSocketService, @Lazy WebSocketService webSocketService) {
SecurityContext securityContext) {
this.objectMapper = objectMapper; this.objectMapper = objectMapper;
this.applicationService = applicationService; this.applicationService = applicationService;
this.encryptionService = encryptionService; this.encryptionService = encryptionService;
this.webSocketService = webSocketService; this.webSocketService = webSocketService;
this.securityContext = securityContext;
} }
@Override @Override
@ -104,7 +101,10 @@ public class ClientSocketHandler extends TextWebSocketHandler {
LOGGER.error("Failed to close session on afterConnectionClosed ", e); LOGGER.error("Failed to close session on afterConnectionClosed ", e);
} }
} }
webSocketService.connectToSocket(); String sessionId = session.getId();
UserData userData = userDataBySessionId.get(sessionId);
userDataBySessionId.remove(sessionId);
webSocketService.connectToSocket(userData.userId(), userData.token());
} }
public boolean isSessionOpen(String userId) { public boolean isSessionOpen(String userId) {
@ -114,4 +114,10 @@ public class ClientSocketHandler extends TextWebSocketHandler {
public void putSession(String userId, WebSocketSession session) { public void putSession(String userId, WebSocketSession session) {
sessionByUserId.put(userId, session); sessionByUserId.put(userId, session);
} }
public void putUserData(String sessionId, String userId, String token) {
userDataBySessionId.put(sessionId, new UserData(userId, token));
}
private record UserData(String userId, String token) {}
} }

View file

@ -9,12 +9,12 @@ import java.util.concurrent.TimeoutException;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Value; import org.springframework.beans.factory.annotation.Value;
import org.springframework.retry.annotation.Retryable;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketHttpHeaders; import org.springframework.web.socket.WebSocketHttpHeaders;
import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.client.WebSocketClient; import org.springframework.web.socket.client.WebSocketClient;
import ru.micord.ervu.account_applications.security.context.SecurityContext;
import ru.micord.ervu.account_applications.websocket.handler.ClientSocketHandler; import ru.micord.ervu.account_applications.websocket.handler.ClientSocketHandler;
/** /**
@ -26,7 +26,6 @@ public class WebSocketService {
private final WebSocketClient webSocketClient; private final WebSocketClient webSocketClient;
private final WebSocketHandler webSocketHandler; private final WebSocketHandler webSocketHandler;
private final SecurityContext securityContext;
@Value("${ervu.url}") @Value("${ervu.url}")
private String ervuUrl; private String ervuUrl;
@ -35,22 +34,22 @@ public class WebSocketService {
@Value("${ervu.socket.connection_timeout:30}") @Value("${ervu.socket.connection_timeout:30}")
private long timeout; private long timeout;
public WebSocketService(WebSocketClient webSocketClient, WebSocketHandler webSocketHandler, public WebSocketService(WebSocketClient webSocketClient, WebSocketHandler webSocketHandler) {
SecurityContext securityContext) {
this.webSocketClient = webSocketClient; this.webSocketClient = webSocketClient;
this.webSocketHandler = webSocketHandler; this.webSocketHandler = webSocketHandler;
this.securityContext = securityContext;
} }
public void connectToSocket() { @Retryable(
retryFor = {ExecutionException.class, TimeoutException.class},
maxAttemptsExpression = "${socket.connect.max_attempts:3}")
public void connectToSocket(String userId, String token) {
ClientSocketHandler socketHandler = (ClientSocketHandler) this.webSocketHandler; ClientSocketHandler socketHandler = (ClientSocketHandler) this.webSocketHandler;
if (socketHandler.isSessionOpen(securityContext.getUserId())) { if (socketHandler.isSessionOpen(userId)) {
return; return;
} }
WebSocketHttpHeaders headers = new WebSocketHttpHeaders(); WebSocketHttpHeaders headers = new WebSocketHttpHeaders();
headers.set("Content-Type", "application/json"); headers.set("Content-Type", "application/json");
String token = securityContext.getToken();
headers.add("Authorization", "Bearer " + token); headers.add("Authorization", "Bearer " + token);
headers.add("Cookie", "JWT=" + token); // to listen private messages headers.add("Cookie", "JWT=" + token); // to listen private messages
@ -58,7 +57,8 @@ public class WebSocketService {
String host = new URI(ervuUrl).getHost(); String host = new URI(ervuUrl).getHost();
WebSocketSession session = webSocketClient.doHandshake(this.webSocketHandler, headers, WebSocketSession session = webSocketClient.doHandshake(this.webSocketHandler, headers,
URI.create("wss://" + host + socketQueue)).get(timeout, TimeUnit.SECONDS); URI.create("wss://" + host + socketQueue)).get(timeout, TimeUnit.SECONDS);
socketHandler.putSession(securityContext.getUserId(), session); socketHandler.putSession(userId, session);
socketHandler.putUserData(session.getId(), userId, token);
} }
catch (InterruptedException | ExecutionException | URISyntaxException | TimeoutException e) { catch (InterruptedException | ExecutionException | URISyntaxException | TimeoutException e) {
LOGGER.error("Failed to connect socket", e); LOGGER.error("Failed to connect socket", e);

10
pom.xml
View file

@ -307,6 +307,16 @@
<artifactId>hibernate-validator</artifactId> <artifactId>hibernate-validator</artifactId>
<version>9.0.0.CR1</version> <version>9.0.0.CR1</version>
</dependency> </dependency>
<dependency>
<groupId>org.springframework.retry</groupId>
<artifactId>spring-retry</artifactId>
<version>2.0.11</version>
</dependency>
<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-aspects</artifactId>
<version>6.2.5</version>
</dependency>
</dependencies> </dependencies>
</dependencyManagement> </dependencyManagement>
<repositories> <repositories>