Zig 0.16 学习笔记 - 优化上次编写的 HTTP/1.1 服务器

在上一篇笔记中,我基于 Zig 0.16 编写了一个简易单文件 HTTP/1.1 服务器,实现了最基础的请求解析、数据响应功能。但初版代码存在明显短板:所有逻辑堆砌在单个文件中,耦合度极高,请求解析、连接处理、数据读写逻辑混杂,不利于后期维护、迭代和功能扩展。

本次我对原有代码进行结构性重构优化,核心目标:解耦代码层级、拆分功能模块、封装通用逻辑、优化缓冲区读写机制。

上一篇笔记:Zig 0.16 学习笔记 - 编写一个简易的 HTTP/1.1 服务器

代码实现

  • main.zig:负责网络监听、连接管理与任务分发
  • http.zig:负责 HTTP 协议解析、数据结构定义与连接封装

main.zig

zig
const std = @import("std");
const http = @import("http.zig");

pub fn main(init: std.process.Init) !void {
    const io = init.io;
    const addr = try std.Io.net.IpAddress.parse("127.0.0.1", 3000);

    // 初始化服务端监听套接字
    var server = try addr.listen(io, .{});
    defer server.deinit(io);

    std.log.info("Listen: 127.0.0.1:3000", .{});

    // 循环接收客户端连接
    while (server.accept(io)) |stream| {
        std.log.info("Accept: {}", .{stream.socket.address});
        // 开启子线程处理连接
        _ = std.Thread.spawn(.{}, handleConnect, .{ io, init.gpa, stream }) catch |err| {
            stream.close(io);
            std.log.err("Thread.spawn error: {}", .{err});
        };
    } else |err| {
        return err;
    }
}

/// 处理单个客户端连接
fn handleConnect(io: std.Io, alloc: std.mem.Allocator, stream: std.Io.net.Stream) !void {
    // 初始化连接结构体,封装底层读写逻辑
    var connection = http.Connection.init(io, alloc, stream) catch |e| {
        stream.close(io);
        return e;
    };
    defer connection.deinit(alloc);

    var reuse_count: i64 = 0;
    while (true) {
        // 从连接中解析 HTTP 请求
        var request = http.Request.fromConnection(alloc, &connection) catch |e| {
            switch (e) {
                http.HttpParseError.ConnectionClose => return,
                else => return e,
            }
        };
        defer request.deinit();

        reuse_count += 1;
        std.log.debug("Reuse count: (thread id: {}) {}", .{std.Thread.getCurrentId(), reuse_count});

        // 打印请求信息
        try printRequest(&request);
        // 读取请求体数据
        const body = try request.body();
        // 构造并返回响应
        try response(io, &connection.stream, body);
    }
}

/// 格式化打印 HTTP 请求信息
fn printRequest(request: *http.Request) !void {
    std.log.info("request: vvvvvvvvvvvvvvvvvvvvvvvvvv", .{});
    std.log.info("         {s} {s} {s}", .{ request.method.toString(), request.path, request.version.toString() });
    var headers = request.headers.iterator();
    while (headers.next()) |header| {
        std.log.info("         {s}: {s}", .{ header.key_ptr.*, header.value_ptr.* });
    }
    std.log.info("", .{});
    std.log.info("         {s}", .{try request.body()});
    std.log.info("         ^^^^^^^^^^^^^^^^^^^^^^^^^^", .{});
}

/// 构造 HTTP 响应并返回给客户端
fn response(io: std.Io, stream: *const std.Io.net.Stream, message: []const u8) !void {
    var writer_buf: [4096]u8 = undefined;
    var writer = stream.writer(io, &writer_buf);
    try writer.interface.print("HTTP/1.1 200 OK\r\ncontent-length: {}\r\n\r\n{s}", .{
        message.len,
        message,
    });
    try writer.interface.flush();
}

http.zig

zig
const std = @import("std");

/// HTTP 协议版本枚举
pub const HttpVersion = enum {
    http1_1,
    http2,
    http3,

    const TABLE = [_]struct { name: []const u8, version: HttpVersion }{
        .{ .name = "HTTP/1.1", .version = .http1_1 },
        .{ .name = "HTTP/2", .version = .http2 },
        .{ .name = "HTTP/3", .version = .http3 },
    };

    // 字符串转协议版本
    pub fn fromString(version: []const u8) HttpParseError!HttpVersion {
        inline for (TABLE) |pair| {
            if (std.mem.eql(u8, version, pair.name)) {
                return pair.version;
            }
        }
        return HttpParseError.InvalidVersion;
    }

    // 协议版本转字符串
    pub fn toString(self: HttpVersion) []const u8 {
        return switch (self) {
            .http1_1 => "HTTP/1.1",
            .http2 => "HTTP/2",
            .http3 => "HTTP/3",
        };
    }
};

