#!/usr/bin/env python3

import argparse
import copy
import json
import re
import subprocess
from enum import Enum as PyEnum
from pathlib import Path
from typing import Callable
from urllib import request

VoidFn = Callable[[], None]

CHEATCODES_JSON_URL = "https://raw.githubusercontent.com/foundry-rs/foundry/master/crates/cheatcodes/assets/cheatcodes.json"
OUT_PATH = "src/Vm.sol"

VM_SAFE_DOC = """\
/// The `VmSafe` interface does not allow manipulation of the EVM state or other actions that may
/// result in Script simulations differing from on-chain execution. It is recommended to only use
/// these cheats in scripts.
"""

VM_DOC = """\
/// The `Vm` interface does allow manipulation of the EVM state. These are all intended to be used
/// in tests, but it is not recommended to use these cheats in scripts.
"""


def main():
    parser = argparse.ArgumentParser(
            description="Generate Vm.sol based on the cheatcodes json created by Foundry")
    parser.add_argument(
            "--from",
            metavar="PATH",
            dest="path",
            required=False,
            help="path to a json file containing the Vm interface, as generated by Foundry")
    args = parser.parse_args()
    json_str = request.urlopen(CHEATCODES_JSON_URL).read().decode("utf-8") if args.path is None else Path(args.path).read_text()
    contract = Cheatcodes.from_json(json_str)

    ccs = contract.cheatcodes
    ccs = list(filter(lambda cc: cc.status not in ["experimental", "internal"], ccs))
    ccs.sort(key=lambda cc: cc.func.id)

    safe = list(filter(lambda cc: cc.safety == "safe", ccs))
    safe.sort(key=CmpCheatcode)
    unsafe = list(filter(lambda cc: cc.safety == "unsafe", ccs))
    unsafe.sort(key=CmpCheatcode)
    assert len(safe) + len(unsafe) == len(ccs)

    prefix_with_group_headers(safe)
    prefix_with_group_headers(unsafe)

    out = ""

    out += "// Automatically @generated by scripts/vm.py. Do not modify manually.\n\n"

    pp = CheatcodesPrinter(
        spdx_identifier="MIT OR Apache-2.0",
        solidity_requirement=">=0.6.2 <0.9.0",
        abicoder_pragma=True,
    )
    pp.p_prelude()
    pp.prelude = False
    out += pp.finish()

    out += "\n\n"
    out += VM_SAFE_DOC
    vm_safe = Cheatcodes(
        # TODO: Custom errors were introduced in 0.8.4
        errors=[],  # contract.errors
        events=contract.events,
        enums=contract.enums,
        structs=contract.structs,
        cheatcodes=safe,
    )
    pp.p_contract(vm_safe, "VmSafe")
    out += pp.finish()

    out += "\n\n"
    out += VM_DOC
    vm_unsafe = Cheatcodes(
        errors=[],
        events=[],
        enums=[],
        structs=[],
        cheatcodes=unsafe,
    )
    pp.p_contract(vm_unsafe, "Vm", "VmSafe")
    out += pp.finish()

    # Compatibility with <0.8.0
    def memory_to_calldata(m: re.Match) -> str:
        return " calldata " + m.group(1)

    out = re.sub(r" memory (.*returns)", memory_to_calldata, out)

    with open(OUT_PATH, "w") as f:
        f.write(out)

    forge_fmt = ["forge", "fmt", OUT_PATH]
    res = subprocess.run(forge_fmt)
    assert res.returncode == 0, f"command failed: {forge_fmt}"

    print(f"Wrote to {OUT_PATH}")


class CmpCheatcode:
    cheatcode: "Cheatcode"

    def __init__(self, cheatcode: "Cheatcode"):
        self.cheatcode = cheatcode

    def __lt__(self, other: "CmpCheatcode") -> bool:
        return cmp_cheatcode(self.cheatcode, other.cheatcode) < 0

    def __eq__(self, other: "CmpCheatcode") -> bool:
        return cmp_cheatcode(self.cheatcode, other.cheatcode) == 0

    def __gt__(self, other: "CmpCheatcode") -> bool:
        return cmp_cheatcode(self.cheatcode, other.cheatcode) > 0


