본문 바로가기
Web/Spring

[Spring Security] jwt RefreshToken 구현하기

by 수짱수짱 2024. 1. 29.

프로젝트를 수행하던 중 access token만으로는 사용자에게 편리한 로그인을 제공하기가 어려움을 느끼게 되었다

가장 크게 느낀 이유론 access token 특성상 지속시간이 짧기 때문에 만료되면 짧은 주기로 이용자가 다시 로그인해야 하는 불편함이 생기는 것이다.

이를 해결하기 위해 refreshToken을 구현해야겠다고 생각하게 되었고 그 과정을 정리해 보았다.

refreshToken 전체로직

처음에는 아래와 같은 전체로직으로 구현을 생각하게 되었다

 

나는 사실 구현하다보니 여러가지 방법이 있는 것을 알게 되었고 굳이 refreshToken을 취약하게 외부에 드러낼 필요가 없다고 생각이 들어서 아래와 같은 로직으로 구현하게 되었다

즉, 만료된 accessToken을 받았을 때 해당 accessToken을 가진 user 정보와 refresh Token 대비한다

이후 해당 refresh Token이 user의 것이 맞고 유효하다면 새로운 accessToken을 응답 header에 담아서 보내준다

만약 refresh Token이 유효하지 않다면 새로운 accessToken을 내려주지 않는다

이런 경우는 사용자가 다시 로그인을 하고 새로운 refreshToken과 accessToken을 발급받아야 한다

  • 이렇게 계속해서 db에 refreshToken이 쌓이게 되는데 배치 시스템을 통해 주기적으로 지워주어야 하는 절차가 필요해진다.
  • 이런 경우와 속도때문에 redis를 많이 사용한다

 

1️⃣ RefreshToken 구현

public interface RefreshTokenRepository extends JpaRepository<RefreshToken, Long> {
    Optional<RefreshToken> findByUserId(Long userId);
}
@Getter
@Entity
public class RefreshToken extends BaseEntity {

    @Column(name = "user_id", nullable = false)
    private Long userId;

    @Column(name = "refresh_token", nullable = false, unique = true)
    private String refreshToken;

    protected RefreshToken() {
    }

    public RefreshToken(Long userId, String refreshToken) {
        this.userId = userId;
        this.refreshToken = refreshToken;
    }

    public void update(String refreshToken) {
        this.refreshToken = refreshToken; // login시 해당 유저의 refreshToken을 업데이트 할 때 사용
    }
}

 

2️⃣ 로그인시 refreshToken 발급

  1. 해당 유저의 refresToken이 존재하지 않는다면 발급
  2. 이미 refreshToken이 존재한다면 재발급

액세스 토큰의 유효 시간은 30분, 리프레쉬 토큰의 유효 시간은 24시간으로 임시로 정의해 두었다.

AuthServiceloginTransaction을 붙여준다 (이제 refreshToken을 save해야 하거든)

    @Transactional // 추가!
    public LoginResponse login(LoginRequest request) {
        // user 검증
        String userEmail = request.email();
        UsernamePasswordAuthenticationToken authenticationToken
            = new UsernamePasswordAuthenticationToken(userEmail, request.password());
        Authentication authentication = authenticationManagerBuilder.getObject().authenticate(authenticationToken);

        // token 생성
        String accessToken = jwtTokenProvider.createAccessToken(authentication);
        String refreshToken = jwtTokenProvider.createRefreshToken(authentication);

        // refreshToken 저장
        User user = userRepository.findByEmail(userEmail).orElseThrow(() -> new RangerException(NOT_FOUND_USER));
        refreshTokenRepository.findByUserId(user.getId())
            .ifPresentOrElse(
                token -> token.update(refreshToken),
                () -> {
                    RefreshToken savedRefreshToken = new RefreshToken(user.getId(), refreshToken);
                    refreshTokenRepository.save(savedRefreshToken);
                }
            );

        return new LoginResponse(accessToken, "Bearer");
    }

 

최초 로그인 요청 시 해당 유저의 refreshToken이 잘 적용된 모습을 볼 수 있다

 

refreshToken을 이렇게 반환하는건 좀 아닌 거 같아서 반환 응답은 accessToken으로만 하도록 변경했다

 

다시 로그인을 요청했을 때도 update된 refreshToken이 잘 저장되는 모습을 확인할 수 있다

