公告

任何建议或需求可联系我!


Skip to content

客户端基础

学习创建和管理 MCP 客户端的基础知识,包括生命周期管理、初始化和错误处理。

创建客户端

MCP-Go 为每种支持的传输方式提供客户端构造函数。传输方式的选择决定了客户端如何与服务端通信。

客户端构造函数模式

go
// STDIO 客户端 - 用于命令行工具
client, err := client.NewStdioMCPClient("command", "arg1", "arg2")

// StreamableHTTP 客户端 - 用于 Web 服务
client := client.NewStreamableHttpClient("http://localhost:8080/mcp")

// SSE 客户端 - 用于实时 Web 应用程序
client := client.NewSSEMCPClient("http://localhost:8080/mcp/sse")

// 进程内客户端 - 用于测试和嵌入式场景
client := client.NewInProcessClient(server)

STDIO 客户端创建

go
package main

import (
    "context"
    "errors"
    "fmt"
    "log"
    "math"
    "net/http"
    "sync"
    "time"

    "github.com/mark3labs/mcp-go/client"
    "github.com/mark3labs/mcp-go/mcp"
)

func createStdioClient() (client.Client, error) {
    // 创建客户端,启动子进程
    c, err := client.NewStdioMCPClient(
        "go", []string{}, "run", "/path/to/server/main.go",
    )
    if err != nil {
        return nil, fmt.Errorf("failed to create STDIO client: %w", err)
    }

    return c, nil
}

// 使用自定义环境变量
func createStdioClientWithEnv() (client.Client, error) {
    env := []string{
        "LOG_LEVEL=debug",
        "DATABASE_URL=sqlite://test.db",
    }
    c, err := client.NewStdioMCPClient(
        "go", env, "run", "/path/to/server/main.go",
    )
    if err != nil {
        return nil, fmt.Errorf("failed to create STDIO client: %w", err)
    }

    return c, nil
}

StreamableHTTP 客户端创建

go
func createStreamableHTTPClient() client.Client {
    // 基本的 StreamableHTTP 客户端
	httpTransport, err := transport.NewStreamableHTTP(server.URL,
		// 设置超时
		transport.WithHTTPTimeout(30*time.Second),
		// 设置自定义请求头
		transport.WithHTTPHeaders(map[string]string{
			"X-Custom-Header": "custom-value",
			"Y-Another-Header": "another-value",
		}),
		// 使用自定义 HTTP 客户端
		transport.WithHTTPBasicClient(&http.Client{}),
	)
    if err != nil {
        log.Fatalf("Failed to create StreamableHTTP transport: %v", err)
    }
    c := client.NewClient(httpTransport)
    return c
}

SSE 客户端创建

go
func createSSEClient() client.Client {
    // 基本的 SSE 客户端
	c, err := NewSSEMCPClient(testServer.URL+"/sse",
		// 设置自定义请求头
		WithHeaders(map[string]string{
			"X-Custom-Header": "custom-value",
			"Y-Another-Header": "another-value",
		}),
	)
    return c
}

客户端生命周期

理解客户端生命周期对于正确的资源管理和错误处理至关重要。

生命周期阶段

  1. 创建 - 实例化客户端
  2. 初始化 - 建立连接并交换能力
  3. 操作 - 使用工具、资源和提示词
  4. 清理 - 关闭连接并释放资源

完整生命周期示例

go
func demonstrateClientLifecycle() error {
    // 1. 创建
    c, err := client.NewSSEMCPClient("server-command")
    if err != nil {
        return fmt.Errorf("client creation failed: %w", err)
    }

    // 确保清理发生
    defer func() {
        if closeErr := c.Close(); closeErr != nil {
            log.Printf("Error closing client: %v", closeErr)
        }
    }()

    ctx := context.Background()

    // 2. 初始化
    if err := c.Initialize(ctx); err != nil {
        return fmt.Errorf("client initialization failed: %w", err)
    }

    // 3. 操作
    if err := performClientOperations(ctx, c); err != nil {
        return fmt.Errorf("client operations failed: %w", err)
    }

    // 4. 清理(由 defer 处理)
    return nil
}

func performClientOperations(ctx context.Context, c client.Client) error {
    // 列出可用的工具
    tools, err := c.ListTools(ctx, mcp.ListToolsRequest{})
    if err != nil {
        return err
    }

    log.Printf("Found %d tools", len(tools.Tools))

    // 使用工具
    for _, tool := range tools.Tools {
        result, err := c.CallTool(ctx, mcp.CallToolRequest{
            Params: mcp.CallToolRequestParams{
                Name:      tool.Name,
                Arguments: map[string]interface{}{
                    "input": "example input",
                    "format": "json",
                },
            },
        })
        if err != nil {
            log.Printf("Tool %s failed: %v", tool.Name, err)
            continue
        }

        log.Printf("Tool %s result: %+v", tool.Name, result)
    }

    return nil
}

