在上一篇笔记中,我基于 Zig 0.16 编写了一个简易单文件 HTTP/1.1 服务器,实现了最基础的请求解析、数据响应功能。但初版代码存在明显短板:所有逻辑堆砌在单个文件中,耦合度极高,请求解析、连接处理、数据读写逻辑混杂,不利于后期维护、迭代和功能扩展。
本次我对原有代码进行结构性重构优化,核心目标:解耦代码层级、拆分功能模块、封装通用逻辑、优化缓冲区读写机制。
上一篇笔记:Zig 0.16 学习笔记 - 编写一个简易的 HTTP/1.1 服务器
代码实现
- main.zig:负责网络监听、连接管理与任务分发
- http.zig:负责 HTTP 协议解析、数据结构定义与连接封装
main.zig
zig
const std = @import("std");
const http = @import("http.zig");
pub fn main(init: std.process.Init) !void {
const io = init.io;
const addr = try std.Io.net.IpAddress.parse("127.0.0.1", 3000);
// 初始化服务端监听套接字
var server = try addr.listen(io, .{});
defer server.deinit(io);
std.log.info("Listen: 127.0.0.1:3000", .{});
// 循环接收客户端连接
while (server.accept(io)) |stream| {
std.log.info("Accept: {}", .{stream.socket.address});
// 开启子线程处理连接
_ = std.Thread.spawn(.{}, handleConnect, .{ io, init.gpa, stream }) catch |err| {
stream.close(io);
std.log.err("Thread.spawn error: {}", .{err});
};
} else |err| {
return err;
}
}
/// 处理单个客户端连接
fn handleConnect(io: std.Io, alloc: std.mem.Allocator, stream: std.Io.net.Stream) !void {
// 初始化连接结构体,封装底层读写逻辑
var connection = http.Connection.init(io, alloc, stream) catch |e| {
stream.close(io);
return e;
};
defer connection.deinit(alloc);
var reuse_count: i64 = 0;
while (true) {
// 从连接中解析 HTTP 请求
var request = http.Request.fromConnection(alloc, &connection) catch |e| {
switch (e) {
http.HttpParseError.ConnectionClose => return,
else => return e,
}
};
defer request.deinit();
reuse_count += 1;
std.log.debug("Reuse count: (thread id: {}) {}", .{std.Thread.getCurrentId(), reuse_count});
// 打印请求信息
try printRequest(&request);
// 读取请求体数据
const body = try request.body();
// 构造并返回响应
try response(io, &connection.stream, body);
}
}
/// 格式化打印 HTTP 请求信息
fn printRequest(request: *http.Request) !void {
std.log.info("request: vvvvvvvvvvvvvvvvvvvvvvvvvv", .{});
std.log.info(" {s} {s} {s}", .{ request.method.toString(), request.path, request.version.toString() });
var headers = request.headers.iterator();
while (headers.next()) |header| {
std.log.info(" {s}: {s}", .{ header.key_ptr.*, header.value_ptr.* });
}
std.log.info("", .{});
std.log.info(" {s}", .{try request.body()});
std.log.info(" ^^^^^^^^^^^^^^^^^^^^^^^^^^", .{});
}
/// 构造 HTTP 响应并返回给客户端
fn response(io: std.Io, stream: *const std.Io.net.Stream, message: []const u8) !void {
var writer_buf: [4096]u8 = undefined;
var writer = stream.writer(io, &writer_buf);
try writer.interface.print("HTTP/1.1 200 OK\r\ncontent-length: {}\r\n\r\n{s}", .{
message.len,
message,
});
try writer.interface.flush();
}
http.zig
zig
const std = @import("std");
/// HTTP 协议版本枚举
pub const HttpVersion = enum {
http1_1,
http2,
http3,
const TABLE = [_]struct { name: []const u8, version: HttpVersion }{
.{ .name = "HTTP/1.1", .version = .http1_1 },
.{ .name = "HTTP/2", .version = .http2 },
.{ .name = "HTTP/3", .version = .http3 },
};
// 字符串转协议版本
pub fn fromString(version: []const u8) HttpParseError!HttpVersion {
inline for (TABLE) |pair| {
if (std.mem.eql(u8, version, pair.name)) {
return pair.version;
}
}
return HttpParseError.InvalidVersion;
}
// 协议版本转字符串
pub fn toString(self: HttpVersion) []const u8 {
return switch (self) {
.http1_1 => "HTTP/1.1",
.http2 => "HTTP/2",
.http3 => "HTTP/3",
};
}
};
/// HTTP 请求方法(支持自定义非常规方法)
pub const HttpMethod = union(enum) {
get,
post,
put,
patch,
delete,
connect,
head,
options,
custom: []const u8,
pub fn fromString(alloc: std.mem.Allocator, str: []const u8) HttpParseError!HttpMethod {
if (str.len == 0) return HttpParseError.InvalidMethod;
inline for (@typeInfo(HttpMethod).@"union".fields) |field| {
if (field.type == void) {
if (std.ascii.eqlIgnoreCase(field.name, str)) {
return @unionInit(HttpMethod, field.name, {});
}
}
}
return .{ .custom = try alloc.dupe(u8, str) };
}
pub fn toString(self: HttpMethod) []const u8 {
return switch (self) {
.get => "GET",
.post => "POST",
.put => "PUT",
.patch => "PATCH",
.delete => "DELETE",
.connect => "CONNECT",
.head => "HEAD",
.options => "OPTIONS",
.custom => |method| method,
};
}
};
/// HTTP 请求结构体
pub const Request = struct {
version: HttpVersion,
method: HttpMethod,
path: []u8,
headers: std.StringHashMap([]const u8),
payload: ?[]u8,
connection: *Connection,
arena: *std.heap.ArenaAllocator,
pub fn deinit(self: *const Request) void {
self.arena.deinit();
self.arena.child_allocator.destroy(self.arena);
}
// 从连接中解析完整请求
pub fn fromConnection(alloc: std.mem.Allocator, connection: *Connection) HttpParseError!Request {
var arena = try alloc.create(std.heap.ArenaAllocator);
errdefer alloc.destroy(arena);
arena.* = std.heap.ArenaAllocator.init(alloc);
errdefer arena.deinit();
// 循环读取数据,直到找到请求头结束标记
const header_end_index = while (connection.bufferFree().len != 0) {
_ = connection.readToBuffer() catch |e| {
return switch (e) {
std.Io.Reader.Error.EndOfStream => HttpParseError.ConnectionClose,
std.Io.Reader.Error.ReadFailed => HttpParseError.Io,
};
};
if (std.mem.indexOf(u8, connection.bufferData(), "\r\n\r\n")) |index| {
break index;
}
} else {
connection.data_len = 0;
return HttpParseError.HeadTooLarge;
};
// 消费已解析的请求头数据,避免数据残留
defer connection.consume(header_end_index + 4);
var request: Request = undefined;
request.connection = connection;
request.arena = arena;
request.payload = null;
var lines = std.mem.splitSequence(
u8,
connection.bufferData()[0..header_end_index],
"\r\n",
);
// 解析请求行
try request.parseRequestLine(lines.next() orelse {
return HttpParseError.MissingRequestLine;
});
// 初始化请求头哈希表并解析
request.headers = .init(arena.allocator());
while (lines.next()) |header| {
if (header.len == 0) {
break;
}
const index = std.mem.indexOf(u8, header, ":") orelse {
return HttpParseError.InvalidHeader;
};
const header_name = try std.ascii.allocLowerString(arena.allocator(), header[0..index]);
const header_value = if (index + 1 < header.len)
try arena.allocator().dupe(u8, std.mem.trim(u8, header[index + 1 ..], " "))
else
"";
try request.headers.put(header_name, header_value);
}
return request;
}
/// 读取请求体
pub fn body(self: *Request) ![]u8 {
if (self.payload) |payload| return payload;
const size = try self.parseContentLength();
self.payload = if (size != 0) blk: {
const payload = try self.arena.allocator().alloc(u8, size);
try self.connection.fill(payload);
break :blk payload;
} else "";
return self.payload.?;
}
/// 丢弃请求体
pub fn discardBody(self: *Request) !void {
if (self.payload) |_| {
self.payload = "";
return;
}
const size = try self.parseContentLength();
if (size != 0) {
try self.connection.discard(size);
}
self.payload = "";
}
// 解析 Content-Length 获取请求体大小
fn parseContentLength(self: *const Request) !usize {
const content_length = self.headers.get("content-length") orelse "0";
return try std.fmt.parseInt(usize, content_length, 10);
}
// 解析请求行:方法、路径、协议版本
fn parseRequestLine(self: *Request, data: []const u8) HttpParseError!void {
var parts = std.mem.splitScalar(u8, data, ' ');
self.method = blk: {
const text = parts.next() orelse {
return HttpParseError.MissingMethod;
};
const method = try HttpMethod.fromString(self.arena.allocator(), text);
break :blk method;
};
self.path = blk: {
const text = parts.next() orelse {
return HttpParseError.MissingPath;
};
break :blk try self.arena.allocator().dupe(u8, text);
};
self.version = blk: {
const text = parts.next() orelse {
return HttpParseError.MissingVersion;
};
const version = try HttpVersion.fromString(text);
break :blk version;
};
}
};
/// 客户端连接管理结构体
pub const Connection = struct {
stream: std.Io.net.Stream,
io: std.Io,
buf: []u8,
data_len: usize,
reader: std.Io.net.Stream.Reader,
const READER_BUFFER_SIZE = 4096;
pub fn init(io: std.Io, alloc: std.mem.Allocator, stream: std.Io.net.Stream) !Connection {
const buf = try alloc.alloc(u8, READER_BUFFER_SIZE + 4096);
errdefer alloc.free(buf);
return .{
.stream = stream,
.io = io,
.buf = buf,
.data_len = 0,
.reader = stream.reader(io, buf[0..READER_BUFFER_SIZE]),
};
}
pub fn deinit(self: *Connection, alloc: std.mem.Allocator) void {
self.stream.close(self.io);
alloc.free(self.buf);
}
// 获取缓冲区总空间
pub fn buffer(self: *Connection) []u8 {
return self.buf[READER_BUFFER_SIZE..];
}
// 获取缓冲区已存储数据
pub fn bufferData(self: *Connection) []u8 {
return self.buf[READER_BUFFER_SIZE .. READER_BUFFER_SIZE + self.data_len];
}
// 获取缓冲区空闲空间
pub fn bufferFree(self: *Connection) []u8 {
return self.buf[READER_BUFFER_SIZE + self.data_len ..];
}
// 从套接字读取数据到缓冲区
pub fn readToBuffer(self: *Connection) std.Io.Reader.Error!usize {
var vec = [_][]u8{self.bufferFree()};
const read_bytes = try self.reader.interface.readVec(&vec);
self.data_len += read_bytes;
return read_bytes;
}
// 读取数据填满指定的缓冲区(优先读缓冲区,不足则从套接字补读)
pub fn fill(self: *Connection, buf: []u8) std.Io.Reader.Error!void {
const n = @min(self.data_len, buf.len);
@memcpy(buf[0..n], self.bufferData()[0..n]);
if (n != buf.len) {
try self.reader.interface.readSliceAll(buf[n..]);
}
self.consume(n);
}
// 丢弃 n 字节数据
pub fn discard(self: *Connection, n: usize) std.Io.Reader.Error!void {
const min = @min(self.data_len, n);
if (min != n) {
try self.reader.interface.discardAll(n - min);
}
self.consume(min);
}
// 消费缓冲区数据
pub fn consume(self: *Connection, n: usize) void {
std.debug.assert(n <= self.data_len);
const data = self.bufferData();
@memmove(data[0 .. self.data_len - n], data[n..self.data_len]);
self.data_len -= n;
}
};
pub const HttpParseError = error{
InvalidVersion,
InvalidMethod,
InvalidHeader,
MissingVersion,
MissingMethod,
MissingPath,
MissingRequestLine,
Io,
HeadTooLarge,
ConnectionClose,
} || std.mem.Allocator.Error;