File

src/providers/mlflow/mlflow.provider.ts

Index

Properties
Methods

Constructor

constructor(httpService: HttpService, influxDbProvider: InfluxDbProvider)
Parameters :
Name Type Optional
httpService HttpService No
influxDbProvider InfluxDbProvider No

Methods

Async fetchMetricsFromSource
fetchMetricsFromSource(model: Model)
Parameters :
Name Type Optional
model Model No
Returns : Promise<any>
Async fetchModelMetricsForSync
fetchModelMetricsForSync(model: Model, queryOptions: literal type)
Parameters :
Name Type Optional
model Model No
queryOptions literal type No
Returns : unknown
Async fetchModelMetricsFromStore
fetchModelMetricsFromStore(model: Model, queryOptions: literal type)
Parameters :
Name Type Optional
model Model No
queryOptions literal type No
Returns : unknown
Async getModel
getModel()
Returns : Promise<any>
Async getModelInfo
getModelInfo()
Returns : Promise<any>
Async getRun
getRun(runId: string)
Parameters :
Name Type Optional
runId string No
Returns : Promise<any>
initialize
initialize(model: Model)
Parameters :
Name Type Optional
model Model No
Returns : void
Async listModels
listModels()
Returns : unknown
Async saveMetrics
saveMetrics(mlflowMetrics, turoAttributes: any)
Parameters :
Name Type Optional
mlflowMetrics No
turoAttributes any No
Returns : unknown
Async searchModels
searchModels(filter: string, max_results: number, order_by: string[], page_token: string)
Parameters :
Name Type Optional
filter string No
max_results number No
order_by string[] No
page_token string No
Returns : Promise<any>
Async searchModelVersions
searchModelVersions(filter: string, max_results: number, order_by: string[], page_token: string)
Parameters :
Name Type Optional
filter string No
max_results number No
order_by string[] No
page_token string No
Returns : Promise<any>

Properties

Private authOptions
Type : null
Default value : null
Private headers
Type : null
Default value : null
Private mlflowApiUrl
Type : null
Default value : null
Private modelName
Type : null
Default value : null
import { Injectable } from "@nestjs/common";
import { HttpService } from "@nestjs/axios";
import { AxiosResponse } from "axios";
import { lastValueFrom } from "rxjs";
import { findProdVersion } from "./utils";
import { Model, ModelProvider } from "../../models/entities/model.entity";
import { InfluxDbProvider } from "../influxDB/influxdb.provider";

@Injectable()
export class MLflowProvider {
  private mlflowApiUrl = null;
  private authOptions = null;
  private headers = null;
  private modelName = null;

  constructor(
    private readonly httpService: HttpService,
    private readonly influxDbProvider: InfluxDbProvider,
  ) {}

  initialize(model: Model): void {
    if (model.provider != ModelProvider.MLFLOW) {
      throw new Error(
        `Can't intialize MLflowProvider when model provider isn't ${ModelProvider.MLFLOW}`,
      );
    }
    if (!model.modelUrl) {
      throw new Error("Model doesn't have a valid modelUrl");
    }

    const url = new URL(model.modelUrl);
    this.mlflowApiUrl = `${url.origin}/api`;
    this.authOptions = JSON.parse(model.authInfo);
    this.modelName = url.href.split("models")[1].replace(/^(\/)/, "");

    if (this.authOptions) {
      const auth = Buffer.from(
        `${this.authOptions["username"]}:${this.authOptions["password"]}`,
      ).toString("base64");
      this.headers = {
        Authorization: `Basic ${auth}`,
      };
    }
  }

  async fetchMetricsFromSource(model: Model): Promise<any> {
    this.initialize(model);
    const modelInfo = await this.getModelInfo();
    if ("run" in modelInfo) {
      return modelInfo["run"]["data"]["metrics"];
    } else {
      throw "Couldn't find a valid run";
    }
  }