初始化过程

初始化过程建立 MCP 连接并交换能力:

go
func initializeClientWithDetails(ctx context.Context, c client.Client) error {
    // 使用自定义客户端信息进行初始化
    initReq := mcp.InitializeRequest{
        Params: mcp.InitializeRequestParams{
            ProtocolVersion: "2024-11-05",
            Capabilities: mcp.ClientCapabilities{
                Tools:     &mcp.ToolsCapability{},
                Resources: &mcp.ResourcesCapability{},
                Prompts:   &mcp.PromptsCapability{},
            },
            ClientInfo: mcp.ClientInfo{
                Name:    "My Application",
                Version: "1.0.0",
            },
        },
    }

    result, err := c.InitializeWithRequest(ctx, initReq)
    if err != nil {
        return fmt.Errorf("initialization failed: %w", err)
    }

    log.Printf("Connected to server: %s v%s", 
        result.ServerInfo.Name, 
        result.ServerInfo.Version)
    
    log.Printf("Server capabilities: %+v", result.Capabilities)

    return nil
}

优雅关闭

go
type ManagedClient struct {
    client client.Client
    ctx    context.Context
    cancel context.CancelFunc
    done   chan struct{}
}

func NewManagedClient(clientType, address string) (*ManagedClient, error) {
    var c client.Client
    var err error

    switch clientType {
    case "stdio":
        c, err = client.NewSSEMCPClient("server-command")
    case "streamablehttp":
        c = client.NewStreamableHttpClient(address)
    case "sse":
        c = client.NewSSEMCPClient(address)
    default:
        return nil, fmt.Errorf("unknown client type: %s", clientType)
    }

    if err != nil {
        return nil, err
    }

    ctx, cancel := context.WithCancel(context.Background())

    mc := &ManagedClient{
        client: c,
        ctx:    ctx,
        cancel: cancel,
        done:   make(chan struct{}),
    }

    // 在后台初始化
    go func() {
        defer close(mc.done)
        if err := c.Initialize(ctx); err != nil {
            log.Printf("Client initialization failed: %v", err)
        }
    }()

    return mc, nil
}

func (mc *ManagedClient) WaitForReady(timeout time.Duration) error {
    select {
    case <-mc.done:
        return nil
    case <-time.After(timeout):
        return fmt.Errorf("client initialization timeout")
    case <-mc.ctx.Done():
        return mc.ctx.Err()
    }
}

func (mc *ManagedClient) Close() error {
    mc.cancel()
    
    // 等待初始化完成或超时
    select {
    case <-mc.done:
    case <-time.After(5 * time.Second):
        log.Println("Timeout waiting for client shutdown")
    }

    return mc.client.Close()
}

错误处理

正确的错误处理对于健壮的客户端应用程序至关重要。

错误类型

go
// 连接错误
var (
    ErrConnectionFailed = errors.New("connection failed")
    ErrConnectionLost   = errors.New("connection lost")
    ErrTimeout          = errors.New("operation timeout")
)

// 协议错误
var (
    ErrInvalidResponse    = errors.New("invalid response")
    ErrProtocolViolation  = errors.New("protocol violation")
    ErrUnsupportedVersion = errors.New("unsupported protocol version")
)

// 操作错误
var (
    ErrToolNotFound       = errors.New("tool not found")
    ErrResourceNotFound   = errors.New("resource not found")
    ErrInvalidArguments   = errors.New("invalid arguments")
    ErrPermissionDenied   = errors.New("permission denied")
)

全面错误处理

go
func handleClientErrors(ctx context.Context, c client.Client) {
    result, err := c.CallTool(ctx, mcp.CallToolRequest{
        Params: mcp.CallToolRequestParams{
            Name: "example_tool",
            Arguments: map[string]interface{}{
                "param": "value",
            },
        },
    })

    if err != nil {
        switch {
        // 连接错误 - 可能可恢复
        case errors.Is(err, client.ErrConnectionLost):
            log.Println("Connection lost, attempting reconnect...")
            if reconnectErr := reconnectClient(c); reconnectErr != nil {
                log.Printf("Reconnection failed: %v", reconnectErr)
                return
            }
            // 重试操作
            return handleClientErrors(ctx, c)

        case errors.Is(err, client.ErrTimeout):
            log.Println("Operation timed out, retrying with longer timeout...")
            ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
            defer cancel()
            return handleClientErrors(ctx, c)

        // 协议错误 - 通常不可恢复
        case errors.Is(err, client.ErrProtocolViolation):
            log.Printf("Protocol violation: %v", err)
            return

        case errors.Is(err, client.ErrUnsupportedVersion):
            log.Printf("Unsupported protocol version: %v", err)
            return

        // 操作错误 - 检查并修复请求
        case errors.Is(err, client.ErrToolNotFound):
            log.Printf("Tool not found: %v", err)
            // 列出可用工具并建议替代方案
            suggestAlternativeTools(ctx, c)
            return

        case errors.Is(err, client.ErrInvalidArguments):
            log.Printf("Invalid arguments: %v", err)
            // 获取工具模式并显示必需参数
            showToolSchema(ctx, c, "example_tool")
            return

        case errors.Is(err, client.ErrPermissionDenied):
            log.Printf("Permission denied: %v", err)
            // 提示进行身份验证
            return

        // 未知错误
        default:
            log.Printf("Unexpected error: %v", err)
            return
        }
    }

    // 处理成功结果
    log.Printf("Tool result: %+v", result)
}