def cmp_cheatcode(a: "Cheatcode", b: "Cheatcode") -> int:
    if a.group != b.group:
        return -1 if a.group < b.group else 1
    if a.status != b.status:
        return -1 if a.status < b.status else 1
    if a.safety != b.safety:
        return -1 if a.safety < b.safety else 1
    if a.func.id != b.func.id:
        return -1 if a.func.id < b.func.id else 1
    return 0


# HACK: A way to add group header comments without having to modify printer code
def prefix_with_group_headers(cheats: list["Cheatcode"]):
    s = set()
    for i, cheat in enumerate(cheats):
        if cheat.group in s:
            continue

        s.add(cheat.group)

        c = copy.deepcopy(cheat)
        c.func.description = ""
        c.func.declaration = f"// ======== {group(c.group)} ========"
        cheats.insert(i, c)
    return cheats


def group(s: str) -> str:
    if s == "evm":
        return "EVM"
    if s == "json":
        return "JSON"
    return s[0].upper() + s[1:]


class Visibility(PyEnum):
    EXTERNAL: str = "external"
    PUBLIC: str = "public"
    INTERNAL: str = "internal"
    PRIVATE: str = "private"

    def __str__(self):
        return self.value


class Mutability(PyEnum):
    PURE: str = "pure"
    VIEW: str = "view"
    NONE: str = ""

    def __str__(self):
        return self.value


class Function:
    id: str
    description: str
    declaration: str
    visibility: Visibility
    mutability: Mutability
    signature: str
    selector: str
    selector_bytes: bytes

    def __init__(
        self,
        id: str,
        description: str,
        declaration: str,
        visibility: Visibility,
        mutability: Mutability,
        signature: str,
        selector: str,
        selector_bytes: bytes,
    ):
        self.id = id
        self.description = description
        self.declaration = declaration
        self.visibility = visibility
        self.mutability = mutability
        self.signature = signature
        self.selector = selector
        self.selector_bytes = selector_bytes

    @staticmethod
    def from_dict(d: dict) -> "Function":
        return Function(
            d["id"],
            d["description"],
            d["declaration"],
            Visibility(d["visibility"]),
            Mutability(d["mutability"]),
            d["signature"],
            d["selector"],
            bytes(d["selectorBytes"]),
        )


class Cheatcode:
    func: Function
    group: str
    status: str
    safety: str

    def __init__(self, func: Function, group: str, status: str, safety: str):
        self.func = func
        self.group = group
        self.status = status
        self.safety = safety

    @staticmethod
    def from_dict(d: dict) -> "Cheatcode":
        return Cheatcode(
            Function.from_dict(d["func"]),
            str(d["group"]),
            str(d["status"]),
            str(d["safety"]),
        )


class Error:
    name: str
    description: str
    declaration: str

    def __init__(self, name: str, description: str, declaration: str):
        self.name = name
        self.description = description
        self.declaration = declaration

    @staticmethod
    def from_dict(d: dict) -> "Error":
        return Error(**d)


class Event:
    name: str
    description: str
    declaration: str

    def __init__(self, name: str, description: str, declaration: str):
        self.name = name
        self.description = description
        self.declaration = declaration

    @staticmethod
    def from_dict(d: dict) -> "Event":
        return Event(**d)


class EnumVariant:
    name: str
    description: str

    def __init__(self, name: str, description: str):
        self.name = name
        self.description = description


class Enum:
    name: str
    description: str
    variants: list[EnumVariant]

    def __init__(self, name: str, description: str, variants: list[EnumVariant]):
        self.name = name
        self.description = description
        self.variants = variants

    @staticmethod
    def from_dict(d: dict) -> "Enum":
        return Enum(
            d["name"],
            d["description"],
            list(map(lambda v: EnumVariant(**v), d["variants"])),
        )


class StructField:
    name: str
    ty: str
    description: str

    def __init__(self, name: str, ty: str, description: str):
        self.name = name
        self.ty = ty
        self.description = description


class Struct:
    name: str
    description: str
    fields: list[StructField]

    def __init__(self, name: str, description: str, fields: list[StructField]):
        self.name = name
        self.description = description
        self.fields = fields

    @staticmethod
    def from_dict(d: dict) -> "Struct":
        return Struct(
            d["name"],
            d["description"],
            list(map(lambda f: StructField(**f), d["fields"])),
        )


