GrpcSupport.scala

package wechaty.padplus.support

import java.io.InputStream
import java.util.UUID
import java.util.concurrent.atomic.AtomicLong
import java.util.concurrent.{Executors, TimeUnit}

import com.amazonaws.auth.{AWSStaticCredentialsProvider, BasicAWSCredentials}
import com.amazonaws.regions.Regions
import com.amazonaws.services.s3.AmazonS3ClientBuilder
import com.amazonaws.services.s3.model.{CannedAccessControlList, GeneratePresignedUrlRequest, ObjectMetadata, PutObjectRequest}
import com.fasterxml.jackson.databind.JsonNode
import com.typesafe.scalalogging.LazyLogging
import io.grpc.stub.{ClientCalls, StreamObserver}
import io.grpc.{ManagedChannel, ManagedChannelBuilder, MethodDescriptor}
import wechaty.padplus.PuppetPadplus
import wechaty.padplus.grpc.PadPlusServerGrpc
import wechaty.padplus.grpc.PadPlusServerOuterClass._
import wechaty.puppet.schemas.Puppet

import scala.concurrent.duration._
import scala.concurrent.{Await, ExecutionContext, Future, Promise}
import scala.reflect.ClassTag
import scala.reflect.runtime.universe._

/**
  *
  * @author <a href="mailto:jcai@ganshane.com">Jun Tsai</a>
  * @since 2020-06-21
  */
trait GrpcSupport {
  self: PuppetPadplus with LazyLogging =>
  private            val executorService                                       = Executors.newSingleThreadScheduledExecutor()
  //from https://github.com/wechaty/java-wechaty/blob/master/wechaty-puppet/src/main/kotlin/Puppet.kt
  private            val HEARTBEAT_COUNTER                                     = new AtomicLong()
  private            val HOSTIE_KEEPALIVE_TIMEOUT                              = 15 * 1000L
  private            val DEFAULT_WATCHDOG_TIMEOUT                              = 60L
  //  protected var grpcClient: PadPlusServerGrpc.PadPlusServerBlockingStub= _
  private   var asyncGrpcClient          : PadPlusServerGrpc.PadPlusServerStub = _
  protected var channel                  : ManagedChannel                      = _
  protected implicit val executionContext: ExecutionContext                    = scala.concurrent.ExecutionContext.Implicits.global


  protected def startGrpc(endpoint: String): Unit = {
    initChannel(endpoint)
    internalStartGrpc()
    //from https://github.com/wechaty/java-wechaty/blob/master/wechaty-puppet/src/main/kotlin/Puppet.kt
    executorService.scheduleAtFixedRate(() => {
      try {
        asyncRequest[JsonNode](ApiType.HEARTBEAT)
      } catch {
        case e: Throwable =>
          logger.warn("ding exception:{}", e.getMessage)
        //ignore any exception
      }
    }, HOSTIE_KEEPALIVE_TIMEOUT, HOSTIE_KEEPALIVE_TIMEOUT, TimeUnit.MILLISECONDS)
  }

  protected def initChannel(endpoint: String) = {
    option.channelOpt match {
      case Some(channel) =>
        this.channel = channel
      case _ =>
        /*
    this.channel = NettyChannelBuilder
      .forTarget(endpoint)
      .keepAliveTime(20, TimeUnit.SECONDS)
      //      .keepAliveTimeout(2, TimeUnit.SECONDS)
      .keepAliveWithoutCalls(true)
            .idleTimeout(2, TimeUnit.HOURS)
      .enableRetry()
      .usePlaintext().build()
      */
        this.channel = ManagedChannelBuilder.forTarget(endpoint)
          .maxInboundMessageSize(1024 * 1024 * 150)
          .usePlaintext().build()
    }
  }

  protected def reconnectStream() {
    logger.info("reconnect stream stream...")
    try {
      stopGrpc()
    } catch {
      case e: Throwable =>
        logger.warn("fail to stop grpc {}", e.getMessage)
    }
    internalStartGrpc()
    logger.info("reconnect stream stream done")

  }

  private def internalStartGrpc() {
    logger.info("start grpc client ....")
    //    this.grpcClient = PadPlusServerGrpc.newBlockingStub(channel)
    this.asyncGrpcClient = PadPlusServerGrpc.newStub(channel)
    //    startStream()
    logger.info("start grpc client done")
  }

  private[wechaty] def startStream() {
    val initConfig = InitConfig.newBuilder().setToken(option.token.get).build()
    this.asyncGrpcClient.init(initConfig, this)
  }

  protected def stopGrpc(): Unit = {
    if (option.channelOpt.isEmpty) { //if no test!
      //stop stream
      stopStream()

      //stop grpc client
      //      this.grpcClient.request(RequestObject.newBuilder().setApiType(ApiType.CLOSE).setToken(option.token.get).build())
      this.channel.shutdownNow()
    }
  }

  private def stopStream(): Unit = {
    //do nothing
  }

//  protected def syncRequest[T: TypeTag](apiType: ApiType, data: Option[Any] = None)(implicit classTag: ClassTag[T]): T = {
//    val future = asyncRequest[T](apiType, data)
//    Await.result(future, 10 seconds)
//  }

  protected def generateTraceId(apiType: ApiType): String = {
    UUID.randomUUID().toString
  }