/// HTTP 请求方法(支持自定义非常规方法)
pub const HttpMethod = union(enum) {
    get,
    post,
    put,
    patch,
    delete,
    connect,
    head,
    options,
    custom: []const u8,

    pub fn fromString(alloc: std.mem.Allocator, str: []const u8) HttpParseError!HttpMethod {
        if (str.len == 0) return HttpParseError.InvalidMethod;

        inline for (@typeInfo(HttpMethod).@"union".fields) |field| {
            if (field.type == void) {
                if (std.ascii.eqlIgnoreCase(field.name, str)) {
                    return @unionInit(HttpMethod, field.name, {});
                }
            }
        }

        return .{ .custom = try alloc.dupe(u8, str) };
    }

    pub fn toString(self: HttpMethod) []const u8 {
        return switch (self) {
            .get => "GET",
            .post => "POST",
            .put => "PUT",
            .patch => "PATCH",
            .delete => "DELETE",
            .connect => "CONNECT",
            .head => "HEAD",
            .options => "OPTIONS",
            .custom => |method| method,
        };
    }
};

/// HTTP 请求结构体
pub const Request = struct {
    version: HttpVersion,
    method: HttpMethod,
    path: []u8,
    headers: std.StringHashMap([]const u8),
    payload: ?[]u8,

    connection: *Connection,
    arena: *std.heap.ArenaAllocator,

    pub fn deinit(self: *const Request) void {
        self.arena.deinit();
        self.arena.child_allocator.destroy(self.arena);
    }

    // 从连接中解析完整请求
    pub fn fromConnection(alloc: std.mem.Allocator, connection: *Connection) HttpParseError!Request {
        var arena = try alloc.create(std.heap.ArenaAllocator);
        errdefer alloc.destroy(arena);

        arena.* = std.heap.ArenaAllocator.init(alloc);
        errdefer arena.deinit();

        // 循环读取数据,直到找到请求头结束标记
        const header_end_index = while (connection.bufferFree().len != 0) {
            _ = connection.readToBuffer() catch |e| {
                return switch (e) {
                    std.Io.Reader.Error.EndOfStream => HttpParseError.ConnectionClose,
                    std.Io.Reader.Error.ReadFailed => HttpParseError.Io,
                };
            };
            if (std.mem.indexOf(u8, connection.bufferData(), "\r\n\r\n")) |index| {
                break index;
            }
        } else {
            connection.data_len = 0;
            return HttpParseError.HeadTooLarge;
        };

        // 消费已解析的请求头数据,避免数据残留
        defer connection.consume(header_end_index + 4);

        var request: Request = undefined;
        request.connection = connection;
        request.arena = arena;
        request.payload = null;

        var lines = std.mem.splitSequence(
            u8,
            connection.bufferData()[0..header_end_index],
            "\r\n",
        );

        // 解析请求行
        try request.parseRequestLine(lines.next() orelse {
            return HttpParseError.MissingRequestLine;
        });

        // 初始化请求头哈希表并解析
        request.headers = .init(arena.allocator());
        while (lines.next()) |header| {
            if (header.len == 0) {
                break;
            }

            const index = std.mem.indexOf(u8, header, ":") orelse {
                return HttpParseError.InvalidHeader;
            };

            const header_name = try std.ascii.allocLowerString(arena.allocator(), header[0..index]);
            const header_value = if (index + 1 < header.len)
                try arena.allocator().dupe(u8, std.mem.trim(u8, header[index + 1 ..], " "))
            else
                "";

            try request.headers.put(header_name, header_value);
        }

        return request;
    }

    /// 读取请求体
    pub fn body(self: *Request) ![]u8 {
        if (self.payload) |payload| return payload;

        const size = try self.parseContentLength();

        self.payload = if (size != 0) blk: {
            const payload = try self.arena.allocator().alloc(u8, size);
            try self.connection.fill(payload);
            break :blk payload;
        } else "";

        return self.payload.?;
    }

    /// 丢弃请求体
    pub fn discardBody(self: *Request) !void {
        if (self.payload) |_| {
            self.payload = "";
            return;
        }

        const size = try self.parseContentLength();

        if (size != 0) {
            try self.connection.discard(size);
        }

        self.payload = "";
    }

    // 解析 Content-Length 获取请求体大小
    fn parseContentLength(self: *const Request) !usize {
        const content_length = self.headers.get("content-length") orelse "0";
        return try std.fmt.parseInt(usize, content_length, 10);
    }

    // 解析请求行:方法、路径、协议版本
    fn parseRequestLine(self: *Request, data: []const u8) HttpParseError!void {
        var parts = std.mem.splitScalar(u8, data, ' ');
        self.method = blk: {
            const text = parts.next() orelse {
                return HttpParseError.MissingMethod;
            };
            const method = try HttpMethod.fromString(self.arena.allocator(), text);
            break :blk method;
        };
        self.path = blk: {
            const text = parts.next() orelse {
                return HttpParseError.MissingPath;
            };
            break :blk try self.arena.allocator().dupe(u8, text);
        };
        self.version = blk: {
            const text = parts.next() orelse {
                return HttpParseError.MissingVersion;
            };
            const version = try HttpVersion.fromString(text);
            break :blk version;
        };
    }
};