class Cheatcodes:
    errors: list[Error]
    events: list[Event]
    enums: list[Enum]
    structs: list[Struct]
    cheatcodes: list[Cheatcode]

    def __init__(
        self,
        errors: list[Error],
        events: list[Event],
        enums: list[Enum],
        structs: list[Struct],
        cheatcodes: list[Cheatcode],
    ):
        self.errors = errors
        self.events = events
        self.enums = enums
        self.structs = structs
        self.cheatcodes = cheatcodes

    @staticmethod
    def from_dict(d: dict) -> "Cheatcodes":
        return Cheatcodes(
            errors=[Error.from_dict(e) for e in d["errors"]],
            events=[Event.from_dict(e) for e in d["events"]],
            enums=[Enum.from_dict(e) for e in d["enums"]],
            structs=[Struct.from_dict(e) for e in d["structs"]],
            cheatcodes=[Cheatcode.from_dict(e) for e in d["cheatcodes"]],
        )

    @staticmethod
    def from_json(s) -> "Cheatcodes":
        return Cheatcodes.from_dict(json.loads(s))

    @staticmethod
    def from_json_file(file_path: str) -> "Cheatcodes":
        with open(file_path, "r") as f:
            return Cheatcodes.from_dict(json.load(f))


class Item(PyEnum):
    ERROR: str = "error"
    EVENT: str = "event"
    ENUM: str = "enum"
    STRUCT: str = "struct"
    FUNCTION: str = "function"


class ItemOrder:
    _list: list[Item]

    def __init__(self, list: list[Item]) -> None:
        assert len(list) <= len(Item), "list must not contain more items than Item"
        assert len(list) == len(set(list)), "list must not contain duplicates"
        self._list = list
        pass

    def get_list(self) -> list[Item]:
        return self._list

    @staticmethod
    def default() -> "ItemOrder":
        return ItemOrder(
            [
                Item.ERROR,
                Item.EVENT,
                Item.ENUM,
                Item.STRUCT,
                Item.FUNCTION,
            ]
        )