func reconnectClient(c client.Client) error {
    // 关闭现有连接
    if err := c.Close(); err != nil {
        log.Printf("Error closing client: %v", err)
    }

    // 等待重新连接
    time.Sleep(1 * time.Second)

    // 重新初始化
    ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
    defer cancel()

    return c.Initialize(ctx)
}

func suggestAlternativeTools(ctx context.Context, c client.Client) {
    tools, err := c.ListTools(ctx, mcp.ListToolsRequest{})
    if err != nil {
        log.Printf("Failed to list tools: %v", err)
        return
    }

    log.Println("Available tools:")
    for _, tool := range tools.Tools {
        log.Printf("- %s: %s", tool.Name, tool.Description)
    }
}

func showToolSchema(ctx context.Context, c client.Client, toolName string) {
    tools, err := c.ListTools(ctx, mcp.ListToolsRequest{})
    if err != nil {
        log.Printf("Failed to list tools: %v", err)
        return
    }

    for _, tool := range tools.Tools {
        if tool.Name == toolName {
            log.Printf("Tool schema for %s:", toolName)
            log.Printf("Description: %s", tool.Description)
            log.Printf("Input schema: %+v", tool.InputSchema)
            return
        }
    }

    log.Printf("Tool %s not found", toolName)
}

带指数退避的重试逻辑

go
type RetryConfig struct {
    MaxRetries      int
    InitialDelay    time.Duration
    MaxDelay        time.Duration
    BackoffFactor   float64
    RetryableErrors []error
}

func DefaultRetryConfig() RetryConfig {
    return RetryConfig{
        MaxRetries:    3,
        InitialDelay:  1 * time.Second,
        MaxDelay:      30 * time.Second,
        BackoffFactor: 2.0,
        RetryableErrors: []error{
            client.ErrConnectionLost,
            client.ErrTimeout,
            client.ErrConnectionFailed,
        },
    }
}

func (rc RetryConfig) IsRetryable(err error) bool {
    for _, retryableErr := range rc.RetryableErrors {
        if errors.Is(err, retryableErr) {
            return true
        }
    }
    return false
}

func WithRetry[T any](ctx context.Context, config RetryConfig, operation func() (T, error)) (T, error) {
    var lastErr error
    var zero T

    for attempt := 0; attempt <= config.MaxRetries; attempt++ {
        result, err := operation()
        if err == nil {
            return result, nil
        }

        lastErr = err

        // 不重试不可重试的错误
        if !config.IsRetryable(err) {
            break
        }

        // 不在最后一次重试
        if attempt == config.MaxRetries {
            break
        }

        // 计算带指数退避的延迟
        delay := time.Duration(float64(config.InitialDelay) * math.Pow(config.BackoffFactor, float64(attempt)))
        if delay > config.MaxDelay {
            delay = config.MaxDelay
        }

        log.Printf("Attempt %d failed, retrying in %v: %v", attempt+1, delay, err)

        // 等待,支持上下文取消
        select {
        case <-time.After(delay):
        case <-ctx.Done():
            return zero, ctx.Err()
        }
    }

    return zero, fmt.Errorf("failed after %d attempts: %w", config.MaxRetries+1, lastErr)
}

// 使用示例
func callToolWithRetry(ctx context.Context, c client.Client, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
    config := DefaultRetryConfig()
    
    return WithRetry(ctx, config, func() (*mcp.CallToolResult, error) {
        return c.CallTool(ctx, req)
    })
}

上下文和超时管理

