# -*- coding: utf-8 -*-

IR utilities.

from __future__ import print_function, division, absolute_import
import collections
import difflib
import contextlib

from pykit.utils import nestedmap, listify

def linearize(func):
    Return a linearized from of the IR and a dict mapping basic blocks to
    result = []
    blockstarts = {} # { block_label : instruction offset }
    for block in func.blocks:
        blockstarts[block.name] = len(result)

    return result, blockstarts

def index(function, indexed=None):
    """Index the IR, returning { opcode: [operations] }"""
    indexed = indexed or collections.defaultdict(list)
    for block in function.blocks:
        for op in block.ops:

    return indexed

# ______________________________________________________________________

def _getops(func_or_block_or_list):
    if isinstance(func_or_block_or_list, list):
        return func_or_block_or_list
    return func_or_block_or_list.ops

def findop(container, opcode):
    """Find the first Operation with the given opcode"""
    for op in _getops(container):
        if op.opcode == opcode:
            return op

def findallops(container, opcode):
    """Find all Operations with the given opcode"""
    found = []
    for op in _getops(container):
        if op.opcode == opcode:

    return found

def opcodes(container):
    """Returns [opcode] for all operations"""
    for op in _getops(container):
        yield op.opcode

def optypes(container):
    """Returns [type] for all operations"""
    for op in _getops(container):
        yield op.type

# ______________________________________________________________________

def vmap(f, func):
    Apply `f` over all the values in `func`, that is, all Op, Const, FuncArg
    and GlobalValue.
    from . import GlobalValue, Const

    for arg in func.args:
    for op in func.ops:
        for arg in op.args:
            if isinstance(arg, (GlobalValue, Const)):

# ______________________________________________________________________

def diff(before, after):
    """Diff two strings"""
    lines = difflib.Differ().compare(before.splitlines(), after.splitlines())
    return "\n".join(lines)

def passdiff(func):
    with passdiff(func):
    before = str(func)
    after = str(func)
    print(diff(before, after))