3️⃣ 이제 accessToken이 유효기간이 지났을 때 유호한 refreshToken인지를 검증하고 유효하다면 새로운 accessToken을 발급해주어야 한다

해당 빨간박스 과정이 되겠지!

 

JwtTokenProvider 를 수정해보자

아래 코드는 전부 JwtTokenProvider에 추가된 메소드이다

public Boolean validateRefreshToken(String token) {
        try {
            Jwts.parserBuilder().setSigningKey(key).build().parseClaimsJws(token);
            return true;
        } catch (ExpiredJwtException e) { // 기한 만료
            throw new ExpiredJwtException(null, null, EXPIRED_JWT_TOKEN.getMessage());
        }
    }

먼저 나는 accessToken의 유효성과 refreshToken의 유효성을 검사하는 메소드를 분리해주었다

만약 기한이 만료된 refreshToken이라면 ExpiredJwtException을 날려서 JwtFilter에서 이를 처리하도록 했다 (아래 후술)

public String createAccessToken(Authentication authentication) {
        String authorities = authentication.getAuthorities().stream()
            .map(GrantedAuthority::getAuthority)
            .collect(Collectors.joining(","));

        return Jwts.builder()
            .setSubject(authentication.getName())
            .claim(AUTHORITY, authorities)
            .setExpiration(new Date(System.currentTimeMillis() + ACCESS_EXPIRATION_TIME))
            .signWith(key, SignatureAlgorithm.HS256)
            .compact();
    }

    public String createRefreshToken(Authentication authentication) {
        String authorities = authentication.getAuthorities().stream()
            .map(GrantedAuthority::getAuthority)
            .collect(Collectors.joining(","));

        return Jwts.builder()
            .setSubject(authentication.getName())
            .claim(AUTHORITY, authorities)
            .setExpiration(new Date(System.currentTimeMillis() + REFRESH_EXPIRATION_TIME))
            .signWith(key, SignatureAlgorithm.HS256)
            .compact();
    }

그리고 accessToken을 만드는 메소드와 refreshToken을 만드는 메소드를 분리해 주었다

로직이 유효기간말고 다 똑같아서 유효기간을 인자로 주고 인자에 따른 메소드로 리팩토링해도 될 것 같은데 ..?

 

여기까지가 크게 JwtTokenProvider 에서 변한 점이다

주목해볼점은 refreshToken을 생성하는 메소드가 따로 생겼고(유효 시간이 accessToken보다 김) refreshToken을 검증하는 메소드가 따로 생겼다

 

해당 메소드들을 통해 refreshToken을 만들고 유효한 refreshToken일시 새로운 accessToken을 만들 수 있다

이를 적용한 모습을 JwtFilter에서 확인해보자

@Override
    protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response,
        FilterChain filterChain) throws ServletException, IOException {
        String accessToken = resolveToken(request, "Authorization");

        try 
            if (StringUtils.hasText(accessToken) && jwtTokenProvider.isValidToken(accessToken)) { // 1.
                Authentication authentication = jwtTokenProvider.getAuthentication(accessToken);
                SecurityContextHolder.getContext().setAuthentication(authentication);
            }
        } catch (ExpiredJwtException e) { // 2.
            String refreshToken = null;

            if (StringUtils.hasText(request.getHeader("Authorization"))) { // 3.
                Authentication authentication = jwtTokenProvider.getAuthentication(accessToken);
                UserPrincipal principal = (UserPrincipal)authentication.getPrincipal();
                Long userId = principal.getId();
                refreshToken = jwtTokenProvider.getRefreshToken(userId);
            }

            if (StringUtils.hasText(refreshToken) && jwtTokenProvider.validateRefreshToken(refreshToken)) { // 4.
                Authentication authentication = jwtTokenProvider.getAuthentication(refreshToken);
                String newAccessToken = jwtTokenProvider.createAccessToken(authentication);
                SecurityContextHolder.getContext().setAuthentication(authentication);

                response.setHeader(HttpHeaders.AUTHORIZATION, newAccessToken); // 5.
            }
        }

        filterChain.doFilter(request, response);
    }

    private String resolveToken(HttpServletRequest request, String header) {
        String bearerToken = request.getHeader(header);
        if (StringUtils.hasText(bearerToken) && bearerToken.startsWith(BEARER_PREFIX)) {
            return bearerToken.substring(BEARER_PREFIX.length());
        }
        return null;
    }

