package com.supwisdom.infras.security.utils;

import io.jsonwebtoken.Claims;
import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.SignatureAlgorithm;

import java.io.IOException;
import java.security.KeyPair;
import java.security.KeyStoreException;
import java.security.NoSuchAlgorithmException;
import java.security.UnrecoverableKeyException;
import java.security.cert.CertificateException;
import java.security.interfaces.RSAPrivateKey;
import java.security.interfaces.RSAPublicKey;
import java.security.spec.InvalidKeySpecException;
import java.util.Date;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
//import org.springframework.security.core.GrantedAuthority;
//import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.stereotype.Component;

import com.supwisdom.infras.security.cert.CertUtil;
import com.supwisdom.infras.security.token.store.redis.JWTTokenRedisStore;

@Component
public class JWTTokenUtil implements InitializingBean {

  private static final Logger logger = LoggerFactory.getLogger(JWTTokenUtil.class);
  
  private static ConcurrentMap<String, Long> mapTokenExpiration = new ConcurrentHashMap<String, Long>();
  
  @Autowired(required = false)
  private JWTTokenRedisStore redisTokenStore;

  /**
   * 密钥
   */
  //@Value("${infras.security.jwt.secret:MyJwtSecret}")
  //private String secret;
  

  @Value("${infras.security.jwt.iss:supwisdom}")
  private String issuer;
  @Value("${infras.security.jwt.jti:supwisdom-jwt}")
  private String jti;


  @Value("${infras.security.jwt.expiration:2592000}")
  private Long expiration;
  
  @Value("${infras.security.jwt.kickout.enabled:false}")
  private boolean kickoutEnabled;
  

  @Value("${infras.security.jwt.key-alias:supwisdom-jwt-key}")
  private String keyAlias;
  @Value("${infras.security.jwt.key-password:kingstar}")
  private String keyPassword;

  @Value("${infras.security.jwt.key-store:}")
  private String keyStore;
  @Value("${infras.security.jwt.key-store-password:kingstar}")
  private String keyStorePassword;
  
  
  @Value("${infras.security.jwt.public-key-pem:}")
  private String publicKeyPem;
  @Value("${infras.security.jwt.private-key-pem-pkcs8:}")
  private String privateKeyPemPKCS8;

  @Override
  public void afterPropertiesSet() throws Exception {
    this.initKey();
  }

  private KeyPair keyPair;
  
  public void initKey() {
    
    try {
      this.keyPair = CertUtil.initKeyFromPem(publicKeyPem, privateKeyPemPKCS8);
      logger.debug("init keyPair from pem");
      return;
    } catch (NoSuchAlgorithmException e) {
      e.printStackTrace();
    } catch (InvalidKeySpecException e) {
      e.printStackTrace();
    }
    
    try {
      this.keyPair = CertUtil.initKeyFromKeyStore(keyStore, keyStorePassword, keyAlias, keyPassword);
      logger.debug("init keyPair from keyStore");
    } catch (UnrecoverableKeyException e) {
      e.printStackTrace();
    } catch (KeyStoreException e) {
      e.printStackTrace();
    } catch (CertificateException e) {
      e.printStackTrace();
    } catch (IOException e) {
      e.printStackTrace();
    } catch (NoSuchAlgorithmException e) {
      e.printStackTrace();
    }
    
  }
  
  public RSAPublicKey getPublicKey() {
    return (RSAPublicKey) this.keyPair.getPublic();
  }
  
  public RSAPrivateKey getPrivateKey() {
    return (RSAPrivateKey) this.keyPair.getPrivate();
  }
  
  public String getPublicKeyPem() {
    return CertUtil.publicKeyToPem(getPublicKey());
  }

  private void storeTokenExpiration(String token, Long expiration) {
    if (!kickoutEnabled) {
      return;
    }
    
    logger.debug("store <token, expiration> to Map");
    mapTokenExpiration.put(token, expiration);  // FIXME: 存储到 redis 或 数据库

    if (redisTokenStore != null) {
      logger.debug("store <token, expiration> to Redis");
      redisTokenStore.storeTokenExpiration(token, expiration);
    }
  }
  
  private Long loadTokenExpiration(String token) {
    if (!kickoutEnabled) {
      return Long.MAX_VALUE;
    }
    
    if (redisTokenStore != null) {
      logger.debug("load <token, expiration> from Redis");
      return redisTokenStore.loadTokenExpiration(token, -1L);
    }
    
    logger.debug("load <token, expiration> from Map");
    return mapTokenExpiration.getOrDefault(token, -1L);  // FIXME: 存储到 redis 或 数据库
  }

  /**
   * 从数据声明生成令牌
   *
   * @param claims 数据声明
   * @return 令牌
   */
  public String generateToken(Map<String, Object> claims) {
      Date expirationDate = new Date(System.currentTimeMillis() + expiration * 1000);
      String token = Jwts.builder()
          .setClaims(claims)
          .setIssuer(issuer)
          .setId(jti)
          .setIssuedAt(new Date(System.currentTimeMillis()))
          .setExpiration(expirationDate)
          .signWith(SignatureAlgorithm.RS512, this.getPrivateKey())
          .compact();
      
      storeTokenExpiration(token, expirationDate.getTime());
      
      return token;
  }

  /**
   * 从令牌中获取数据声明
   *
   * @param token 令牌
   * @return 数据声明
   */
  public Claims getClaimsFromToken(String token) {
      Claims claims;
      try {
          claims = Jwts.parser().setSigningKey(this.getPublicKey()).parseClaimsJws(token).getBody();
      } catch (Exception e) {
          claims = null;
      }
      return claims;
  }

  /**
   * 判断令牌是否过期
   *
   * @param token 令牌
   * @return 是否过期
   */
  public Boolean isTokenExpired(String token) {
      try {
          Claims claims = getClaimsFromToken(token);
          Date expiration = claims.getExpiration();
          
          Date now = new Date();
          
          // 判断存储中的 token 是否过期
          Long expirationTimeMillis = loadTokenExpiration(token);
          if (expirationTimeMillis < now.getTime()) {
            return true;
          }
          
          return expiration.before(now);
      } catch (Exception e) {
          return false; // FIXME: ?
      }
  }

  /**
   * 刷新令牌
   *
   * @param token 原令牌
   * @return 新令牌
   */
  public String refreshToken(String token) {
      String refreshedToken;
      try {
          Claims claims = getClaimsFromToken(token);
          claims.put("created", new Date());
          refreshedToken = generateToken(claims);
      } catch (Exception e) {
          refreshedToken = null;
      }
      return refreshedToken;
  }

  /**
   * 验证令牌
   *
   * @param token       令牌
   * @param userDetails 用户
   * @return 是否有效
   */
  public Boolean validateToken(String token, String username) {
    Claims claims = getClaimsFromToken(token);
    String sub = claims.getSubject();
    
    return (sub.equals(username) && !isTokenExpired(token));
  }

  /**
   * 踢出令牌
   * @param token
   */
  public void expireToken(String token) {
      if (!isTokenExpired(token)) {
          storeTokenExpiration(token, -1L);
      }
  }

}