/// 客户端连接管理结构体
pub const Connection = struct {
    stream: std.Io.net.Stream,
    io: std.Io,
    buf: []u8,
    data_len: usize,
    reader: std.Io.net.Stream.Reader,

    const READER_BUFFER_SIZE = 4096;

    pub fn init(io: std.Io, alloc: std.mem.Allocator, stream: std.Io.net.Stream) !Connection {
        const buf = try alloc.alloc(u8, READER_BUFFER_SIZE + 4096);
        errdefer alloc.free(buf);

        return .{
            .stream = stream,
            .io = io,
            .buf = buf,
            .data_len = 0,
            .reader = stream.reader(io, buf[0..READER_BUFFER_SIZE]),
        };
    }

    pub fn deinit(self: *Connection, alloc: std.mem.Allocator) void {
        self.stream.close(self.io);
        alloc.free(self.buf);
    }

    // 获取缓冲区总空间
    pub fn buffer(self: *Connection) []u8 {
        return self.buf[READER_BUFFER_SIZE..];
    }

    // 获取缓冲区已存储数据
    pub fn bufferData(self: *Connection) []u8 {
        return self.buf[READER_BUFFER_SIZE .. READER_BUFFER_SIZE + self.data_len];
    }

    // 获取缓冲区空闲空间
    pub fn bufferFree(self: *Connection) []u8 {
        return self.buf[READER_BUFFER_SIZE + self.data_len ..];
    }

    // 从套接字读取数据到缓冲区
    pub fn readToBuffer(self: *Connection) std.Io.Reader.Error!usize {
        var vec = [_][]u8{self.bufferFree()};
        const read_bytes = try self.reader.interface.readVec(&vec);
        self.data_len += read_bytes;
        return read_bytes;
    }

    // 读取数据填满指定的缓冲区(优先读缓冲区,不足则从套接字补读)
    pub fn fill(self: *Connection, buf: []u8) std.Io.Reader.Error!void {
        const n = @min(self.data_len, buf.len);
        @memcpy(buf[0..n], self.bufferData()[0..n]);
        if (n != buf.len) {
            try self.reader.interface.readSliceAll(buf[n..]);
        }
        self.consume(n);
    }

    // 丢弃 n 字节数据
    pub fn discard(self: *Connection, n: usize) std.Io.Reader.Error!void {
        const min = @min(self.data_len, n);
        if (min != n) {
            try self.reader.interface.discardAll(n - min);
        }
        self.consume(min);
    }

    // 消费缓冲区数据
    pub fn consume(self: *Connection, n: usize) void {
        std.debug.assert(n <= self.data_len);
        const data = self.bufferData();
        @memmove(data[0 .. self.data_len - n], data[n..self.data_len]);
        self.data_len -= n;
    }
};

pub const HttpParseError = error{
    InvalidVersion,
    InvalidMethod,
    InvalidHeader,
    MissingVersion,
    MissingMethod,
    MissingPath,
    MissingRequestLine,
    Io,
    HeadTooLarge,
    ConnectionClose,
} || std.mem.Allocator.Error;