src/providers/mlflow/mlflow.provider.ts
Properties |
|
Methods |
|
constructor(httpService: HttpService, influxDbProvider: InfluxDbProvider)
|
|||||||||
|
Defined in src/providers/mlflow/mlflow.provider.ts:14
|
|||||||||
|
Parameters :
|
| Async fetchMetricsFromSource | ||||||
fetchMetricsFromSource(model: Model)
|
||||||
|
Defined in src/providers/mlflow/mlflow.provider.ts:46
|
||||||
|
Parameters :
Returns :
Promise<any>
|
| Async fetchModelMetricsForSync | |||||||||
fetchModelMetricsForSync(model: Model, queryOptions: literal type)
|
|||||||||
|
Defined in src/providers/mlflow/mlflow.provider.ts:72
|
|||||||||
|
Parameters :
Returns :
unknown
|
| Async fetchModelMetricsFromStore | |||||||||
fetchModelMetricsFromStore(model: Model, queryOptions: literal type)
|
|||||||||
|
Defined in src/providers/mlflow/mlflow.provider.ts:56
|
|||||||||
|
Parameters :
Returns :
unknown
|
| Async getModel |
getModel()
|
|
Defined in src/providers/mlflow/mlflow.provider.ts:121
|
|
Returns :
Promise<any>
|
| Async getModelInfo |
getModelInfo()
|
|
Defined in src/providers/mlflow/mlflow.provider.ts:88
|
|
Returns :
Promise<any>
|
| Async getRun | ||||||
getRun(runId: string)
|
||||||
|
Defined in src/providers/mlflow/mlflow.provider.ts:139
|
||||||
|
Parameters :
Returns :
Promise<any>
|
| initialize | ||||||
initialize(model: Model)
|
||||||
|
Defined in src/providers/mlflow/mlflow.provider.ts:21
|
||||||
|
Parameters :
Returns :
void
|
| Async listModels |
listModels()
|
|
Defined in src/providers/mlflow/mlflow.provider.ts:185
|
|
Returns :
unknown
|
| Async saveMetrics | |||||||||
saveMetrics(mlflowMetrics, turoAttributes: any)
|
|||||||||
|
Defined in src/providers/mlflow/mlflow.provider.ts:217
|
|||||||||
|
Parameters :
Returns :
unknown
|
| Async searchModels | |||||||||||||||
searchModels(filter: string, max_results: number, order_by: string[], page_token: string)
|
|||||||||||||||
|
Defined in src/providers/mlflow/mlflow.provider.ts:157
|
|||||||||||||||
|
Parameters :
Returns :
Promise<any>
|
| Async searchModelVersions | |||||||||||||||
searchModelVersions(filter: string, max_results: number, order_by: string[], page_token: string)
|
|||||||||||||||
|
Defined in src/providers/mlflow/mlflow.provider.ts:189
|
|||||||||||||||
|
Parameters :
Returns :
Promise<any>
|
| Private authOptions |
Type : null
|
Default value : null
|
|
Defined in src/providers/mlflow/mlflow.provider.ts:12
|
| Private headers |
Type : null
|
Default value : null
|
|
Defined in src/providers/mlflow/mlflow.provider.ts:13
|
| Private mlflowApiUrl |
Type : null
|
Default value : null
|
|
Defined in src/providers/mlflow/mlflow.provider.ts:11
|
| Private modelName |
Type : null
|
Default value : null
|
|
Defined in src/providers/mlflow/mlflow.provider.ts:14
|
import { Injectable } from "@nestjs/common";
import { HttpService } from "@nestjs/axios";
import { AxiosResponse } from "axios";
import { lastValueFrom } from "rxjs";
import { findProdVersion } from "./utils";
import { Model, ModelProvider } from "../../models/entities/model.entity";
import { InfluxDbProvider } from "../influxDB/influxdb.provider";
@Injectable()
export class MLflowProvider {
private mlflowApiUrl = null;
private authOptions = null;
private headers = null;
private modelName = null;
constructor(
private readonly httpService: HttpService,
private readonly influxDbProvider: InfluxDbProvider,
) {}
initialize(model: Model): void {
if (model.provider != ModelProvider.MLFLOW) {
throw new Error(
`Can't intialize MLflowProvider when model provider isn't ${ModelProvider.MLFLOW}`,
);
}
if (!model.modelUrl) {
throw new Error("Model doesn't have a valid modelUrl");
}
const url = new URL(model.modelUrl);
this.mlflowApiUrl = `${url.origin}/api`;
this.authOptions = JSON.parse(model.authInfo);
this.modelName = url.href.split("models")[1].replace(/^(\/)/, "");
if (this.authOptions) {
const auth = Buffer.from(
`${this.authOptions["username"]}:${this.authOptions["password"]}`,
).toString("base64");
this.headers = {
Authorization: `Basic ${auth}`,
};
}
}
async fetchMetricsFromSource(model: Model): Promise<any> {
this.initialize(model);
const modelInfo = await this.getModelInfo();
if ("run" in modelInfo) {
return modelInfo["run"]["data"]["metrics"];
} else {
throw "Couldn't find a valid run";
}
}
async fetchModelMetricsFromStore(
model: Model,
queryOptions: {
startTime: number;
endTime: number;
},
) {
const start = new Date(queryOptions.startTime);
const end = new Date(queryOptions.endTime);
return await this.influxDbProvider.fetchModelMetricsFromStore(
model,
start,
end,
);
}
async fetchModelMetricsForSync(
model: Model,
queryOptions: {
startTime: number;
endTime: number;
},
) {
const start = new Date(queryOptions.startTime);
const end = new Date(queryOptions.endTime);
return await this.influxDbProvider.fetchModelMetricsForSync(
model,
start,
end,
);
}
async getModelInfo(): Promise<any> {
const modelInfo = await this.getModel();
const modelVersions = await this.searchModelVersions(
`name LIKE "${this.modelName}"`,
1000,
[],
null,
);
let latestModelVersion = null;
if ("latest_versions" in modelInfo) {
latestModelVersion = modelInfo["latest_versions"][0];
}
const prodModelVersion = findProdVersion(modelVersions);
const chosenVersion = prodModelVersion || latestModelVersion;
if (!chosenVersion) {
throw "Model doesn't have a valid ModelVersion";
}
if (!chosenVersion["run_id"]) {
throw "Model doesn't have a valid Run";
}
const run = await this.getRun(chosenVersion["run_id"]);
const returnInfo = {
model: modelInfo,
version: chosenVersion,
run: run,
};
return returnInfo;
}
async getModel(): Promise<any> {
const endpoint = `${this.mlflowApiUrl}/2.0/mlflow/registered-models/get`;
const response: AxiosResponse<any> = await lastValueFrom(
this.httpService.get(endpoint, {
params: { name: this.modelName },
headers: this.headers,
}),
);
if (response.status === 200) {
const model = response.data.registered_model;
return model;
} else {
console.error(`Failed to get`);
return {};
}
}
async getRun(runId: string): Promise<any> {
const endpoint = `${this.mlflowApiUrl}/2.0/mlflow/runs/get`;
const response: AxiosResponse<any> = await lastValueFrom(
this.httpService.get(endpoint, {
params: { run_id: runId },
headers: this.headers,
}),
);
if (response.status === 200) {
const run = response.data.run;
return run;
} else {
console.error(`Failed to get run`);
return {};
}
}
async searchModels(
filter: string,
max_results: number,
order_by: string[],
page_token: string,
): Promise<any> {
const endpoint = `${this.mlflowApiUrl}/2.0/mlflow/registered-models/search`;
const response: AxiosResponse<any> = await lastValueFrom(
this.httpService.get(endpoint, {
params: {
filter: filter,
max_results: max_results,
order_by: order_by,
page_token: page_token,
},
headers: this.headers,
}),
);
if (response.status === 200) {
const models = response.data.registered_models;
return models;
} else {
console.error(`Failed to search for models.`);
return [];
}
}
async listModels() {
return await this.searchModels(null, 1000, [], null);
}
async searchModelVersions(
filter: string,
max_results: number,
order_by: string[],
page_token: string,
): Promise<any> {
const endpoint = `${this.mlflowApiUrl}/2.0/mlflow/model-versions/search`;
const response: AxiosResponse<any> = await lastValueFrom(
this.httpService.get(endpoint, {
params: {
filter: filter,
max_results: max_results,
order_by: order_by,
page_token: page_token,
},
headers: this.headers,
}),
);
if (response.status === 200) {
const versions = response.data.model_versions;
return versions;
} else {
console.error(`Failed to search for models versions`);
return [];
}
}
async saveMetrics(mlflowMetrics, turoAttributes: any) {
return await this.influxDbProvider.saveMetrics(
mlflowMetrics,
turoAttributes,
);
}
}