  async fetchModelMetricsFromStore(
    model: Model,
    queryOptions: {
      startTime: number;
      endTime: number;
    },
  ) {
    const start = new Date(queryOptions.startTime);
    const end = new Date(queryOptions.endTime);
    return await this.influxDbProvider.fetchModelMetricsFromStore(
      model,
      start,
      end,
    );
  }

  async fetchModelMetricsForSync(
    model: Model,
    queryOptions: {
      startTime: number;
      endTime: number;
    },
  ) {
    const start = new Date(queryOptions.startTime);
    const end = new Date(queryOptions.endTime);
    return await this.influxDbProvider.fetchModelMetricsForSync(
      model,
      start,
      end,
    );
  }

  async getModelInfo(): Promise<any> {
    const modelInfo = await this.getModel();
    const modelVersions = await this.searchModelVersions(
      `name LIKE "${this.modelName}"`,
      1000,
      [],
      null,
    );

    let latestModelVersion = null;
    if ("latest_versions" in modelInfo) {
      latestModelVersion = modelInfo["latest_versions"][0];
    }
    const prodModelVersion = findProdVersion(modelVersions);

    const chosenVersion = prodModelVersion || latestModelVersion;

    if (!chosenVersion) {
      throw "Model doesn't have a valid ModelVersion";
    }
    if (!chosenVersion["run_id"]) {
      throw "Model doesn't have a valid Run";
    }

    const run = await this.getRun(chosenVersion["run_id"]);
    const returnInfo = {
      model: modelInfo,
      version: chosenVersion,
      run: run,
    };
    return returnInfo;
  }

  async getModel(): Promise<any> {
    const endpoint = `${this.mlflowApiUrl}/2.0/mlflow/registered-models/get`;
    const response: AxiosResponse<any> = await lastValueFrom(
      this.httpService.get(endpoint, {
        params: { name: this.modelName },
        headers: this.headers,
      }),
    );

    if (response.status === 200) {
      const model = response.data.registered_model;
      return model;
    } else {
      console.error(`Failed to get`);
      return {};
    }
  }

  async getRun(runId: string): Promise<any> {
    const endpoint = `${this.mlflowApiUrl}/2.0/mlflow/runs/get`;
    const response: AxiosResponse<any> = await lastValueFrom(
      this.httpService.get(endpoint, {
        params: { run_id: runId },
        headers: this.headers,
      }),
    );

    if (response.status === 200) {
      const run = response.data.run;
      return run;
    } else {
      console.error(`Failed to get run`);
      return {};
    }
  }

  async searchModels(
    filter: string,
    max_results: number,
    order_by: string[],
    page_token: string,
  ): Promise<any> {
    const endpoint = `${this.mlflowApiUrl}/2.0/mlflow/registered-models/search`;
    const response: AxiosResponse<any> = await lastValueFrom(
      this.httpService.get(endpoint, {
        params: {
          filter: filter,
          max_results: max_results,
          order_by: order_by,
          page_token: page_token,
        },
        headers: this.headers,
      }),
    );

    if (response.status === 200) {
      const models = response.data.registered_models;
      return models;
    } else {
      console.error(`Failed to search for models.`);
      return [];
    }
  }

  async listModels() {
    return await this.searchModels(null, 1000, [], null);
  }

  async searchModelVersions(
    filter: string,
    max_results: number,
    order_by: string[],
    page_token: string,
  ): Promise<any> {
    const endpoint = `${this.mlflowApiUrl}/2.0/mlflow/model-versions/search`;
    const response: AxiosResponse<any> = await lastValueFrom(
      this.httpService.get(endpoint, {
        params: {
          filter: filter,
          max_results: max_results,
          order_by: order_by,
          page_token: page_token,
        },
        headers: this.headers,
      }),
    );

    if (response.status === 200) {
      const versions = response.data.model_versions;
      return versions;
    } else {
      console.error(`Failed to search for models versions`);
      return [];
    }
  }

  async saveMetrics(mlflowMetrics, turoAttributes: any) {
    return await this.influxDbProvider.saveMetrics(
      mlflowMetrics,
      turoAttributes,
    );
  }
}

results matching ""

    No results matching ""