{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "fa3b1299",
   "metadata": {},
   "source": [
    "# Creating new visitors"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "23abb66d",
   "metadata": {},
   "source": [
    "In the previous notebook, we relied heavily on the [_FindNodes_](https://sites.ecmwf.int/docs/loki/main/loki.visitors.find.html#loki.visitors.find.FindNodes) visitor, which looks through a given IR tree and returns a list of matching instances of a specified [_Node_](https://sites.ecmwf.int/docs/loki/main/loki.ir.html#loki.ir.Node) type. Although this functionality is sufficient for most use cases, there may be scenarios that require the implementation of bespoke visitors.\n",
    "\n",
    "For node types that could appear in a nested structure, for example [_Loop_](https://sites.ecmwf.int/docs/loki/main/loki.ir.html#loki.ir.Loop) or [_Conditional_](https://sites.ecmwf.int/docs/loki/main/loki.ir.html#loki.ir.Conditional), we may be interested in knowing at what depth they appear in a given IR tree. The following illustrates how this can be achieved by building a new `FindNodesDepth` visitor based on `FindNodes`."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3d758ee9",
   "metadata": {},
   "source": [
    "## Dataclass to store return values"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f0537525",
   "metadata": {},
   "source": [
    "The default return value for `FindNodes` is a list of nodes. For `FindNodesDepth`, we would also like to return the depth of the node. We can create a new dataclass (essentially a c-style struct) called `DepthNode` to store both these pieces of information:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "547ef8bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "from loki import Node\n",
    "from dataclasses import dataclass\n",
    "\n",
    "@dataclass\n",
    "class DepthNode:\n",
    "    \"\"\"Store node object and depth in c-style struct.\"\"\"\n",
    "    \n",
    "    node: Node\n",
    "    depth: int"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8d63f588",
   "metadata": {},
   "source": [
    "## Modifying initialization method"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ee2518e0",
   "metadata": {},
   "source": [
    "`FindNodes` has two operating modes. The first (and default mode) is to look through a given IR tree and return a list of matching instances of a specified node type. The second, which is enabled by passing `mode='scope'` when creating the visitor, returns the [_InternalNode_](https://sites.ecmwf.int/docs/loki/main/loki.ir.html#loki.ir.InternalNode) i.e. the [_Scope_](https://sites.ecmwf.int/docs/loki/main/loki.scope.html#loki.scope.Scope) in which a specified node appears.\n",
    "\n",
    "For our new visitor, we are only interested in the default operating mode of `FindNodes`. Therefore let us define a new initialization function for our `FindNodesDepth` class:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "37350bd3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from loki import FindNodes\n",
    "\n",
    "class FindNodesDepth(FindNodes):\n",
    "    \"\"\"Visitor that computes node-depth relative to subroutine body. Returns list of DepthNode objects.\"\"\"\n",
    "    \n",
    "    def __init__(self, match, greedy=False):\n",
    "        super().__init__(match, mode='type', greedy=greedy)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "613e4b3a",
   "metadata": {},
   "source": [
    "## Modifying the `visit_Node` method"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "77845b25",
   "metadata": {},
   "source": [
    "In order to achieve the desired functionality of our new visitor, we will need a new `visit_Node` method. We start from a copy of `FindNodes.visit_Node` and make only a few changes to it:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "67983caa",
   "metadata": {},
   "outputs": [],
   "source": [
    "def visit_Node(self, o, **kwargs):\n",
    "    \"\"\"\n",
    "    Add the node to the returned list if it matches the criteria and increment depth\n",
    "    before visiting all children.\n",
    "    \"\"\"\n",
    "\n",
    "    ret = kwargs.pop('ret', self.default_retval())\n",
    "    depth = kwargs.pop('depth', 0)\n",
    "    if self.rule(self.match, o): \n",
    "        ret.append(DepthNode(o, depth))\n",
    "        if self.greedy:\n",
    "            return ret \n",
    "    for i in o.children:\n",
    "        ret = self.visit(i, depth=depth+1, ret=ret, **kwargs)\n",
    "    return ret or self.default_retval()\n",
    "\n",
    "FindNodesDepth.visit_Node = visit_Node"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f5d6ba1d",
   "metadata": {},
   "source": [
    "The first change to `visit_Node` is the addition of a line that sets `depth`. If `visit_Node` is called from the base IR tree, then `depth` is initialized to 0. If on the other hand `visit_Node` is called recursively, then the current `depth` of node `o` is retrieved. The second and final change is the addition of a `depth` keyword argument to the recursive call to `visit` for the children of node `o`. As recursion signifies moving down one level in the IR tree, the `depth+1` is passed as an argument.\n",
    "\n",
    "Having now fully defined our new visitor, we can test it on the following routine containing nested loops:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "cf2196d3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "DO k=1,n\n",
      "  DO j=1,n\n",
      "    DO i=1,n\n",
      "      var_out(i, j, k) = var_in(i, j, k)\n",
      "    END DO\n",
      "    DO i=1,n\n",
      "      var_out(i, j, k) = 2._JPRB*var_out(i, j, k)\n",
      "    END DO\n",
      "  END DO\n",
      "  \n",
      "  CALL some_kernel(n, var_out(1, 1, k))\n",
      "  \n",
      "  DO j=1,n\n",
      "    DO i=1,n\n",
      "      var_out(i, j, k) = var_out(i, j, k) + 1._JPRB\n",
      "    END DO\n",
      "    DO i=1,n\n",
      "      var_out(i, j, k) = 2._JPRB*var_out(i, j, k)\n",
      "    END DO\n",
      "  END DO\n",
      "END DO\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from loki import Sourcefile\n",
    "from loki import fgen\n",
    "\n",
    "source = Sourcefile.from_file('src/loop_fuse.F90')\n",
    "routine = source['loop_fuse_v1']\n",
    "print(fgen(routine.body))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "02b6e47c",
   "metadata": {},
   "source": [
    "`loop_fuse_v1` contains a total of 7 loops, with a maximum nesting depth of 3. Let us see if our new visitor can identify the loops and their depth correctly:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "df95eda3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 Loop:: k=1:n 1\n",
      "1 Loop:: j=1:n 2\n",
      "2 Loop:: i=1:n 3\n",
      "3 Loop:: i=1:n 3\n",
      "4 Loop:: j=1:n 2\n",
      "5 Loop:: i=1:n 3\n",
      "6 Loop:: i=1:n 3\n"
     ]
    }
   ],
   "source": [
    "from loki import Loop\n",
    "\n",
    "loops = FindNodesDepth(Loop).visit(routine.body)\n",
    "\n",
    "for k, loop in enumerate(loops):\n",
    "    print(k, loop.node, loop.depth)\n",
    "    \n",
    "depths = [1, 2, 3, 3, 2, 3, 3]\n",
    "assert(depths == [loop.depth for loop in loops])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "74ce856d",
   "metadata": {},
   "source": [
    "As the output shows, the depth of all 7 loops was identified correctly. Note that the subroutine body itself is assigned a depth of 0, and because the outermost `k`-loop is a child of the subroutine body, it has a depth of 1.\n",
    "\n",
    "We can also use our new visitor to find the depth of the [_Assignment_](https://sites.ecmwf.int/docs/loki/main/loki.ir.html#loki.ir.Assignment) statements within the bodies of the loops:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "2aa221a6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 Assignment:: var_out(i, j, k) = var_in(i, j, k)             4\n",
      "1 Assignment:: var_out(i, j, k) = 2._JPRB*var_out(i, j, k)    4\n",
      "2 Assignment:: var_out(i, j, k) = var_out(i, j, k) + 1._JPRB  4\n",
      "3 Assignment:: var_out(i, j, k) = 2._JPRB*var_out(i, j, k)    4\n"
     ]
    }
   ],
   "source": [
    "from loki import Assignment\n",
    "\n",
    "assigns = FindNodesDepth(Assignment).visit(routine.body)\n",
    "\n",
    "for k, assign in enumerate(assigns):\n",
    "    print(f'{k} {str(assign.node):<60}{assign.depth}')\n",
    "    \n",
    "depths = [4, 4, 4, 4]\n",
    "assert(depths == [assign.depth for assign in assigns])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "067e5919",
   "metadata": {},
   "source": [
    "All the `Assignment` statements and their respective depths are identified correctly. We can do a similar test on nested `if` statements:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "56f6f076",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 i 1\n",
      "1 j 2\n",
      "2 k 3\n",
      "3 h 3\n"
     ]
    }
   ],
   "source": [
    "from loki import Subroutine\n",
    "from loki import Conditional\n",
    "\n",
    "fcode = \"\"\" \n",
    "subroutine nested_conditionals(i,j,k,h)\n",
    "    \n",
    "    logical,intent(in) :: i,j,k,h\n",
    "\n",
    "    if(i)then\n",
    "      if(j)then\n",
    "\n",
    "        if(k)then\n",
    "          ! do something\n",
    "        else\n",
    "          ! do something else\n",
    "        endif\n",
    "        \n",
    "        if(h)then\n",
    "          ! also test h\n",
    "        endif\n",
    "\n",
    "      endif\n",
    "    endif\n",
    "\n",
    "end subroutine nested_conditionals\n",
    "\"\"\"\n",
    "\n",
    "routine = Subroutine.from_source(fcode)\n",
    "\n",
    "conds = FindNodesDepth(Conditional).visit(routine.body)\n",
    "for k, cond in enumerate(conds):\n",
    "    print(k, cond.node.condition, cond.depth)\n",
    "    \n",
    "depths = [1, 2, 3, 3]\n",
    "assert(depths == [cond.depth for cond in conds])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.8 ('loki_env': venv)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  },
  "vscode": {
   "interpreter": {
    "hash": "5b6429b76fde06fc4400bf3c27b3ae893ffb7a047f8b8ee9418a3bc77878d107"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
