diff --git a/src/base64.zig b/src/base64.zig index 012f150..5bc4dc5 100644 --- a/src/base64.zig +++ b/src/base64.zig @@ -5,105 +5,103 @@ pub const err = error{ IndexNotFound, }; -pub const Base64 = struct { - table: *const [64]u8, +fn char_at(index: u8) u8 { + const table: *const [64]u8 = comptime "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + return table[index]; +} - pub const init: Base64 = .{ - .table = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/", - }; - - fn char_at(self: Base64, index: u8) u8 { - return self.table[index]; +pub fn encode(allocator: std.mem.Allocator, input: []const u8) ![]u8 { + if (input.len == 0) { + return ""; } - pub fn encode(self: Base64, allocator: std.mem.Allocator, input: []const u8) ![]u8 { - if (input.len == 0) { - return ""; + const out_sz = try calc_encode_length(input); + var out = try allocator.alloc(u8, out_sz); + var buf = [3]u8{0, 0, 0}; + var count: u16 = 0; + var outc: u16 = 0; + + for (input) |b| { + buf[count] = b; + count += 1; + if (count == 3) { + out[outc] = char_at(buf[0] >> 2); + out[outc + 1] = char_at(((buf[0] & 0x03) << 4) + (buf[1] >> 4)); + out[outc + 2] = char_at(((buf[1] & 0x0F) << 2) + (buf[2] >> 6)); + out[outc + 3] = char_at(buf[2] & 0x3F); + count = 0; + outc += 4; } + } - const encode_length = try calc_encode_length(input); - var out = try allocator.alloc(u8, encode_length); - var buf = [3]u8{0, 0, 0}; - var count: u16 = 0; - var outc: u16 = 0; + if (count == 2) { + out[outc] = char_at(buf[0] >> 2); + out[outc + 1] = char_at((buf[0] & 0x03) << 4) + (buf[1] >> 4); + out[outc + 2] = char_at((buf[1] & 0x0F) << 2); + out[outc + 3] = '='; + } else if (count == 1) { + out[outc] = char_at(buf[0] >> 2); + out[outc + 1] = char_at((buf[0] & 0x03) << 4); + out[outc + 2] = '='; + out[outc + 3] = '='; + } - for (input) |b| { - buf[count] = b; - count += 1; - if (count == 3) { - out[outc] = self.char_at(buf[0] >> 2); - out[outc + 1] = self.char_at(((buf[0] & 0x03) << 4) + (buf[1] >> 4)); - out[outc + 2] = self.char_at(((buf[1] & 0x0F) << 2) + (buf[2] >> 6)); - out[outc + 3] = self.char_at(buf[2] & 0x3F); - count = 0; - outc += 4; + return out; +} + +fn decode_index(char: u8) err!u8 { + if (char == '=') { + return 64; + } + + if (char >= 'A' and char <= 'Z') { + return char - 'A'; + } else if (char >= 'a' and char <= 'z') { + return char - 'a' + 26; + } else if (char >= '0' and char <= '9') { + return char - '0' + 52; + } + + switch (char) { + '+' => { return 62; }, + '/' => { return 63; }, + '=' => { return 64; }, + else => { return err.IndexNotFound; } + } +} + +pub fn decode(allocator: std.mem.Allocator, input: []const u8) ![]u8 { + if (input.len == 0) { + return ""; + } + + const out_sz = try calc_decode_length(input); + var out = try allocator.alloc(u8, out_sz); + var count: u8 = 0; + var iout: u64 = 0; + var buf = [4]u8{ 0, 0, 0, 0 }; + var cutoff :u8 = 0; + + for (0..input.len) |i| { + buf[count] = try decode_index(input[i]); + count += 1; + if (count == 4) { + out[iout] = (buf[0] << 2) + (buf[1] >> 4); + if (buf[2] != 64) { + out[iout + 1] = (buf[1] << 4) + (buf[2] >> 2); + cutoff = 2; } - } - - if (count == 2) { - out[outc] = self.char_at(buf[0] >> 2); - out[outc + 1] = self.char_at((buf[0] & 0x03) << 4) + (buf[1] >> 4); - out[outc + 2] = self.char_at((buf[1] & 0x0F) << 2); - out[outc + 3] = '='; - } else if (count == 1) { - out[outc] = self.char_at(buf[0] >> 2); - out[outc + 1] = self.char_at((buf[0] & 0x03) << 4); - out[outc + 2] = '='; - out[outc + 3] = '='; - } - - return out; - } - - fn decode_index(self: Base64, char: u8) err!u8 { - if (char == '=') { - return 64; - } - if (char >= 'A' and char <= 'Z') { - return char - 'A'; - } else if (char >= 'a' and char <= 'z') { - return char - 'a' + 26; - } else { - for (52..64) |i| { - const idx: u8 = @intCast(i); - if (self.char_at(idx) == char) { - return idx; - } + if (buf[3] != 64) { + out[iout + 2] = (buf[2] << 6) + buf[3]; + cutoff = 1; } + iout += 3; + count = 0; } - return err.IndexNotFound; } - pub fn decode(self: Base64, allocator: std.mem.Allocator, input: []const u8) ![]u8 { - if (input.len == 0) { - return ""; - } - - const output_sz = try calc_decode_length(input); - var output = try allocator.alloc(u8, output_sz); - var count: u8 = 0; - var iout: u64 = 0; - var buf = [4]u8{ 0, 0, 0, 0 }; - - for (0..input.len) |i| { - buf[count] = try self.decode_index(input[i]); - count += 1; - if (count == 4) { - output[iout] = (buf[0] << 2) + (buf[1] >> 4); - if (buf[2] != 64) { - output[iout + 1] = (buf[1] << 4) + (buf[2] >> 2); - } - if (buf[3] != 64) { - output[iout + 2] = (buf[2] << 6) + buf[3]; - } - iout += 3; - count = 0; - } - } - - return output; - } -}; + return out[0..out.len - cutoff + 1]; +} fn calc_encode_length(input: []const u8) !usize { if (input.len < 3) { @@ -122,57 +120,51 @@ fn calc_decode_length(input: []const u8) !usize { } test "encode hello" { - var b = Base64.init; var gpa = std.heap.GeneralPurposeAllocator(.{}).init; const allocator = gpa.allocator(); - const encoded = try b.encode(allocator, "hello"); + const encoded = try encode(allocator, "hello"); try std.testing.expect(std.mem.eql(u8, encoded, "aGVsbG8=")); } test "encode long" { - var b = Base64.init; var gpa = std.heap.GeneralPurposeAllocator(.{}).init; const allocator = gpa.allocator(); - const encoded = try b.encode(allocator, "Hey, it's me. I'm the problem. It's me"); + const encoded = try encode(allocator, "Hey, it's me. I'm the problem. It's me"); try std.testing.expect(std.mem.eql(u8, encoded, "SGV5LCBpdCdzIG1lLiBJJ20gdGhlIHByb2JsZW0uIEl0J3MgbWU=")); } test "decode hello" { - var b = Base64.init; var gpa = std.heap.GeneralPurposeAllocator(.{}).init; const allocator = gpa.allocator(); - const encoded = try b.decode(allocator, "aGVsbG8="); + const encoded = try decode(allocator, "aGVsbG8="); try stdout.print("{s}\n", .{encoded}); try std.testing.expect(std.mem.eql(u8, encoded, "hello")); } test "decode long" { - var b = Base64.init; var gpa = std.heap.GeneralPurposeAllocator(.{}).init; const allocator = gpa.allocator(); - const encoded = try b.decode(allocator, "SGV5LCBpdCdzIG1lLiBJJ20gdGhlIHByb2JsZW0uIEl0J3MgbWU="); + const encoded = try decode(allocator, "SGV5LCBpdCdzIG1lLiBJJ20gdGhlIHByb2JsZW0uIEl0J3MgbWU="); try stdout.print("{s}\n", .{encoded}); try std.testing.expect(std.mem.eql(u8, encoded, "Hey, it's me. I'm the problem. It's me")); } test "decode_index" { - var b = Base64.init; - - try std.testing.expectError(err.IndexNotFound, b.decode_index('{')); + try std.testing.expectError(err.IndexNotFound, decode_index('{')); var r :u8 = 0; - r = try b.decode_index('A'); + r = try decode_index('A'); try std.testing.expect(r == 0); - r = try b.decode_index('a'); + r = try decode_index('a'); try std.testing.expect(r == 26); - r = try b.decode_index('0'); + r = try decode_index('0'); try std.testing.expect(r == 52); - r = try b.decode_index('/'); + r = try decode_index('/'); try std.testing.expect(r == 63); - r = try b.decode_index('='); + r = try decode_index('='); try std.testing.expect(r == 64); } diff --git a/src/main.zig b/src/main.zig index 56015e8..e621fef 100644 --- a/src/main.zig +++ b/src/main.zig @@ -1,7 +1,7 @@ const std = @import("std"); const stdout = std.io.getStdOut().writer(); -const b64 = @import("base64.zig"); +const Base64 = @import("base64.zig"); const usageText = \\Usage: baze64 {-d/-D} input @@ -16,22 +16,34 @@ pub fn main() !void { var args = try std.process.argsWithAllocator(allocator); defer args.deinit(); - var b = b64.Base64.init; + var decodeSwitch = false; + var input: []const u8 = undefined; // Skip arg[0] _ = args.skip(); - var a = args.next(); - if (a) |arg| { - if (std.mem.eql(u8, arg, "-d") or std.mem.eql(u8, arg, "-D")) { - a = args.next(); - const decoded = try b.decode(allocator, arg); - try stdout.print("{s}", .{decoded}); + while (true) { + const argVal = args.next(); + var arg: []const u8 = undefined; + if (argVal) |a| { + arg = a; } else { - const encoded = try b.decode(allocator, arg); - try stdout.print("{s}", .{encoded}); + break; } + + if (std.mem.eql(u8, arg, "-d") or std.mem.eql(u8, arg, "-D")) { + decodeSwitch = true; + } else { + input = arg; + break; + } + } + + if (decodeSwitch) { + const decoded = try Base64.decode(allocator, input); + try stdout.print("{s}", .{decoded}); } else { - try stdout.print("{s}", .{usageText}); + const encoded = try Base64.encode(allocator, input); + try stdout.print("{s}", .{encoded}); } }