const std = @import("std");
const Io = std.Io;
const testing = std.testing;
const Monad = @import("monad.zig").Monad;
const monad = @import("monad.zig").monad;

// Monad law tests
// https://wiki.haskell.org/Monad_laws

fn expectEqual(io: std.Io, val: u32, m1: Monad(u32), m2: Monad(u32)) !void {
    // Convert monads to regular futures. This is due to the nature of expectEqual (Its not monadic)
    var f1 = std.Io.Future(u32){ .any_future = m1.any_future, .result = m1.val };
    var f2 = std.Io.Future(u32){ .any_future = m2.any_future, .result = m2.val };
    try testing.expectEqual(val, f1.await(io));
    try testing.expectEqual(val, f2.await(io));
}

fn h(val: u32) Monad(u32) {
    // Add one to a
    return monad(val + 1);
}

fn g(val: u32) Monad(u32) {
    // Multiply a by 5
    return monad(val * 5);
}

// Left identity
test "Left Identity" {
    const global = struct {
        fn testOne(_: void, smith: *testing.Smith) anyerror!void {
            const a = smith.value(u32);

            // return a >>= h
            const cmp1 = monad(a).bind(h, testing.io);

            // h a
            const cmp2 = h(a);

            try expectEqual(testing.io, a + 1, cmp1, cmp2);
        }
    };
    try testing.fuzz({}, global.testOne, .{});
}

// Right identity
test "Right Identity" {
    const global = struct {
        fn testOne(_: void, smith: *testing.Smith) anyerror!void {
            const val = smith.value(u32);
            const m = monad(val);

            // m >>= return
            const cmp1 = m.bind(Monad(u32).@"return", testing.io);

            // m
            const cmp2 = m;

            // m is defined above
            try expectEqual(testing.io, val, cmp1, cmp2);
        }
    };
    try testing.fuzz({}, global.testOne, .{});
}

// Associativity
test "Associativity" {
    const global = struct {
        // (\x -> g x >>= h)
        fn applySecond(x: u32, io: std.Io) Monad(u32) {
            return g(x).bind(h, io);
        }

        fn testOne(_: void, smith: *testing.Smith) anyerror!void {
            const a = smith.value(u32);

            // (m >>= g) >>= h
            const cmp1 = monad(a).bind(g, testing.io).bind(h, testing.io);

            // m >>= (\x -> g x >>= h)
            const cmp2 = monad(.{ a, testing.io }).bind(applySecond, testing.io);

            try expectEqual(testing.io, (a * 5) + 1, cmp1, cmp2);
        }
    };
    try testing.fuzz({}, global.testOne, .{});
}
