简述一下OJ项目中手写的Token验证服务设计过程
PART A 设计校验的哈希算法
这里直接展示整个项目中用到的算法库,其中涉及位运算的可不管
直接应用到的方法是hash(str)
大概流程如下
1.构造一个大素数表并随机打乱
2.提供足够快的快速幂
3.哈希规则:\sum 下标对应byte^^randomPrimes[下标 % 素数表长度] % 128
为了更快的hash过程其实可以把下标进一步转为其bitcount,这样算幂会把log的复杂度略降一点
package com.noresp.oj.utils;
/**
* 方便OJ搭建的简易算法库
* 目前可提供:
* 随机大素数表
* 随机打乱
* 哈希(注意:特定用途)
* 整型交换、bitcount、fastPow
* 随机数
*/
public class AlgsUtils {
private static final int[] bitmasks = new int[0x100];
private static final int[] randomPrimes = new int[1<<10];
public static final long magicNumber = 19260817L;
public static class SimpleRandom {
long seed = 1L;
public void setSeed(long seed) {
this.seed = seed;
}
/**
* 简易高效的手写随机数
* 大概比Math.random快20倍(2^^26数量级下)
* @return 随机数
*/
public long next() {
seed = seed*1103515245+12345 & 0xffffffffL; // 模拟unsigned int // 切记0xffffffff没有L会翻车。。
return seed >> 16;
}
public int next(int mod) {
return (int)(next()%mod);
}
}
static {
initializeBitmasks();
initializePrimeTable();
randomShuffle(randomPrimes,magicNumber);
}
/**
* O(n)打长度为n的二进制表
* 测试通过
*/
private static void initializeBitmasks() {
for(int i = 0xff; i > 0; --i) {
if(bitmasks[i] != 0) continue;
for(int j = i; j > 0; j -= j&-j) {
bitmasks[i]++;
}
for(int j = i, k = 0; j > 0; j -= j&-j, k++) {
bitmasks[j] = bitmasks[i]-k;
}
}
}
/**
* 计算二进制1的个数
* 测试通过
* @param value
* @return
*/
public static int bitCount(int value) {
int result = 0;
for(; value > 0; value >>>= 8) {
result += bitmasks[value & 0xff];
}
return result;
}
/**
* 通过固定的随机素数进行哈希/加密
* 哈希串 = \sum 下标对应byte^^randomPrimes[下标 % 素数表长度] % 128
* 时间复杂度O(n log2(log2m)),其中n为字符串长度,m为最大的随机素数大小
* @param string
* @return
*/
public static String hash(String string) {
if(string == null || string.length() == 0) return null;
byte[] b = string.getBytes();
for(int i = 0; i < b.length; i++) {
int j = b[i];
int k = randomPrimes[i & randomPrimes.length-1]; // 仿java.util思路 二进制长用&优化取代%
k = bitCount(k);
b[i] = (byte)fastPow(j,k,0x7f);
}
return new String(b); //注意不要b.toString()
}
/**
* 快速幂求解a^^n%mod,n不支持负数
* @param a
* @param n
* @param mod
* @return
*/
public static int fastPow(long a,long n,int mod) {
long res = 1; //long防相乘溢出
while(n > 0) {
if((n&1) == 1) {
res = res*a;
if(res >= mod) res %= mod;
}
a *= a;
if(a >= mod) a %= mod;
n >>= 1;
}
return (int)res;
}
/**
* 简单的筛法初始化素数表
*/
private static void initializePrimeTable() {
int n = randomPrimes.length << 8; // 一个大概的打表估值,预计素数的大小在1e5数量级
boolean[] notPrime = new boolean[n];
for(int i = 2; i*i < n; i++) {
if(!notPrime[i]) {
for(int j = i; j < n; j += i) {
notPrime[j] = true;
}
}
}
// 优先选取大素数,因此倒序处理+插入两个极大的素数
randomPrimes[0] = (int)1e9+7;
randomPrimes[1] = 998244353;
for(int i = n-1, j = 2; true; --i) {
if(!notPrime[i]) randomPrimes[j++] = i;
if(j == randomPrimes.length) return;
}
}
/**
* 随机打乱一个整型数组
* @param toRandom
*/
public static void randomShuffle(int[] toRandom,long seed) {
SimpleRandom roll = new AlgsUtils.SimpleRandom();
roll.setSeed(seed);
for(int i = toRandom.length-1; i > 0; --i) {
swap(toRandom,i,roll.next(i+1));
}
}
/**
* 交换两个数,注意安全使用
*/
public static void swap(int[] arr,int i,int j) {
if(SafeUtils.isOutOfBound(arr,i)) return;
if(SafeUtils.isOutOfBound(arr,j)) return;
int t = arr[i];
arr[i] = arr[j];
arr[j] = t;
}
public static void swapRange(int[] arr,int lo,int hi) {
while(lo < hi) swap(arr,lo++,hi--);
}
}
PART B 进一步的哈希
由PART A可以看到任意编码的字符串都把char限制在0-127范围内,但可能存在特殊的转义符影响面向文本的协议
因此需要把0-127映射到ASCII中a-z A-Z 0-9的范围内
为了满足尽可能的均匀分布,又乱写了一个算法(其实ch+i是多余的)
public static String visualizableHash(String str) {
StringBuilder sb = new StringBuilder("");
for(int i = 0; i < str.length(); i++) {
char ch = str.charAt(i);
if(isVisualChar(ch)) sb.append(ch);
else {
char curChar = 'a';
long factor = (int)(ch)*17+i*23;
int pos = (int)(factor % (26+26+10));
if(pos < 26) curChar = (char)('a'+pos);
else if(pos-26 < 26) curChar = (char)('A'+pos-26);
else curChar = (char)('0'+pos-26-26);
sb.append(curChar);
}
}
return sb.toString();
}
这样调用visualizableHash(hash(str))就能获得一个还可以的文本哈希了
PART C util方法封装【废弃】
其中payload就是我要负载的内容
Sign作为签名校验
目前是使用简单的String,也提供了简单的Map转换 格式见doc说明
package com.noresp.oj.utils;
import java.util.*;
/**
* 使用Token,解放Session
* 注:一个Token的格式
* [encode(key1).encode(val1).encode(key2).encode(val2).....mySign]
* 目前encode默认是base64
*
* UPD.改为可选缓存的Service实现
*/
@Deprecated // 这里是原来的实现
public class TokenUtils {
private static String encode(String str) {
if(SafeUtils.isEmpty(str)) return "";
return Base64.getEncoder().encodeToString(str.getBytes());
}
private static String decode(String str) {
if(SafeUtils.isEmpty(str)) return "";
return new String(Base64.getDecoder().decode(str.getBytes()));
}
private static String getTokenPayload(String key) {
return encode(key);
}
public static String getTokenSign(String... base64Payloads) {
if(SafeUtils.isEmpty(base64Payloads)) return "";
StringBuilder sb = new StringBuilder("");
for(String payload : base64Payloads) {
sb.append(StringUtils.visualizableHash(AlgsUtils.hash(payload)));
}
return sb.toString();
}
public static String getToken(String... payloads) {
if(SafeUtils.isEmpty(payloads)) return "";
StringBuilder sb = new StringBuilder("");
String[] encodedPayloads = new String[payloads.length];
for(int i = 0; i < payloads.length; i++) {
encodedPayloads[i] = getTokenPayload(payloads[i]);
sb.append(encodedPayloads[i]+".");
}
sb.append(getTokenSign(encodedPayloads));
return sb.toString();
}
/**
* 解密和校验Token
* @param token
* @return 如果校验失败,会返回null,否则返回Token解密内容
*/
public static String[] decodeTokenAndValidate(String token) {
if(token == null) return null;
List<String> result = new LinkedList<>();
for(int i = 0, len = 1; i < token.length(); i++,len++) {
if(token.charAt(i) == '.') {
String payload = (token.substring(i-len+1,i));
result.add(payload);
len = 0;
}
if(i == token.length()-1) {
String salt = token.substring(i-len+1,i+1); len = 0;
String[] encodedPayloads = new String[result.size()];
Iterator<String> itor = result.iterator();
while(itor.hasNext()) {
encodedPayloads[len++] = itor.next();
}
String comp = getTokenSign(encodedPayloads);
if(!salt.equals(comp)) {
return null;
}
String[] decodedPayloads = encodedPayloads; // 引用是一样的
for(len = 0; len < encodedPayloads.length; len++) {
decodedPayloads[len] = decode(encodedPayloads[len]);
}
return decodedPayloads;
}
}
return null;
}
public static Map<String,String> tokenMap(String[] decodedToken) {
Map<String,String> result = new HashMap<>();
if(decodedToken == null) return result;
for(int i = 0; i < decodedToken.length; i+=2) {
result.put(decodedToken[i],decodedToken[i+1]);
}
return result;
}
public static String getTokenAttribute(String token,String key) {
Map<String,String> tokenMap = tokenMap(decodeTokenAndValidate(token));
return tokenMap.getOrDefault(key,null);
}
}
PART C2 原始的缓存token实现
/**
* 使用非组合的装饰器模式增强Token的缓存能力(实现个数有限就没用组合了,下回改用代理+模板减少重复步骤
* 如果TokenService修改了Encoder信息需要更新缓存,需要添加个独立的flush()实现
* 测试大于1e4量级后会有优势
* 另外redis虽然是C实现,但效率还是被本地的JVM暴打
* [!] 实现上不要使用ThreadLocal来保证线程安全,tomcat是会线程复用的,会导致副本内存暴增
*
* [!] 已废弃,新的实现类在service/cache包中
*/
@Deprecated
@Service("CachedTokenService")
public class CachedTokenService extends TokenService {
// @Autowired
// private RedisClient<String> redis;
////////////////////////////////////////////////////////////
private final int CAPACITY = 10000;
/**
* 方便实现先用暴力上锁。。
* 有待优化 TODO
*/
private Map<String,Object> cache = Collections.synchronizedMap(
new LinkedHashMap<String,Object>(CAPACITY, 0.75f, true){
protected boolean removeEldestEntry(Map.Entry eldest) {
return size() > CAPACITY;
}
}
);
private Object get(String key) {
return cache.get(key);
}
private void put(String key, Object value) {
cache.put(key, value);
}
/////////////////////////////////////////////////////////////
private final String PAYLOAD_PREFIX = "payload::";
private final String PAYLOADS_PREFIX = "payloads::";
private final String ENCODED_PAYLOAD_PREFIX = "encodedPayload::";
private final String ENCODED_PAYLOADS_PREFIX = "encodedPayloads::";
private final String TOKEN_PREFIX = "token::";
/**
* 改用LRU实现,没有使用TTL
*/
private final long EXPIRE = 60*60*5;
private String preStringhandler(String string) {
return string == null ? "" : string;
}
private String payloadsHandler(String... encodedPayloads) {
StringBuilder sb = new StringBuilder("");
for(String i : encodedPayloads) {
sb.append(i);
}
return sb.toString();
}
/**
* 对生成payload部分缓存
* @return
*/
protected String createTokenPayload(String keyOrValue) {
String key = PAYLOAD_PREFIX + preStringhandler(keyOrValue);
String result = (String)get(key);
if(result != null) return result;
result = super.createTokenPayload(keyOrValue);
put(key,result);
return result;
}
/**
*
* @param encodedPayload
* @return
*/
protected String createEncodedPayloadSalt(String encodedPayload) {
// String key = ENCODED_PAYLOAD_PREFIX + encodedPayload;
// String result = redis.get(key);
// if(result != null) return result;
// result = StringUtils.visualizableHash(AlgsUtils.hash(encodedPayload));
// redis.set(key,result);
// return result;
String key = ENCODED_PAYLOAD_PREFIX + encodedPayload;
String result = (String) get(key);
if(result != null) return result;
result = super.createEncodedPayloadSalt(encodedPayload);
put(key,result);
return result;
}
protected String getSign(String... encodedPayloads) {
if(SafeUtils.isEmpty(encodedPayloads)) return "";
String key = ENCODED_PAYLOADS_PREFIX+payloadsHandler(encodedPayloads);
String result = (String)get(key);
if(result != null) return result;
result = super.getSign(encodedPayloads);
put(key,result);
return result;
}
/**
* 最常用的,对Token生成结果直接缓存
* @param payloads
* @return
*/
public String getToken(String... payloads) {
if(SafeUtils.isEmpty(payloads)) return "";
String key = PAYLOADS_PREFIX + payloadsHandler(payloads);
String result = (String)get(key);
if(result != null) return result;
result = super.getToken(payloads);
put(key,result);
return result;
}
public String[] decodeTokenAndValidate(String token) {
if(token == null) return null;
String key = TOKEN_PREFIX+token;
String[] result = (String[]) get(key);
if(result != null) return result;
result = super.decodeTokenAndValidate(token);
put(key,result);
return result;
}
}
PART C3 松耦合的缓存token实现
松耦合的设计可以对于未来的更改和问题排查有很大帮助
虽然说写的比前面繁杂多了,但都是可重用的组件设计
首先是对于编码、解码的设计
public interface Encoderable {
String encode(String string);
default String decode(String string) {
throw new UnsupportedOperationException();
}
}
/**
* 一个抽象的Encoder
* Encoder继承它可只需关注Encode的钩子部分
* 不用关心麻烦的字符串问题
*/
public abstract class AbstractEncoder {
public String preStringHandler(String string) {
if(string == null) return "";
return string;
}
protected abstract String encodeImpl(String string);
protected String decodeImpl(String string) {
throw new UnsupportedOperationException();
}
public String encode(String string) {
string = preStringHandler(string);
return encodeImpl(string);
}
public String decode(String string) {
string = preStringHandler(string);
return decodeImpl(string);
}
}
@Component
public class Base64Encoder extends AbstractEncoder implements Encoderable {
@Override
protected String encodeImpl(String string) {
return Base64.getEncoder().encodeToString(string.getBytes());
}
@Override
protected String decodeImpl(String string) {
return new String(Base64.getDecoder().decode(string.getBytes()));
}
}
其次是本地缓存
/**
* 提供JVM本地缓存
*/
public interface LocalCache {
void put(String key,Object obj);
Object get(String key);
/**
* 显式删除 ×,更倾向于自动维护 √
* @param key
*/
default void remove(String key) {
;;;
}
}
```java
@Component
public class CacheManager implements MethodInterceptor {
/**
* 一个松耦合的cache
*/
@Autowired
private LocalCache cache;
private Enhancer enhancer = new Enhancer();
public Object getCacheManagerProxy(Class clazz) {
enhancer.setSuperclass(clazz);
enhancer.setCallback(this);
return enhancer.create();
}
private String keyGenerator(String prefix,Object[] names) {
StringBuilder sb = new StringBuilder().append(prefix + "::");
for(Object name : names) {
if(name == null) continue;
sb.append(name.toString());
}
return sb.toString();
}
@Override
public Object intercept(Object o, Method method, Object[] args, MethodProxy methodProxy) throws Throwable {
String key = keyGenerator(method.getName(),args);
Object val = cache.get(key);
if(val != null) return val;
val = methodProxy.invokeSuper(o,args);
cache.put(key,val);
return val;
}
}
/**
* 线程安全
*/
@Component
@Primary
public class LruCache implements LocalCache {
private final int CAPACITY = 10000;
private Map<String,Object> cacheContainer;
/**
* 提供装饰器,方便后期修改维护非线程安全版本
*/
LruCache() {
cacheContainer = Collections.synchronizedMap(
new LinkedHashMap<String,Object>(CAPACITY, 0.75f, true){
protected boolean removeEldestEntry(Map.Entry eldest) {
return size() > CAPACITY;
}
}
);
}
@Override
public Object get(String key) {
return cacheContainer.get(key);
}
@Override
public void put(String key, Object value) {
cacheContainer.put(key, value);
}
}
最后才是服务本体
@Service
public class TokenService {
/**
* 注入的是支持decode操作的类
*/
@Autowired
Encoderable encoder;
private String encode(String str) {
return encoder.encode(str);
}
private String decode(String str) {
return encoder.decode(str);
}
protected String createTokenPayload(String keyOrValue) {
return encode(keyOrValue);
}
/**
* 对经过可逆的加密的密文生成不可解密的hash文本
* @param encodedPayload
* @return
*/
protected String createEncodedPayloadSalt(String encodedPayload) {
return StringUtils.visualizableHash(AlgsUtils.hash(encodedPayload));
}
/**
*
* @param encodedPayloads
* @return 返回校验部分信息,无论真假,如果校验错误,将返回空串
*/
protected String getSign(String... encodedPayloads) {
if(SafeUtils.isEmpty(encodedPayloads)) return "";
StringBuilder sb = new StringBuilder("");
for(String payload : encodedPayloads) {
sb.append(createEncodedPayloadSalt(payload));
}
return sb.toString();
}
/**
* 根据payload生成整个Token
* @param payloads
* @return
*/
public String getToken(String... payloads) {
if(SafeUtils.isEmpty(payloads)) return "";
StringBuilder sb = new StringBuilder("");
String[] encodedPayloads = new String[payloads.length];
for(int i = 0; i < payloads.length; i++) {
encodedPayloads[i] = createTokenPayload(payloads[i]);
sb.append(encodedPayloads[i]+".");
}
sb.append(getSign(encodedPayloads));
return sb.toString();
}
/**
* 解密和校验Token
* @param token
* @return 如果校验失败,会返回null,否则返回Token解密内容
*/
public String[] decodeTokenAndValidate(String token) {
if(token == null) return null;
List<String> result = new LinkedList<>();
for(int i = 0, len = 1; i < token.length(); i++,len++) {
if(token.charAt(i) == '.') {
String payload = (token.substring(i-len+1,i));
result.add(payload);
len = 0;
}
if(i == token.length()-1) {
String salt = token.substring(i-len+1,i+1); len = 0;
String[] encodedPayloads = new String[result.size()];
Iterator<String> itor = result.iterator();
while(itor.hasNext()) {
encodedPayloads[len++] = itor.next();
}
String comp = getSign(encodedPayloads);
if(!salt.equals(comp)) {
return null;
}
String[] decodedPayloads = encodedPayloads; // 引用是一样的
for(len = 0; len < encodedPayloads.length; len++) {
decodedPayloads[len] = decode(encodedPayloads[len]);
}
return decodedPayloads;
}
}
return null;
}
/**
* 考虑到可能多次获取同一个Token信息,因此构造一个map
* @param decodedToken
* @return
*/
public Map<String,String> tokenMap(String[] decodedToken) {
Map<String,String> result = new HashMap<>();
if(decodedToken == null) return result;
for(int i = 0; i < decodedToken.length; i+=2) {
result.put(decodedToken[i],decodedToken[i+1]);
}
return result;
}
/**
* 一次性获取token中某个key
* @param token
* @param key
* @return
*/
public String getTokenAttribute(String token,String key) {
Map<String,String> tokenMap = tokenMap(decodeTokenAndValidate(token));
return tokenMap.get(key);
}
/**
* 获取某个属性的同时返回一个回调用的Map
* [注意] 该map传参会使得原先所有数据丢失
* @param token
* @param key
* @param callbackMap
* @return
*/
public String getTokenAttributeCached(String token,String key,Map<String,String> callbackMap) {
callbackMap.clear();
Map<String,String> result = tokenMap(decodeTokenAndValidate(token));
for(String k : result.keySet()) {
String v = result.get(k);
callbackMap.put(k,v);
}
return callbackMap.get(key);
}
}
搭上IOC配置
@Configuration
public class CacheConfig {
@Autowired
CacheManager cacheManager;
/**
* 当处理token=1e6的数量级时,普通tokenService本地需要10s,cached版本需5s
* 当处理<100数量级时,差不多一样
* @param tokenService
* @return
*/
@Bean(name = "cachedTokenService")
public TokenService cachedTokenService(TokenService tokenService) {
TokenService proxy = (TokenService) cacheManager
.getCacheManagerProxy(tokenService.getClass());
return proxy;
}
}
PART D 应用于WEB
目前用于token的payload有userid和ip,后者是为了进一步提高安全性
并且token是直接放在Cookie里头,方便管理生命周期
写得比较杂乱,先贴部分感受一下吧
@PostMapping("/register")
public @ResponseBody String registerPost(
HttpServletRequest request, HttpServletResponse response,
@RequestParam(value = "email") String email,
@RequestParam(value = "username") String username,
@RequestParam(value = "password") String password) throws IOException {
Boolean isCreated =
userService.createUser(username,password,email,userService.getDefaultUserGroup());
Map<String,Boolean> result = new HashMap<>();
result.put("isCreated",isCreated);
if(isCreated) {
String token = TokenUtils.getToken(
"userID",
String.valueOf(userService.getUserByUsername(username).getUserID()),
"ip",controllerUtils.getRemoteAddr(request)
);
Cookie cookie = new Cookie("token",token);
cookie.setMaxAge(60*60*24*7);
cookie.setHttpOnly(true);
response.addCookie(cookie);
}
return JSONUtils.toJSON(result);
}
其中getRemoteAddr的实现为
public String getRemoteAddr(HttpServletRequest request) {
if ( request.getHeader("X-Real-IP") != null ) {
return request.getHeader("X-Real-IP");
}
return request.getRemoteAddr();
}
PART E 更方便的使用
校验过程太繁琐了,当然要用到AOP,这里采用注解的方式来实现
1.先给一个注解标记
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface NeedLogin {
}
2.接着就是AOP
(请无视直接println
@Component
@Aspect
public class LoginAspect {
@Autowired
ControllerUtils controllerUtils;
/**
* 【约定大于配置】
* 当需要使用@NeedLogin时,token需作为入参的第一个保证AOP成功拦截
* Token校验包括了加盐的检验和IP的对比,以及开启HttpOnly安全设置
* 如果有错会及时把劫持的Cookie删除 // PS.有点小瑕疵
* @param proceedingJoinPoint
* @param token
* @return
* @throws Throwable
*/
@Around(value = "@annotation(com.noresp.oj.annotations.NeedLogin) && args(token,request,response,..)")
public ModelAndView loginCheck(
ProceedingJoinPoint proceedingJoinPoint,
String token,
HttpServletRequest request,
HttpServletResponse response) throws Throwable {
if(token == null) {
System.out.println("没有token");
return ViewUtils.redirect(
"/",new ErrorInfo("login required"));
}
Map<String,String> tokenMap = TokenUtils.tokenMap(
TokenUtils.decodeTokenAndValidate(token));
Integer userID = StringUtils.safeStringToInteger(tokenMap.get("userID"));
String recordedIP = tokenMap.get("ip");
System.out.println(userID+" "+recordedIP);
boolean tokenIllegal =
userID == null || !controllerUtils.getRemoteAddr(request).equals(recordedIP);
if(tokenIllegal) {
System.out.println("token错误");
Cookie fakeToken = controllerUtils.getCookie(request,"token");
if(fakeToken != null) {
fakeToken.setMaxAge(0);
}
return ViewUtils.redirect(
"/",new ErrorInfo("login required"));
}
return (ModelAndView)proceedingJoinPoint.proceed();
}
}
3.使用样例
需要注意AOP没有很好的arg通配方法,这里使用的规约见上面定义
@NeedLogin
@GetMapping("/{problemID}/submit")
public ModelAndView submitView(
@CookieValue(value = "token",required = false) String token,
HttpServletRequest request,
HttpServletResponse response,
@PathVariable("problemID") int problemID) {
ModelAndView view = new ModelAndView("/problems/submit");
Problem problem = problemService.getProblem(problemID);
if(problem == null) {
return ViewUtils.redirect("/",new ErrorInfo("No Such Problem."));
}
view.addObject("problem",problem);
return view;
}
目前的不足 1.token长度受限于Cookie 2.校验的复杂度还是大了点