Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ public interface IText2SqlHook : IHookBase
{
// Get database type
string GetDatabaseType(RoleDialogModel message);
string? GetConnectionString(RoleDialogModel message);
string? GetConnectionString(RoleDialogModel message, string? dataSource = null);
Task SqlGenerated(RoleDialogModel message);
Task SqlExecuting(RoleDialogModel message);
Task SqlExecuted(RoleDialogModel message);
Expand Down
16 changes: 16 additions & 0 deletions src/Plugins/BotSharp.Plugin.SqlDriver/Constants/StateKeys.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
using EntityFrameworkCore.BootKit;
using System;
using System.Collections.Generic;
using System.Data;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

namespace BotSharp.Plugin.SqlDriver.Constants
{
public class StateKeys
{
public const string DBType = "db_type";
public const string DataSource = "data_source_name";
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using BotSharp.Abstraction.Models;
using BotSharp.Plugin.SqlDriver.Constants;
using BotSharp.Plugin.SqlDriver.Controllers.ViewModels;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Mvc;
Expand Down Expand Up @@ -30,8 +31,8 @@ public async Task<IActionResult> ExecuteSqlQuery([FromRoute] string conversation
var conv = _services.GetRequiredService<IConversationService>();
await conv.SetConversationId(conversationId,
[
new MessageState("database_type", sqlQueryRequest.DbType),
new MessageState("data_source_name", sqlQueryRequest.DataSource),
new MessageState(StateKeys.DBType, sqlQueryRequest.DbType),
new MessageState(StateKeys.DataSource, sqlQueryRequest.DataSource),
]);

var msg = new RoleDialogModel(AgentRole.User, sqlQueryRequest.SqlStatement)
Expand Down
58 changes: 14 additions & 44 deletions src/Plugins/BotSharp.Plugin.SqlDriver/Functions/ExecuteQueryFn.cs
Original file line number Diff line number Diff line change
@@ -1,25 +1,22 @@
using BotSharp.Core.Infrastructures;
using Dapper;
using Microsoft.Data.SqlClient;
using Microsoft.Data.Sqlite;
using MySqlConnector;
using Npgsql;
using System.Text;

namespace BotSharp.Plugin.SqlDriver.Functions;

public class ExecuteQueryFn : IFunctionCallback
{
public string Name => "execute_sql";
public string Indication => "Performing data retrieval operation.";
private readonly SqlDriverSetting _setting;
private readonly SqlExecuteService _sqlExecuteService;
private readonly IServiceProvider _services;
private readonly ILogger _logger;

public ExecuteQueryFn(IServiceProvider services, SqlDriverSetting setting, ILogger<ExecuteQueryFn> logger)
public ExecuteQueryFn(IServiceProvider services,
SqlDriverSetting setting,
SqlExecuteService sqlExecuteService,
ILogger<ExecuteQueryFn> logger)
{
_services = services;
_setting = setting;
_sqlExecuteService = sqlExecuteService;
_logger = logger;
}

Expand All @@ -30,22 +27,23 @@ public async Task<bool> Execute(RoleDialogModel message)
var dbHook = _services.GetRequiredService<IText2SqlHook>();
var dbType = dbHook.GetDatabaseType(message);
var connectionString = _setting.Connections.FirstOrDefault(x => x.Name.Equals(args.DataSource, StringComparison.OrdinalIgnoreCase))?.ConnectionString;
var dbConnectionString = dbHook.GetConnectionString(message) ?? connectionString ?? throw new Exception("database connection is not found");
var dbConnectionString = dbHook.GetConnectionString(message, args.DataSource) ?? connectionString ?? throw new Exception("database connection is not found");

// Print all the SQL statements for debugging
_logger.LogInformation("Executing SQL Statements: {SqlStatements}", string.Join("\r\n", args.SqlStatements));

IEnumerable<dynamic> results = [];
try
{
results = dbType.ToLower() switch
results = await (dbType.ToLower() switch
{
"mysql" => RunQueryInMySql(dbConnectionString, args.SqlStatements),
"sqlserver" or "mssql" => RunQueryInSqlServer(dbConnectionString, args.SqlStatements),
"redshift" => RunQueryInRedshift(dbConnectionString, args.SqlStatements),
"sqlite" => RunQueryInSqlite(dbConnectionString, args.SqlStatements),
"mysql" => _sqlExecuteService.RunQueryInMySql(dbConnectionString, args.SqlStatements),
"sqlserver" or "mssql" => _sqlExecuteService.RunQueryInSqlServer(dbConnectionString, args.SqlStatements),
"redshift" => _sqlExecuteService.RunQueryInRedshift(dbConnectionString, args.SqlStatements),
"sqlite" => _sqlExecuteService.RunQueryInSqlite(dbConnectionString, args.SqlStatements),
"mongodb" => _sqlExecuteService.RunQueryInMongoDb(dbConnectionString, args.SqlStatements.FirstOrDefault(), []),
_ => throw new NotImplementedException($"Database type {dbType} is not supported.")
};
});

if (args.SqlStatements.Length == 1 && args.SqlStatements[0].StartsWith("DROP TABLE"))
{
Expand Down Expand Up @@ -215,34 +213,6 @@ private string EscapeMarkdownField(string field)
return field.Replace("|", "\\|");
}

private IEnumerable<dynamic> RunQueryInMySql(string connectionString, string[] sqlTexts)
{
var settings = _services.GetRequiredService<SqlDriverSetting>();
using var connection = new MySqlConnection(connectionString);
return connection.Query(string.Join(";\r\n", sqlTexts));
}

private IEnumerable<dynamic> RunQueryInSqlServer(string connectionString, string[] sqlTexts)
{
var settings = _services.GetRequiredService<SqlDriverSetting>();
using var connection = new SqlConnection(connectionString);
return connection.Query(string.Join("\r\n", sqlTexts));
}

private IEnumerable<dynamic> RunQueryInRedshift(string connectionString, string[] sqlTexts)
{
var settings = _services.GetRequiredService<SqlDriverSetting>();
using var connection = new NpgsqlConnection(connectionString);
return connection.Query(string.Join("\r\n", sqlTexts));
}

private IEnumerable<dynamic> RunQueryInSqlite(string connectionString, string[] sqlTexts)
{
var settings = _services.GetRequiredService<SqlDriverSetting>();
using var connection = new SqliteConnection(connectionString);
return connection.Query(string.Join("\r\n", sqlTexts));
}

private async Task<ExecuteQueryArgs> RefineSqlStatement(RoleDialogModel message, ExecuteQueryArgs args)
{
if (args.Tables == null || args.Tables.Length == 0)
Expand Down
168 changes: 0 additions & 168 deletions src/Plugins/BotSharp.Plugin.SqlDriver/Functions/SqlSelect.cs

This file was deleted.

55 changes: 55 additions & 0 deletions src/Plugins/BotSharp.Plugin.SqlDriver/Functions/SqlSelectFn.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
namespace BotSharp.Plugin.SqlDriver.Functions;

public class SqlSelectFn : IFunctionCallback
{
public string Name => "sql_select";
private readonly IServiceProvider _services;
private readonly SqlExecuteService _sqlExecuteService;

public SqlSelectFn(IServiceProvider services,
SqlExecuteService sqlExecuteService)
{
_services = services;
_sqlExecuteService = sqlExecuteService;
}

public async Task<bool> Execute(RoleDialogModel message)
{
var args = JsonSerializer.Deserialize<SqlStatement>(message.FunctionArgs);

if (args.GeneratedWithoutTableDefinition)
{
message.Content = $"Get the table definition first.";
return false;
}

// check if need to instantely
var dbHook = _services.GetRequiredService<IText2SqlHook>();
var dbType = dbHook.GetDatabaseType(message);
var dbConnectionString = dbHook.GetConnectionString(message) ??
throw new Exception("database connectdion is not found");

var result = await (dbType switch
{
"mysql" => _sqlExecuteService.RunQueryInMySql(dbConnectionString, args.Statement, args.Parameters),
"sqlserver" or "mssql" => _sqlExecuteService.RunQueryInSqlServer(dbConnectionString, args.Statement, args.Parameters),
"redshift" => _sqlExecuteService.RunQueryInRedshift(dbConnectionString, args.Statement, args.Parameters),
"sqlite" => _sqlExecuteService.RunQueryInSqlite(dbConnectionString, args.Statement, args.Parameters),
"mongodb" => _sqlExecuteService.RunQueryInMongoDb(dbConnectionString, args.Statement, args.Parameters),
_ => throw new NotImplementedException($"Database type {dbType} is not supported.")
});

if (result == null)
{
message.Content = "Record not found";
}
else
{
if (dbType == "mongodb") message.StopCompletion = true;
message.Content = JsonSerializer.Serialize(result);
args.Return.Value = message.Content;
}

return true;
}
}
Loading
Loading