package com.supwisdom.dlpay.framework.filter;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.StringUtils;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

/**
 * 防止sql注入,xss攻击
 * */
public class XssHttpServletRequestWrapper extends HttpServletRequestWrapper {
  private final static Logger log = LoggerFactory.getLogger(XssHttpServletRequestWrapper.class);
  private final static String key = "'|and|exec|execute|insert|select|delete|update|count|drop|%|chr|mid|master|truncate|" +
      "char|declare|sitename|net user|xp_cmdshell|;|or|-|+|,|like'|and|exec|execute|insert|create|drop|" +
      "table|from|grant|use|group_concat|column_name|" +
      "information_schema.columns|table_schema|union|where|select|delete|update|order|by|count|" +
      "chr|mid|master|truncate|char|declare|or|;|-|--|,|like|//|/|%|#";
  private static Set<String> notAllowedKeyWords = new HashSet<String>(0);
  private static String replacedString="INVALID";
  static {
    String keyStr[] = key.split("\\|");
    for (String str : keyStr) {
      notAllowedKeyWords.add(str);
    }
  }

  private String currentUrl;

  /**
   * Constructs a request object wrapping the given request.
   *
   * @param request the {@link HttpServletRequest} to be wrapped.
   * @throws IllegalArgumentException if the request is null
   */
  public XssHttpServletRequestWrapper(HttpServletRequest request) {
    super(request);
    currentUrl = request.getRequestURI();
  }

  /**覆盖getParameter方法，将参数名和参数值都做xss过滤。
   * 如果需要获得原始的值，则通过super.getParameterValues(name)来获取
   * getParameterNames,getParameterValues和getParameterMap也可能需要覆盖
   */
  @Override
  public String getParameter(String parameter) {
    String value = super.getParameter(parameter);
    if (value == null) {
      return null;
    }
    return cleanXSS(value);
  }
  @Override
  public String[] getParameterValues(String parameter) {
    String[] values = super.getParameterValues(parameter);
    if (values == null) {
      return null;
    }
    int count = values.length;
    String[] encodedValues = new String[count];
    for (int i = 0; i < count; i++) {
      encodedValues[i] = cleanXSS(values[i]);
    }
    return encodedValues;
  }

  @Override
  public Map<String, String[]> getParameterMap(){
    Map<String, String[]> values=super.getParameterMap();
    if (values == null) {
      return null;
    }
    Map<String, String[]> result=new HashMap<>();
    for(String key:values.keySet()){
      String encodedKey=cleanXSS(key);
      int count=values.get(key).length;
      String[] encodedValues = new String[count];
      for (int i = 0; i < count; i++){
        encodedValues[i]=cleanXSS(values.get(key)[i]);
      }
      result.put(encodedKey,encodedValues);
    }
    return result;
  }

  /**
   * 覆盖getHeader方法，将参数名和参数值都做xss过滤。
   * 如果需要获得原始的值，则通过super.getHeaders(name)来获取
   * getHeaderNames 也可能需要覆盖
   */
  @Override
  public String getHeader(String name) {
    String value = super.getHeader(name);
    if (value == null) {
      return null;
    }
    return cleanXSS(value);
  }

  private String cleanXSS(String valueP) {
    // You'll need to remove the spaces from the html entities below
    String value = valueP.replaceAll("<", "&lt;").replaceAll(">", "&gt;");
    value = value.replaceAll("<", "& lt;").replaceAll(">", "& gt;");
    value = value.replaceAll("\\(", "& #40;").replaceAll("\\)", "& #41;");
    value = value.replaceAll("'", "& #39;");
    value = value.replaceAll("eval\\((.*)\\)", "");
    value = value.replaceAll("[\\\"\\\'][\\s]*javascript:(.*)[\\\"\\\']", "\"\"");
    value = value.replaceAll("script", "");
    value = value.replaceAll(" ","");
    value = cleanSqlKeyWords(value);
    return value;
  }

  private String cleanSqlKeyWords(String value) {
    String paramValue = value;
    for (String keyword : notAllowedKeyWords) {
      if (paramValue.length() > keyword.length() + 4
          && (paramValue.contains(" "+keyword)||paramValue.contains(keyword+" ")||paramValue.contains(" "+keyword+" "))) {
        paramValue = StringUtils.replace(paramValue, keyword, replacedString);
        log.error(this.currentUrl + "已被过滤，因为参数中包含不允许sql的关键词(" + keyword
            + ")"+";参数："+value+";过滤后的参数："+paramValue);
      }
    }
    return paramValue;
  }
}
