package com.supwisdom.infras.security.reactive.jwt;

import io.jsonwebtoken.Claims;

import java.util.ArrayList;
import java.util.List;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.http.HttpHeaders;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextImpl;
import org.springframework.security.web.server.context.ServerSecurityContextRepository;
import org.springframework.web.server.ServerWebExchange;

import com.supwisdom.infras.security.authentication.JwtAuthenticationToken;
import com.supwisdom.infras.security.utils.JWTTokenUtil;

import reactor.core.publisher.Mono;

public class JWTSecurityContextRepository implements ServerSecurityContextRepository {
  
  private static final Logger logger = LoggerFactory.getLogger(JWTSecurityContextRepository.class);
  
  @Value("${infras.security.jwt.token.authorization.prefix:Bearer}")
  private String authorizationPrefix;
  
  private JWTTokenUtil jwtTokenUtil;
  
  @Autowired
  public JWTSecurityContextRepository(JWTTokenUtil jwtTokenUtil) {
    this.jwtTokenUtil = jwtTokenUtil;
  }

  @Override
  public Mono<Void> save(ServerWebExchange exchange, SecurityContext context) {
    return Mono.empty();
  }

  @Override
  public Mono<SecurityContext> load(ServerWebExchange exchange) {
    
    ServerHttpRequest request = exchange.getRequest();
    
    String authToken = null;
    
    String authParamter = request.getQueryParams().getFirst("token"); logger.debug("authParamter is [{}]", authParamter);
    if (authParamter != null && !authParamter.isEmpty()) {
      authToken = authParamter;
    }

    if (authToken == null) {
      String authHeader = request.getHeaders().getFirst(HttpHeaders.AUTHORIZATION); logger.debug("authHeader is [{}]", authHeader);
      if (authHeader != null && authHeader.toLowerCase().startsWith(authorizationPrefix.toLowerCase())) {
        authToken = authHeader.substring(authorizationPrefix.length() + 1);
      }
    }
    
    logger.debug("authToken is [{}]", authToken);
    
    if (authToken != null && !authToken.isEmpty()) {
      
      String username = getUsernameFromToken(authToken);
      
      if (username != null) {
        List<GrantedAuthority> authorities = getAuthoritiesFromToken(authToken);
        
        Authentication authentication = new JwtAuthenticationToken(username, authToken, authorities);
        
        return Mono.justOrEmpty(new SecurityContextImpl(authentication));
      }
      
    }
    
    return Mono.empty();
  }
  
  
  /**
   * 从令牌中获取用户名
   *
   * @param token 令牌
   * @return 用户名
   */
  private String getUsernameFromToken(String token) {
      String username;
      try {
          Claims claims = jwtTokenUtil.getClaimsFromToken(token);
          username = claims.getSubject();
      } catch (Exception e) {
          username = null;
      }
      return username;
  }
  
  private List<GrantedAuthority> getAuthoritiesFromToken(String token) {
    List<GrantedAuthority> collAuthorities = new ArrayList<GrantedAuthority>();
    try {
        Claims claims = jwtTokenUtil.getClaimsFromToken(token);
        String roles = claims.get("roles", String.class);
        
        for (String role : roles.split(",")) {
          collAuthorities.add(new SimpleGrantedAuthority(role));
        }
    } catch (Exception e) {
      collAuthorities = new ArrayList<GrantedAuthority>();
    }
    return collAuthorities;
  }


}
