package com.supwisdom.dlpay.framework.core;

import org.jose4j.jwa.AlgorithmConstraints;
import org.jose4j.jwk.JsonWebKey;
import org.jose4j.jws.AlgorithmIdentifiers;
import org.jose4j.jws.JsonWebSignature;
import org.jose4j.jwt.JwtClaims;
import org.jose4j.jwt.MalformedClaimException;
import org.jose4j.jwt.consumer.InvalidJwtException;
import org.jose4j.jwt.consumer.JwtConsumer;
import org.jose4j.jwt.consumer.JwtConsumerBuilder;
import org.jose4j.lang.JoseException;
import org.springframework.security.core.userdetails.UserDetails;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class JwtTokenUtil {
  private JwtConfig jwtConfig;

  public JwtTokenUtil(JwtConfig config) {
    this.jwtConfig = config;
  }

  public String getHeader() {
    return jwtConfig.getHeader();
  }

  public JwtToken generateToken(Map<String, Object> params) throws JoseException, MalformedClaimException {
    JwtClaims claims = new JwtClaims();
    claims.setIssuer(params.get("issuer").toString());  // who creates the token and signs it
    if (params.get("audience") != null) {
      claims.setAudience(params.get("audience").toString());
    }
    claims.setExpirationTimeMinutesInTheFuture(jwtConfig.getExpiration() / 60); // time when the token will expire (10 minutes from now)
    claims.setGeneratedJwtId();
    claims.setIssuedAtToNow();  // when the token was issued/created (now)
    claims.setNotBeforeMinutesInThePast(2); // time before which the token is not yet valid (2 minutes ago)
    if (params.get("subject") != null) {
      claims.setSubject(params.get("subject").toString()); // the subject/principal is whom the token is about
    }
    if (params.get("authorities") != null) {
      claims.setClaim("authorities", params.get("authorities"));
    }
    if(params.get("uid") != null) {
      claims.setClaim("uid", params.get("uid"));
    }
    /*
    claims.setClaim("email", "mail@example.com"); // additional claims/attributes about the subject can be added
    List<String> groups = Arrays.asList("group-one", "other-group", "group-three");
    claims.setStringListClaim("groups", groups); // multi-valued claims work too and will end up as a JSON array
     */

    Map<String, Object> keySpec = new HashMap<>();
    keySpec.put("kty", "oct");
    keySpec.put("k", jwtConfig.getSecret());
    JsonWebKey key = JsonWebKey.Factory.newJwk(keySpec);
    JsonWebSignature jws = new JsonWebSignature();
    jws.setPayload(claims.toJson());
    jws.setKey(key.getKey());
    jws.setKeyIdHeaderValue(key.getKeyId());
    jws.setAlgorithmHeaderValue(AlgorithmIdentifiers.HMAC_SHA256);
    return new JwtToken(claims.getJwtId(), jws.getCompactSerialization(), claims.getExpirationTime());
  }

  public JwtToken generateToken(UserDetails userDetails) throws JoseException, MalformedClaimException {
    Map<String, Object> claims = new HashMap<>();
    claims.put("uid", userDetails.getUsername());
    return generateToken(claims);
  }

  public Map<String, Object> verifyToken(String token) throws JoseException, InvalidJwtException {
    Map<String, Object> keySpec = new HashMap<>();
    keySpec.put("kty", "oct");
    keySpec.put("k", jwtConfig.getSecret());
    JsonWebKey key = JsonWebKey.Factory.newJwk(keySpec);
    JwtConsumer jwtConsumer = new JwtConsumerBuilder()
        .setRequireExpirationTime() // the JWT must have an expiration time
        .setAllowedClockSkewInSeconds(30) // allow some leeway in validating time based claims to account for clock skew
        .setVerificationKey(key.getKey()) // verify the signature with the public key
        .setSkipDefaultAudienceValidation()
        .setJwsAlgorithmConstraints( // only allow the expected signature algorithm(s) in the given context
            new AlgorithmConstraints(org.jose4j.jwa.AlgorithmConstraints.ConstraintType.WHITELIST, // which is only RS256 here
                AlgorithmIdentifiers.HMAC_SHA256))
        .build(); // create the JwtConsumer instance

    //  Validate the JWT and process it to the Claims
    JwtClaims jwtClaims = jwtConsumer.processToClaims(token);
    return jwtClaims.getClaimsMap();
  }
}