class CheatcodesPrinter:
    buffer: str

    prelude: bool
    spdx_identifier: str
    solidity_requirement: str
    abicoder_v2: bool

    block_doc_style: bool

    indent_level: int
    _indent_str: str

    nl_str: str

    items_order: ItemOrder

    def __init__(
        self,
        buffer: str = "",
        prelude: bool = True,
        spdx_identifier: str = "UNLICENSED",
        solidity_requirement: str = "",
        abicoder_pragma: bool = False,
        block_doc_style: bool = False,
        indent_level: int = 0,
        indent_with: int | str = 4,
        nl_str: str = "\n",
        items_order: ItemOrder = ItemOrder.default(),
    ):
        self.prelude = prelude
        self.spdx_identifier = spdx_identifier
        self.solidity_requirement = solidity_requirement
        self.abicoder_v2 = abicoder_pragma
        self.block_doc_style = block_doc_style
        self.buffer = buffer
        self.indent_level = indent_level
        self.nl_str = nl_str

        if isinstance(indent_with, int):
            assert indent_with >= 0
            self._indent_str = " " * indent_with
        elif isinstance(indent_with, str):
            self._indent_str = indent_with
        else:
            assert False, "indent_with must be int or str"

        self.items_order = items_order

    def finish(self) -> str:
        ret = self.buffer.rstrip()
        self.buffer = ""
        return ret

    def p_contract(self, contract: Cheatcodes, name: str, inherits: str = ""):
        if self.prelude:
            self.p_prelude(contract)

        self._p_str("interface ")
        name = name.strip()
        if name != "":
            self._p_str(name)
            self._p_str(" ")
        if inherits != "":
            self._p_str("is ")
            self._p_str(inherits)
            self._p_str(" ")
        self._p_str("{")
        self._p_nl()
        self._with_indent(lambda: self._p_items(contract))
        self._p_str("}")
        self._p_nl()

    def _p_items(self, contract: Cheatcodes):
        for item in self.items_order.get_list():
            if item == Item.ERROR:
                self.p_errors(contract.errors)
            elif item == Item.EVENT:
                self.p_events(contract.events)
            elif item == Item.ENUM:
                self.p_enums(contract.enums)
            elif item == Item.STRUCT:
                self.p_structs(contract.structs)
            elif item == Item.FUNCTION:
                self.p_functions(contract.cheatcodes)
            else:
                assert False, f"unknown item {item}"

    def p_prelude(self, contract: Cheatcodes | None = None):
        self._p_str(f"// SPDX-License-Identifier: {self.spdx_identifier}")
        self._p_nl()

        if self.solidity_requirement != "":
            req = self.solidity_requirement
        elif contract and len(contract.errors) > 0:
            req = ">=0.8.4 <0.9.0"
        else:
            req = ">=0.6.0 <0.9.0"
        self._p_str(f"pragma solidity {req};")
        self._p_nl()

        if self.abicoder_v2:
            self._p_str("pragma experimental ABIEncoderV2;")
            self._p_nl()

        self._p_nl()

    def p_errors(self, errors: list[Error]):
        for error in errors:
            self._p_line(lambda: self.p_error(error))

    def p_error(self, error: Error):
        self._p_comment(error.description, doc=True)
        self._p_line(lambda: self._p_str(error.declaration))

    def p_events(self, events: list[Event]):
        for event in events:
            self._p_line(lambda: self.p_event(event))

    def p_event(self, event: Event):
        self._p_comment(event.description, doc=True)
        self._p_line(lambda: self._p_str(event.declaration))

    def p_enums(self, enums: list[Enum]):
        for enum in enums:
            self._p_line(lambda: self.p_enum(enum))

    def p_enum(self, enum: Enum):
        self._p_comment(enum.description, doc=True)
        self._p_line(lambda: self._p_str(f"enum {enum.name} {{"))
        self._with_indent(lambda: self.p_enum_variants(enum.variants))
        self._p_line(lambda: self._p_str("}"))

    def p_enum_variants(self, variants: list[EnumVariant]):
        for i, variant in enumerate(variants):
            self._p_indent()
            self._p_comment(variant.description)

            self._p_indent()
            self._p_str(variant.name)
            if i < len(variants) - 1:
                self._p_str(",")
            self._p_nl()

    def p_structs(self, structs: list[Struct]):
        for struct in structs:
            self._p_line(lambda: self.p_struct(struct))

    def p_struct(self, struct: Struct):
        self._p_comment(struct.description, doc=True)
        self._p_line(lambda: self._p_str(f"struct {struct.name} {{"))
        self._with_indent(lambda: self.p_struct_fields(struct.fields))
        self._p_line(lambda: self._p_str("}"))

    def p_struct_fields(self, fields: list[StructField]):
        for field in fields:
            self._p_line(lambda: self.p_struct_field(field))

    def p_struct_field(self, field: StructField):
        self._p_comment(field.description)
        self._p_indented(lambda: self._p_str(f"{field.ty} {field.name};"))

    def p_functions(self, cheatcodes: list[Cheatcode]):
        for cheatcode in cheatcodes:
            self._p_line(lambda: self.p_function(cheatcode.func))

    def p_function(self, func: Function):
        self._p_comment(func.description, doc=True)
        self._p_line(lambda: self._p_str(func.declaration))

    def _p_comment(self, s: str, doc: bool = False):
        s = s.strip()
        if s == "":
            return

        s = map(lambda line: line.lstrip(), s.split("\n"))
        if self.block_doc_style:
            self._p_str("/*")
            if doc:
                self._p_str("*")
            self._p_nl()
            for line in s:
                self._p_indent()
                self._p_str(" ")
                if doc:
                    self._p_str("* ")
                self._p_str(line)
                self._p_nl()
            self._p_indent()
            self._p_str(" */")
            self._p_nl()
        else:
            first_line = True
            for line in s:
                if not first_line:
                    self._p_indent()
                first_line = False

                if doc:
                    self._p_str("/// ")
                else:
                    self._p_str("// ")
                self._p_str(line)
                self._p_nl()

    def _with_indent(self, f: VoidFn):
        self._inc_indent()
        f()
        self._dec_indent()

    def _p_line(self, f: VoidFn):
        self._p_indent()
        f()
        self._p_nl()

    def _p_indented(self, f: VoidFn):
        self._p_indent()
        f()

    def _p_indent(self):
        for _ in range(self.indent_level):
            self._p_str(self._indent_str)

    def _p_nl(self):
        self._p_str(self.nl_str)

    def _p_str(self, txt: str):
        self.buffer += txt

    def _inc_indent(self):
        self.indent_level += 1

    def _dec_indent(self):
        self.indent_level -= 1


if __name__ == "__main__":
    main()
