9个NestJS企业级自定义装饰器,你知道几个?

490 阅读9分钟

大家好,我是元兮。

想必有很多同学都想知道NestJS在实际的公司到底是怎么用的?今天就聊一聊最常用的装饰器在企业项目中的实际应用。(没基础的同学莫慌,继续往下看~

说到装饰器,不得不先提及Nest中核心的设计思想:AOP切面编程

装饰器被广泛用于实现AOP,例如通过类装饰器、方法装饰器、属性装饰器和参数装饰器来实现。

Nest中内置了非常丰富并开箱即用的装饰器,例如:

  • 类装饰器:@Controller、@Injectable、@Module、@UseInterceptors
  • 方法装饰器:@Get、@Post、@UseInterceptors
  • 属性装饰器:@IsNotEmpty、@IsString、@IsNumber
  • 参数装饰器:@Body、@Param、@Query

此外,业务中还有很多场景需要自定义装饰器,实现公共逻辑统一管理,例如:

  1. 获取请求IP地址
  2. 请求域名合法性判断
  3. 获取平台标识
  4. 获取用户信息
  5. QPS限制
  6. 接口下线管理
  7. 接口白名单化
  8. 获取用户所在国家代码
  9. 自定义swagger响应数据结构

下面一起来看看~

PS: 对装饰器和AOP概念陌生的同学,可以查看之前的文章如何理解装饰器?Nest实现AOP切面学习。

如果你觉得还不尽兴,我在图书 《NestJS全栈开发解析:快速上手与实战》 中进行了专门的阐述和代码演示,感兴趣朋友可以上手(支持纸质和电子版)~

京东:item.jd.com/14283389.ht… 

当当:product.dangdang.com/29783482.ht… 

1. 获取IP地址

根据请求头获取IP地址属于最常用、最简单的一种,创建一个参数装饰器,从header头部取对应属性即可。

import { createParamDecorator, ExecutionContext } from '@nestjs/common';

export const Ip = createParamDecorator(async (data: string, ctx: ExecutionContext) => {
  const request = ctx.switchToHttp().getRequest();
  return request.headers['x-real-ip'] || request.ip;
});

其中,x-real-ip可以是客户端(前端)设置的IP地址,但通常是由Nginx层统一配置,用于服务端获取客户端IP。

然后这样使用:

  async getXxx(
    @Ip() ip: string,
  ) {
    // 处理业务逻辑
    console.log(ip)
  }

2. 请求域名合法性判断

请求域名合法性判断也是如此,通过请求头中的refererorigin判断是否在白名单内,决定是否放行。

import { createParamDecorator, ExecutionContext } from '@nestjs/common';
import { Request } from 'express';

/* 验证请求来源是否合法域名 */
export const ValidReferer = createParamDecorator((data: unknown, ctx: ExecutionContext) => {
  if (process.env.APP_ENV !== 'prod') {
    return true;
  }
  const request: Request = ctx.switchToHttp().getRequest();
  const referer = request.headers.referer || request.headers.origin;
  return isValidUrl(referer);
});

/**
 * @description 判断是否是合法域名
 */
export function isValidUrl(url) {
  if (!url) {
    return false;
  }
  let result = false;
  const arr = url.split('/');
  if (arr && arr[2]) {
    if (
      arr[2].endsWith('aaa.com') ||
      arr[2].endsWith('bbb.com') ||
      arr[2].endsWith('ccc.com') ||
      arr[2].endsWith('ddd.com')
    ) {
      result = true;
    }
  }
  return result;
}

使用方式如下:

  async getXxx(
    @ValidReferer() isValidReferer: boolean,
  ) {
    // 处理业务逻辑
    console.log(ip)
  }

3. 获取平台标识

业务涉及到多个端的请求,如Android、iOS、Web等等,需要根据不同端接口进行业务逻辑判断,此时就可以通过参数装饰器来实现。

import { createParamDecorator, ExecutionContext } from '@nestjs/common';
import { TaskPlatform } from 'src/app.schema';
import { Request } from 'express';

/* 获取平台 */
export const Platform = createParamDecorator((data: unknown, ctx: ExecutionContext): TaskPlatform => {
  const request: Request = ctx.switchToHttp().getRequest();
  let platform = request.headers['x-platform'] as TaskPlatform;
  if (!platform || !Object.values(TaskPlatform).includes(platform)) {
    platform = TaskPlatform.WEB;
  }
  return platform as TaskPlatform;
});

TaskPlatform服务端维护的枚举类:

/** 平台 */
export enum TaskPlatform {
  WEB = 'web',
  IOS = 'ios',
  ANDROID = 'android',
}

需要客户端在请求接口时,在请求头上带上平台标识。装饰器使用方式如下:

  async getXxx(
    @Platform() platform: TaskPlatform,
  ) {
    // 处理业务逻辑
    console.log(platform)
  }

4. 获取用户信息

登陆成功后,用户信息通常会存到请求头中的Session中,以便后续的请求能够直接获取到用户的登陆态信息,比如uidaid这种:

import { createParamDecorator, ExecutionContext } from '@nestjs/common';

/* 获取用户信息 */
export const UserInfoDecor = createParamDecorator((data: unknown, ctx: ExecutionContext) => {
  const request = ctx.switchToHttp().getRequest();
  const userInfo = request.session?.userInfo
  return userInfo ?? null;
});

注意,上面是从Session中获取的用户信息,它可以在守卫里面手动设置到Request对象中,这样在有效期内都可以直接获取。

import { Injectable, CanActivate, HttpException, HttpStatus, ExecutionContext } from '@nestjs/common';
import { Request } from 'express';
import { AuthService } from 'src/modules/auth/auth.service';
import { UserInfo } from 'src/app.schema';
import { CacheService } from 'src/modules/shared/cache/cache.service';

@Injectable()
export class AuthGuard implements CanActivate {
  // 白名单内的路由地址不做校验
  private whiteList: string[] = [];
  constructor(
    private readonly cacheService: CacheService,
    private readonly authService: AuthService,
  ) {}

  async canActivate(context: ExecutionContext): Promise<boolean> {
    const request: Request = context.switchToHttp().getRequest();
    if (this.whiteList.includes(request.route.path)) {
      return true;
    }

    const authHeader = request.headers.authorization;
    const auth_token = authHeader && authHeader.split(' ')[1];
    // token 不存在,则表示未登录
    const token: string = request.headers.token || request.cookies.token || auth_token;
    if (!token) {
      delete request.session['userInfo'];
      throw new HttpException('UNAUTHORIZED', HttpStatus.UNAUTHORIZED);
    }
    if (request.session['userInfo']?.uid) {
      return true;
    } else {
      const userInfo: UserInfo = await this.authService.getUserInfo(token);
      if (userInfo?.uid) {
        request.session['userInfo'] = userInfo;
        return true;
      } else {
        throw new HttpException('UNAUTHORIZED', HttpStatus.UNAUTHORIZED);
      }
    }
  }
}

每个需要鉴权的接口都绑定这个守卫,然后这样使用:

  @UseGuards(AuthGuard) 
  @Get()
  async getXxx(
    @UserInfoDecor() userInfo: UserInfo,
  ) {
    // 处理业务逻辑
    console.log(userInfo.uid)
  }

5. QPS限制器

业务中存在一些昂贵操作的请求接口,为了防止羊毛党或爬虫脚本恶意刷量,我们通常会对其进行限制并发数,比如限制接口A每秒只能请求1次

import { SetMetadata, UseGuards } from '@nestjs/common';
import { applyDecorators } from '@nestjs/common';
import { QpsGuard } from 'src/guard/qps.guard';

export function QpsLimit(options: { limit: number; seconds?: number; env?: string }) {
  return applyDecorators(SetMetadata('qpsOptions', options), UseGuards(QpsGuard));
}

限制器接收limitsecondsenv作为参数,分别表示限制次数、限制秒数和环境变量。通过来组合applyDecorators多个装饰器。

applyDecorators支持组合多个不同的装饰器,如果你认为在一个地方需要添加多个装饰器来完成一个功能,不妨考虑将它们组合起来,例如:

import { applyDecorators } from '@nestjs/common'; 
function CustomDecorator() { 
    return applyDecorators( Decorator1(), Decorator2(), Decorator3() ); 
}

核心逻辑在QpsGuard中,守卫中通过内存缓存cache-manage维护一个访问量的标记位,每访问一次接口,都会往内存中的visits数加1,如果在指定的时间内visits数超过我们设置的limit,接口会抛出一个QpsLimitException,阻止接口继续请求。

import { Injectable, CanActivate, HttpException, HttpStatus, ExecutionContext } from '@nestjs/common';
import { Reflector } from '@nestjs/core';
import { Request } from 'express';
import { MemCacheService } from 'src/modules/shared/cache/memCache.service';

@Injectable()
export class QpsGuard implements CanActivate {
  constructor(
    private readonly reflector: Reflector,
    private readonly memCacheService: MemCacheService,
  ) {}
  async canActivate(context: ExecutionContext): Promise<boolean> {
    const request: Request = context.switchToHttp().getRequest();
    const optionsForHandler = this.reflector.get('qpsOptions', context.getHandler());
    const optionsForClass = this.reflector.get('qpsOptions', context.getClass());

    const { limit, seconds, env = 'prod' } = optionsForHandler ?? optionsForClass ?? {};

    if (process.env.APP_ENV !== env) {
      return true;
    }

    const expire = seconds ?? 1;
    const visitKey: string = 'qps-limit|' + request.route.path + '|' + (request.headers['x-real-ip'] || request.ip);
    const visits = Number((await this.memCacheService.get(visitKey)) || 0);

    if (visits >= limit) {
      throw new QpsLimitException();
    }
    await this.memCacheService.set(visitKey, visits + 1, expire);

    return true;
  }
}

export class QpsLimitException extends HttpException {
  constructor() {
    super('Too many requests in short period of time, please try later.', HttpStatus.TOO_MANY_REQUESTS);
  }
}

最后,在接口中这样使用:

  @QpsLimit({ env: 'prod', limit: 1, seconds: 1 })
  @Get()
  async getXxx() {
    // 处理业务逻辑
  }

PS: 如果你的业务比较简单,没有定制化需求,也可以直接使用官网提供的限速器。

// Override default configuration for Rate limiting and duration.
@Throttle({ default: { limit: 3, ttl: 60000 } })
@Get()
findAll() {
  return "List users works with custom rate limiting.";
}

6. 接口下线管理

业务接口通常要求做版本管理,尤其是针对APP这种需要安装后才能更新最新版本的,为了兼容很久不更新的用户,使用v1、v2来标识不同版本,经过一段时间的缓冲,最终会对旧版本进行下线操作。

import { Injectable, CanActivate, HttpException, HttpStatus, ExecutionContext, Logger } from '@nestjs/common';
import { Reflector } from '@nestjs/core';
import { Request } from 'express';
import { TaskPlatform } from 'src/app.schema';

@Injectable()
export class OfflineGuard implements CanActivate {
  constructor(
    private readonly reflector: Reflector,
    private readonly logger: Logger,
  ) {}
  async canActivate(context: ExecutionContext): Promise<boolean> {
    const request: Request = context.switchToHttp().getRequest();
    const optionsForHandler = this.reflector.get('offlineOptions', context.getHandler());
    const optionsForClass = this.reflector.get('offlineOptions', context.getClass());
    const { date, platform } = optionsForHandler ?? optionsForClass ?? {};

    let xplatform = request.headers['x-platform'] as TaskPlatform;
    if (!xplatform || !Object.values(TaskPlatform).includes(xplatform)) {
      xplatform = TaskPlatform.WEB;
    }

    if (!date || Date.now() > date.getTime()) {
      if (platform && platform.length > 0) {
        if (platform.includes(xplatform)) {
          this.logger.warn(`Service is currently offline`, {
            module: 'OfflineGuard',
            action: 'offline',
            path: request.route.path,
          });
          throw new HttpException('Service is currently offline.', HttpStatus.GONE);
        }
        return true;
      } else {
        this.logger.warn(`Service is currently offline`, {
          module: 'OfflineGuard',
          action: 'offline',
          path: request.route.path,
        });
        throw new HttpException('Service is currently offline.', HttpStatus.GONE);
      }
    }

    return true;
  }
}

别看上面代码这么多,大部分是logger统计和参数获取,实际上是判断是否在规定的时间下架某个平台(Android、iOS、Web)的接口,如果业务访问了下架的接口,会提示Service is currently offline.错误。

路由方法或者控制器中都可以使用,例如:

@Offline()
@Controller('/api/v1/xxx')
export class XxxController {
  constructor() {}

  /**
   * 初始化
   */
  @Offline()
  async init(
  ): Promise<xxx> {
      // 业务逻辑
  }
}

事实上Nest也内置了版本管理,支持使用路由、请求头和自定义规则实现版本控制,而且能够支持同时存在多版本,但访问规则之外的旧版URL就会404,而上面的Offline是做最后的用户体验兜底。

7. 接口白名单化

场景:由于我们的接口采用v1、v2这种方式进行管理,所以用户很容易猜到下个版本是v3,当业务有些接口更新了最新版本,升级了底层算法,但是不希望所有用户都使用,而是针对指定用户开放。

如何实现这种效果呢?本质上也是通过守卫来实现。

import { SetMetadata, UseGuards } from '@nestjs/common';
import { applyDecorators } from '@nestjs/common';
import { RestrictGuard } from 'src/guard/restrict.guard';

export function Restrict(options: { name: string }) {
  return applyDecorators(SetMetadata('restrictOptions', options), UseGuards(RestrictGuard));
}

装饰器接受一个name参数,表示白名单的模块名称,比如ai-image,用作Redis key的一部分。RestrictGuard的具体实现如下:

import { Injectable, CanActivate, HttpException, HttpStatus, ExecutionContext } from '@nestjs/common';
import { Reflector } from '@nestjs/core';
import { Request } from 'express';
import { CacheService } from 'src/modules/shared/cache/cache.service';

/** 限制仅部分用户可访问接口 */
@Injectable()
export class RestrictGuard implements CanActivate {
  constructor(
    private readonly reflector: Reflector,
    private readonly cacheService: CacheService,
  ) {}

  async canActivate(context: ExecutionContext): Promise<boolean> {
    const request: Request = context.switchToHttp().getRequest();
    const optionsForHandler = this.reflector.get('restrictOptions', context.getHandler());
    const optionsForClass = this.reflector.get('restrictOptions', context.getClass());
    const { name } = optionsForHandler ?? optionsForClass ?? {};

    const uid = request.session['userInfo']?.['uid'];

    if (uid) {
      if (await this.cacheService.sExists(`restriction:${name}`, uid)) {
        return true;
      }
    }

    throw new HttpException('RESTRICTED', HttpStatus.FORBIDDEN);
  }
}

其中,CacheService是封装用于操作Redis的基础模块,这里不做过多介绍。通过判断当前uid是否存在Redis中决定是否放行。

那Redis中的是数据什么时候set进去的呢?

这个其实很简单,我们维护了一个白名单用户的飞书表格,通过定时任务隔一段时间从表格中拉取所有的uid,批量写到Redis中即可。

最后在接口中这样使用:

  @Restrict({ name: 'ai-image' })
  @Post()
  async doXxx() {
    // 处理业务逻辑
  }

8. 获取用户所在国家代码

国际化业务中,我们需要获取当前注册或登录用户属于哪个国家,以便进行定制化服务。此时可以通过IP地址来获取用户所在区域,调用http://ip-api.com/json/${ip}接口来实现,直接上代码:

import { createParamDecorator, ExecutionContext } from '@nestjs/common';

/* 获取用户所在国家代码 */
export const Region = createParamDecorator(async (data: string, ctx: ExecutionContext) => {
  const request = ctx.switchToHttp().getRequest();
  const ip = data ?? (request.headers['x-real-ip'] || request.ip);
  // Using cookies region as workaround when ip api request failed.
  const region = request.cookies.region || 'US';

  try {
    const controller = new AbortController();
    const signal = controller.signal;
    const timeoutId = setTimeout(() => controller.abort(), 3000);

    const res = await fetch(`http://ip-api.com/json/${ip}`, { signal });

    clearTimeout(timeoutId);
    if (res.status === 200) {
      const resJson = await res.json();
      return resJson.countryCode ?? region;
    }
  } catch (error) {
    console.error(error)
  }

  return region;
});

然后在接口获取参数时使用:

  @Get()
  async getXxx(
    @Region() region: String,
  ) {
    // 处理业务逻辑
    console.log(region)
  }

9. 自定义swagger响应数据结构

服务端开发中通常都会使用Swagger自动生成接口文档,Nest与Swagger集成了各种开箱即用的响应体,比如下面这些:

image.png

但很多时候业务需要自定义返回的数据结构,添加额外的一些信息,比如返回对象和列表的结构是不同的,返回列表还需要把分页信息返回给前端。这样使用:

  @Get()
  @ApiListResponse(XxxListDto, PaginatedDto)
  async getXxx() {
    // 处理业务逻辑
  }

其中,PaginatedDto是这样的:

/**
 * 分页信息
 */
export class PaginatedDto extends MetaDto {
  @ApiProperty()
  pageSize: number;

  @ApiProperty()
  pageIndex: number;

  @ApiProperty()
  total: number;
}

另外,返回对象结构时这样使用:

  @Get()
  @ApiCustomResponse(XxxResultDto, MetaDto)
  async getXxx() {
    // 处理业务逻辑
  }

其中,MetaDto是这样的:


/**
 * 通用 meta 信息
 */
export class MetaDto {
  @ApiProperty({ required: false, description: '是否命中缓存' })
  cache?: boolean;

  @ApiProperty({ required: false, description: '过期时间,秒' })
  ttl?: number;

  @ApiProperty({ required: false, description: '未读数量' })
  unread?: number;

  @ApiProperty({ required: false, description: 'data 响应体的 md5 值' })
  hash?: string;

  @ApiProperty({ required: false, description: '是否使用过ai功能' })
  usedAi?: boolean;
}

这时就可以使用自定义装饰器来完成,为不同的响应数据结构组合不同的结果出来:

import { applyDecorators } from '@nestjs/common';
import { ApiExtraModels, ApiOkResponse, getSchemaPath } from '@nestjs/swagger';
import { ResponseData } from 'src/app.schema';

export const ApiCustomResponse = (dataDto?: any, metaDto?: any) => {
  const models = [ResponseData];
  if (dataDto) {
    models.push(dataDto);
  }
  if (metaDto) {
    models.push(metaDto);
  }
  return applyDecorators(
    ApiExtraModels(...models),
    ApiOkResponse({
      schema: {
        required: ['data', metaDto ? 'meta' : ''],
        allOf: [
          { $ref: getSchemaPath(ResponseData) },
          {
            properties: Object.assign(
              dataDto
                ? {
                    data: {
                      $ref: getSchemaPath(dataDto),
                    },
                  }
                : {},
              metaDto ? { meta: { $ref: getSchemaPath(metaDto) } } : {},
            ),
          },
        ],
      },
    }),
  );
};

export const ApiListResponse = (dataDto: any, metaDto?: any) => {
  const models = [ResponseData, dataDto];
  if (metaDto) {
    models.push(metaDto);
  }
  return applyDecorators(
    ApiExtraModels(...models),
    ApiOkResponse({
      schema: {
        required: ['data', metaDto ? 'meta' : ''],
        allOf: [
          { $ref: getSchemaPath(ResponseData) },
          {
            properties: Object.assign(
              {
                data: {
                  type: 'array',
                  items: {
                    $ref: getSchemaPath(dataDto),
                  },
                },
              },
              metaDto ? { meta: { $ref: getSchemaPath(metaDto) } } : {},
            ),
          },
        ],
      },
    }),
  );
};

其中,ResponseData是这样的:

/**
 * 服务端返回给客户端的数据包结构定义,也就是响应数据里的 `data` 字段。
 */
export class ResponseData<T = undefined, M = undefined> {
  /**
   * 业务码。
   */
  @ApiProperty({ description: '业务码, 成功: 0, 失败: -1, 禁止操作: 403' })
  readonly code: ResponseCode;

  /**
   * 动作
   */
  @ApiProperty({ description: '动作', required: false })
  readonly action?: string;

  /**
   * 业务消息。
   */
  @ApiProperty()
  readonly message: string;

  /**
   * 业务数据。
   */
  readonly data?: T;

  /**
   * 其他业务数据
   */
  readonly meta?: M;

  /**
   * 创建一个代表请求成功的 `Response` 对象。
   *
   * @param options.data 请求成功时的数据
   * @param options.message 请求成功时的消息,默认为 `Success`
   * @returns 代表请求成功的 `Response` 对象
   */
  static Success<T = undefined, M = undefined>(options?: {
    data?: T;
    meta?: M;
    message?: string;
    action?: string;
  }): ResponseData<T, M> {
    return {
      code: ResponseCode.SUCCESS,
      message: options?.message || 'Success',
      data: options?.data,
      meta: options?.meta,
      action: options?.action,
    };
  }

  /**
   * 创建一个代表请求失败的 `Response` 对象。
   *
   * @param options.code 请求失败时的业务码,默认为 `ResponseCode.Failure`
   * @param options.message 请求失败时的消息,默认为 `Failure`
   * @returns
   */
  static Failure<T = undefined>(options?: {
    code?: ResponseCode;
    message?: string;
    action?: string;
    data?: T;
  }): ResponseData<T> {
    return {
      code: options?.code || ResponseCode.FAILURE,
      message: options?.message || 'Failure',
      action: options?.action,
      data: options?.data,
    };
  }
}

这样就完成了为不同类型的响应数据定制不同的数据结构了。

总结

前面我们介绍了9种自定义装饰器的场景,其实在实际项目中不止这些,类似还有获取token数据获取请求用户语言上报请求日志等十几种,但原理都大差不差,涉及到参数获取的场景我们都可以灵活利用请求头来传递或保存信息。

另外,涉及到请求限制、白名单相关的场景,优先考虑守卫或拦截器场景,采用组合多个装饰器的方式实现。