JwtFilter에서 가장 먼저 동작하는 filter부분이다. 각 번호에 대해 설명을 해보겠다!

  1. accessToken이 있고 해당 accessToken이 유효한경우
    1. 이런 경우는 인증을 받아와서 정상적으로 security context에 해당 권한 정보를 저장한다
  2. 2번은 ExpiredJwtException 예외가 발생한 경우
    1. 이 경우는 accessToken의 유효기간이 지난경우이다! (refreshToken아님.)
    2. accessToken의 다른 예외는 JwtException으로 터지지만 유효기간이 지나는 예외는 ExpiredJwtException으로 특정해 두었다.
  3. 해당 accessToken으로 유저 정보를 가져와서 해당 유저의 refreshToken을 가져온다
  4. 가져온 refreshToken이 존재하고 유효하다면 createAccessToken 을 통해 새로운 accessToken을 생성한다.

4-1. 만약 refreshToken이 유효하지 않다면 내가 정의해둔 jwt exception handler에 걸려 “유효기간이 지난 토큰입니다”이라는 메시지가 발생한다.

 

4-2. JwtException아래 ExpiredJwtException이 존재하기에 해당 핸들러에서 예외처리가 될 수 있다

 

5. 새로운 accessToken을 응답객체 header의 Authorization에 넣어 전달한다

성공!

 

 

더보기

부록

<JwtTokenProvider 전체 코드>

  @Component
  public class JwtTokenProvider {
      private static final Logger log = LoggerFactory.getLogger(JwtTokenProvider.class);
      private static final long ACCESS_EXPIRATION_TIME = 30 * 60 * 1000L; // 30 min
      private static final long REFRESH_EXPIRATION_TIME = 24 * 60 * 60 * 1000L; // 24 hour
      private static final String AUTHORITY = "auth";

      private final Key key;
      private final CustomUserDetailsService userDetailService;
      private final RedisUtil redisUtil;
      private final RefreshTokenRepository refreshTokenRepository;

      public JwtTokenProvider(@Value("${jwt.secret}") String secretKey, CustomUserDetailsService userDetailService,
          RedisUtil redisUtil, RefreshTokenRepository refreshTokenRepository) {
          this.userDetailService = userDetailService;
          this.redisUtil = redisUtil;
          this.refreshTokenRepository = refreshTokenRepository;
          byte[] secretByteKey = DatatypeConverter.parseBase64Binary(secretKey);
          this.key = Keys.hmacShaKeyFor(secretByteKey);
      }

      public String createAccessToken(Authentication authentication) {
          String authorities = authentication.getAuthorities().stream()
              .map(GrantedAuthority::getAuthority)
              .collect(Collectors.joining(","));

          return Jwts.builder()
              .setSubject(authentication.getName())
              .claim(AUTHORITY, authorities)
              .setExpiration(new Date(System.currentTimeMillis() + ACCESS_EXPIRATION_TIME))
              .signWith(key, SignatureAlgorithm.HS256)
              .compact();
      }

      public String createRefreshToken(Authentication authentication) {
          String authorities = authentication.getAuthorities().stream()
              .map(GrantedAuthority::getAuthority)
              .collect(Collectors.joining(","));

          return Jwts.builder()
              .setSubject(authentication.getName())
              .claim(AUTHORITY, authorities)
              .setExpiration(new Date(System.currentTimeMillis() + REFRESH_EXPIRATION_TIME))
              .signWith(key, SignatureAlgorithm.HS256)
              .compact();
      }

      public Authentication getAuthentication(String accessToken) {
          Claims claims = parseClaims(accessToken);

          if (claims.get("auth") == null) {
              log.error("JWT Exception Occurs : {}", LOGOUT_JWT_TOKEN);
              throw new RangerException(NOT_AUTHORIZED_TOKEN);
          }

          Collection<? extends GrantedAuthority> authorities =
              Arrays.stream(claims.get("auth").toString().split(","))
                  .map(SimpleGrantedAuthority::new)
                  .toList();

          String username = claims.getSubject();
          UserDetails user = userDetailService.loadUserByUsername(username);

          return new UsernamePasswordAuthenticationToken(user, "", authorities);
      }

      public boolean isValidToken(String token) {
          try {
              Jwts.parserBuilder().setSigningKey(key).build().parseClaimsJws(token);

              if (redisUtil.hasKeyBlackList(token)) {
                  log.error("JWT Exception Occurs : {}", LOGOUT_JWT_TOKEN);
                  throw new JwtException(LOGOUT_JWT_TOKEN.getMessage());
              }

              return true;
          } catch (io.jsonwebtoken.security.SecurityException | MalformedJwtException e) {
              log.error("JWT Exception Occurs : {}", NOT_CORRECT_JWT_SIGN);
              throw new JwtException(NOT_CORRECT_JWT_SIGN.getMessage());
          } catch (ExpiredJwtException e) {
              log.error("JWT Exception Occurs : {}", EXPIRED_JWT_TOKEN);
              throw new ExpiredJwtException(null, null, EXPIRED_JWT_TOKEN.getMessage());
          } catch (UnsupportedJwtException e) {
              log.error("JWT Exception Occurs : {}", NOT_SUPPORTED_JWT_TOKEN);
              throw new JwtException(NOT_SUPPORTED_JWT_TOKEN.getMessage());
          } catch (IllegalArgumentException e) {
              log.error("JWT Exception Occurs : {}", NOT_CORRECT_JWT);
              // throw new JwtException(NOT_SUPPORTED_JWT_TOKEN.getMessage()); TODO: Postman에서 로그인할 때 오류남 따라서 임시 조치
          }
          return false;
      }

      public Boolean validateRefreshToken(String token) {
          try {
              Jwts.parserBuilder().setSigningKey(key).build().parseClaimsJws(token);
              return true;
          } catch (ExpiredJwtException e) { // 기한 만료
              throw new ExpiredJwtException(null, null, EXPIRED_JWT_TOKEN.getMessage());
          }
      }

      public String getRefreshToken(Long userId) {
          RefreshToken refreshToken = refreshTokenRepository.findByUserId(userId)
              .orElseThrow(() -> new RangerException(INVALID_REFRESH_TOKEN));
          return refreshToken.getRefreshToken();
      }

      private Claims parseClaims(String accessToken) {
          try {
              return Jwts.parserBuilder().setSigningKey(key).build().parseClaimsJws(accessToken).getBody();
          } catch (ExpiredJwtException e) {
              return e.getClaims();
          }
      }
  }

 