go
func demonstrateContextUsage(c client.Client) {
    // 带超时的操作
    ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
    defer cancel()

    result, err := c.CallTool(ctx, mcp.CallToolRequest{
        Params: mcp.CallToolRequestParams{
            Name: "long_running_tool",
            Arguments: map[string]interface{}{
                "duration": 60, // 秒
            },
        },
    })

    if err != nil {
        if errors.Is(err, context.DeadlineExceeded) {
            log.Println("Tool call timed out")
        } else {
            log.Printf("Tool call failed: %v", err)
        }
        return
    }

    log.Printf("Tool completed: %+v", result)
}

func demonstrateCancellation(c client.Client) {
    ctx, cancel := context.WithCancel(context.Background())

    // 在 goroutine 中启动操作
    go func() {
        result, err := c.CallTool(ctx, mcp.CallToolRequest{
            Params: mcp.CallToolRequestParams{
                Name: "long_running_tool",
            },
        })

        if err != nil {
            if errors.Is(err, context.Canceled) {
                log.Println("Tool call was cancelled")
            } else {
                log.Printf("Tool call failed: %v", err)
            }
            return
        }

        log.Printf("Tool completed: %+v", result)
    }()

    // 5秒后取消
    time.Sleep(5 * time.Second)
    cancel()
    
    // 等待一会儿以查看取消
    time.Sleep(1 * time.Second)
}

连接监控

健康检查

go
type ClientHealthMonitor struct {
    client   client.Client
    interval time.Duration
    timeout  time.Duration
    healthy  bool
    mutex    sync.RWMutex
}

func NewClientHealthMonitor(c client.Client, interval, timeout time.Duration) *ClientHealthMonitor {
    return &ClientHealthMonitor{
        client:   c,
        interval: interval,
        timeout:  timeout,
        healthy:  false,
    }
}

func (chm *ClientHealthMonitor) Start(ctx context.Context) {
    ticker := time.NewTicker(chm.interval)
    defer ticker.Stop()

    for {
        select {
        case <-ctx.Done():
            return
        case <-ticker.C:
            chm.checkHealth(ctx)
        }
    }
}

func (chm *ClientHealthMonitor) checkHealth(ctx context.Context) {
    ctx, cancel := context.WithTimeout(ctx, chm.timeout)
    defer cancel()

    // 尝试列出工具作为健康检查
    _, err := chm.client.ListTools(ctx, mcp.ListToolsRequest{})
    
    chm.mutex.Lock()
    chm.healthy = (err == nil)
    chm.mutex.Unlock()

    if err != nil {
        log.Printf("Health check failed: %v", err)
    }
}

func (chm *ClientHealthMonitor) IsHealthy() bool {
    chm.mutex.RLock()
    defer chm.mutex.RUnlock()
    return chm.healthy
}

连接恢复

go
type ResilientClient struct {
    factory    func() (client.Client, error)
    client     client.Client
    mutex      sync.RWMutex
    recovering bool
}

func NewResilientClient(factory func() (client.Client, error)) *ResilientClient {
    return &ResilientClient{
        factory: factory,
    }
}

func (rc *ResilientClient) ensureConnected(ctx context.Context) error {
    rc.mutex.RLock()
    if rc.client != nil && !rc.recovering {
        rc.mutex.RUnlock()
        return nil
    }
    rc.mutex.RUnlock()

    rc.mutex.Lock()
    defer rc.mutex.Unlock()

    // 获取写锁后双重检查
    if rc.client != nil && !rc.recovering {
        return nil
    }

    rc.recovering = true
    defer func() { rc.recovering = false }()

    // 关闭现有客户端(如果有)
    if rc.client != nil {
        rc.client.Close()
    }

    // 创建新客户端
    newClient, err := rc.factory()
    if err != nil {
        return fmt.Errorf("failed to create client: %w", err)
    }

    // 初始化新客户端
    if err := newClient.Initialize(ctx); err != nil {
        newClient.Close()
        return fmt.Errorf("failed to initialize client: %w", err)
    }

    rc.client = newClient
    return nil
}

func (rc *ResilientClient) CallTool(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
    if err := rc.ensureConnected(ctx); err != nil {
        return nil, err
    }

    rc.mutex.RLock()
    client := rc.client
    rc.mutex.RUnlock()

    result, err := client.CallTool(ctx, req)
    if err != nil && isConnectionError(err) {
        // 标记为恢复并重试一次
        rc.mutex.Lock()
        rc.recovering = true
        rc.mutex.Unlock()

        if retryErr := rc.ensureConnected(ctx); retryErr != nil {
            return nil, fmt.Errorf("recovery failed: %w", retryErr)
        }

        rc.mutex.RLock()
        client = rc.client
        rc.mutex.RUnlock()

        return client.CallTool(ctx, req)
    }

    return result, err
}

func isConnectionError(err error) bool {
    return errors.Is(err, client.ErrConnectionLost) ||
           errors.Is(err, client.ErrConnectionFailed)
}

下一步