const std = @import("std");
const Io = std.Io;

/// helper for Monad(T).@"return"
pub fn monad(val: anytype) Monad(@TypeOf(val)) {
    return Monad(@TypeOf(val)).@"return"(val);
}

fn ReturnType(Transform: type) type {
    if (@typeInfo(Transform) != .@"fn")
        @compileError("transform must be a function.");
    const function = @typeInfo(Transform).@"fn";
    const ResultM = function.return_type.?;
    if (@typeInfo(ResultM) != .@"struct" or !@hasField(ResultM, "val"))
        @compileError("transform does not return a Monad type");
    const Result = @FieldType(ResultM, "val");
    if (ResultM != Monad(Result))
        @compileError("transform does not return a Monad type");

    return Result;
}

/// Represents a computation which produces a value of type T
/// from a value of type Arg
pub fn Monad(T: type) type {
    return struct {
        any_future: ?*std.Io.AnyFuture,
        val: T,

        pub fn @"return"(t: T) Monad(T) {
            return .{
                .any_future = null,
                .val = t,
            };
        }

        fn await(m: @This(), io: std.Io) T {
            var val: T = undefined;
            if (m.any_future) |any_future| {
                io.vtable.await(
                    io.userdata,
                    any_future,
                    @ptrCast(&val),
                    .of(T),
                );
            } else {
                val = m.val;
            }
            return val;
        }

        pub fn bind(
            m: Monad(T),
            transform: anytype,
            io: std.Io,
        ) Monad(ReturnType(@TypeOf(transform))) {
            const Result = ReturnType(@TypeOf(transform));
            const Context = struct { m: Monad(T), io: std.Io };

            const TypeErased = struct {
                fn start(context_: *const anyopaque, result_: *anyopaque) void {
                    const context_casted: *const Context =
                        @ptrCast(@alignCast(context_));
                    const result_casted: *ReturnType(@TypeOf(transform)) =
                        @ptrCast(@alignCast(result_));

                    // Firstly, await the monad value, so we can pass it to the
                    // function
                    const arg = context_casted.m.await(context_casted.io);

                    // This is an additional feature which this bind()
                    // function supports. If the monad contains a Tuple
                    // (https://ziglang.org/documentation/master/#Tuples),
                    // pass it in as the argument parameter, otherwise
                    // wrap it in a new tuple.
                    const arg_tuple =
                        if (@typeInfo(T) == .@"struct" and @typeInfo(T).@"struct".is_tuple)
                            arg
                        else
                            .{arg};

                    // Get the new monad from the function
                    const result =
                        @call(
                            .always_inline,
                            transform,
                            arg_tuple,
                        );

                    // Finally, await the returned monad. This value will ultimately be
                    // the value represented by the Monad(T) returned from bind()
                    result_casted.* = result.await(context_casted.io);
                }
            };

            var result: Monad(ReturnType(@TypeOf(transform))) = undefined;
            result.any_future = io.vtable.async(
                io.userdata,
                @ptrCast(&result.val),
                .of(Result),
                @ptrCast(&Context{ .m = m, .io = io }),
                .of(Context),
                TypeErased.start,
            );
            return result;
        }

        pub fn yield(self: @This(), io: std.Io) void {
            if (self.any_future) |any_future| {
                var result: T = undefined;
                io.vtable.await(io.userdata, any_future, @ptrCast(&result), .of(T));
            }
        }
    };
}