<JwtFilter 전체코드>

@Component
public class JwtFilter extends OncePerRequestFilter {

	private static final String BEARER_PREFIX = "Bearer ";

	private final JwtTokenProvider jwtTokenProvider;

	public JwtFilter(JwtTokenProvider jwtTokenProvider) {
		this.jwtTokenProvider = jwtTokenProvider;
	}

	@Override
	protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response,
		FilterChain filterChain) throws ServletException, IOException {
		String accessToken = resolveToken(request, "Authorization");

		try {
			if (StringUtils.hasText(accessToken) && jwtTokenProvider.isValidToken(accessToken)) {
				Authentication authentication = jwtTokenProvider.getAuthentication(accessToken);
				SecurityContextHolder.getContext().setAuthentication(authentication);
			}
		} catch (ExpiredJwtException e) {
			String refreshToken = null;

			if (StringUtils.hasText(request.getHeader("Authorization"))) {
				Authentication authentication = jwtTokenProvider.getAuthentication(accessToken);
				UserPrincipal principal = (UserPrincipal)authentication.getPrincipal();
				Long userId = principal.getId();
				refreshToken = jwtTokenProvider.getRefreshToken(userId);
			}

			if (StringUtils.hasText(refreshToken) && jwtTokenProvider.validateRefreshToken(refreshToken)) {
				Authentication authentication = jwtTokenProvider.getAuthentication(refreshToken);
				String newAccessToken = jwtTokenProvider.createAccessToken(authentication);
				SecurityContextHolder.getContext().setAuthentication(authentication);

				response.setHeader(HttpHeaders.AUTHORIZATION, newAccessToken);
			}
		}

		filterChain.doFilter(request, response);
	}

	private String resolveToken(HttpServletRequest request, String header) {
		String bearerToken = request.getHeader(header);
		if (StringUtils.hasText(bearerToken) && bearerToken.startsWith(BEARER_PREFIX)) {
			return bearerToken.substring(BEARER_PREFIX.length());
		}
		return null;
	}
}

Reference