  //can't create Promise[Nothing] instance,so use the method create Future[Unit]
  protected def asyncRequestNothing(apiType: ApiType, data: Option[Any] = None): Future[Unit] = {
    val request = RequestObject.newBuilder()
    request.setToken(option.token.get)
    uinOpt match {
      case Some(id) =>
        request.setUin(id)
      case _ =>
    }
    request.setApiType(apiType)
    data match {
      case Some(str: String) =>
        request.setParams(str)
      case Some(d) =>
        request.setParams(Puppet.objectMapper.writeValueAsString(d))
      case _ =>
    }

    val future = asyncCall(PadPlusServerGrpc.getRequestMethod, request.build())
    future.map { rep =>
      if (rep.getResult != "success") {
        logger.warn("fail to request {}", rep)
        throw new IllegalAccessException("fail to request ,grpc result:" + rep)
      }
    }
  }

  protected def asyncRequest[T: TypeTag](apiType: ApiType, data: Option[Any] = None)(implicit classTag: ClassTag[T]): Future[T] = {
    typeOf[T] match {
      case t if t =:= typeOf[Nothing] =>
        throw new IllegalAccessException("generic type is nothing,maybe you should use asyncRequestNothing !")
      case t if t =:= typeOf[RuntimeClass] =>
        throw new IllegalAccessException("generic type is nothing,maybe you should use asyncRequestNothing !")
      case other =>
        logger.debug(s"async request generic type is $other")
    }
    val request = RequestObject.newBuilder()
    request.setToken(option.token.get)
    uinOpt match {
      case Some(id) =>
        request.setUin(id)
      case _ =>
    }
    request.setApiType(apiType)
    data match {
      case Some(str: String) =>
        request.setParams(str)
      case Some(d) =>
        request.setParams(Puppet.objectMapper.writeValueAsString(d))
      case _ =>
    }
    val requestId = UUID.randomUUID().toString
    request.setRequestId(requestId)
    val traceId = generateTraceId(apiType)
    request.setTraceId(traceId)
    logger.debug("request:{}", request.build())

    val callbackPromise = Promise[StreamResponse]
    CallbackHelper.pushCallbackToPool(traceId, callbackPromise)
    val future = asyncCall(PadPlusServerGrpc.getRequestMethod, request.build())
    future.flatMap { rep =>
      if (rep.getResult != "success") {
        logger.warn("fail to request:{}", rep)
        callbackPromise.failure(new IllegalAccessException("fail to request ,grpc result:" + rep))
      }
      callbackPromise.future
    }.map { streamResponse =>
      typeOf[T] match {
        case t if t =:= typeOf[JsonNode] =>
          Puppet.objectMapper.readTree(streamResponse.getData).asInstanceOf[T]
        case _ =>
          Puppet.objectMapper.readValue(streamResponse.getData, classTag.runtimeClass).asInstanceOf[T]
      }
    }
  }

  type ClientCallback[RespT, T] = RespT => T

  protected def asyncCall[ReqT, RespT](call: MethodDescriptor[ReqT, RespT], req: ReqT): Future[RespT] = {
    asyncCallback(call, req)(resp => resp)
  }

  def asyncCallback[ReqT, RespT, T](callMethod: MethodDescriptor[ReqT, RespT], req: ReqT)(callback: ClientCallback[RespT, T]): Future[T] = {
    val call    = channel.newCall(callMethod, asyncGrpcClient.getCallOptions)
    val promise = Promise[T]
    ClientCalls.asyncUnaryCall(call, req, new StreamObserver[RespT] {
      override def onNext(value: RespT): Unit = {
        val result = callback(value)
        promise.success(result)
      }

      override def onError(t: Throwable): Unit = {
        logger.error(t.getMessage,t)
        promise.failure(t)
      }

      override def onCompleted(): Unit = {
        if (!promise.isCompleted) promise.failure(new IllegalStateException("server completed"))
      }
    })
    promise.future
  }

  private val ACCESS_KEY_ID     = "AKIA3PQY2OQG5FEXWMH6"
  private val BUCKET            = "macpro-message-file"
  private val EXPIRE_TIME       = 3600 * 24 * 3
  private val PATH              = "image-message"
  private val SECRET_ACCESS_KEY = "jw7Deo+W8l4FTOL2BXd/VubTJjt1mhm55sRhnsEn"
  //  private val s3 = new AmazonS3Client(new BasicAWSCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY));
  private val s3                = AmazonS3ClientBuilder.standard()
    .withCredentials(new AWSStaticCredentialsProvider(new BasicAWSCredentials(ACCESS_KEY_ID, SECRET_ACCESS_KEY)))
    .enablePayloadSigning()
    .withRegion(Regions.CN_NORTHWEST_1).build(); // 此处根据自己的 s3 地区位置改变

  def uploadFile(filename: String, stream: InputStream) {
    //    ACL: "public-read",
    //    const s3 = new AWS.S3({ region: "cn-northwest-1", signatureVersion: "v4" })
    val meta       = new ObjectMetadata
    val key        = PATH + "/" + filename
    val params     = new PutObjectRequest(BUCKET, key, stream, meta)
    val result     = s3.putObject(params.withCannedAcl(CannedAccessControlList.PublicRead));
    //获取一个request
    val urlRequest = new GeneratePresignedUrlRequest(BUCKET, key);
    //生成公用的url
    s3.generatePresignedUrl(urlRequest);
  }
}