{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "eac9ce5b",
   "metadata": {},
   "source": [
    "# Loop fusion with Loki"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "27a3debb",
   "metadata": {},
   "source": [
    "The objective of this notebook is to go through examples of how loop fusion can be performed using Loki. It is a continuation in a series of notebooks, and builds on the lessons of notebooks on [`Reading and writing files with Loki`](https://git.ecmwf.int/projects/RDX/repos/loki/browse/example/01_reading_and_writing_files.ipynb) and [`Working with Loki's internal representation`](https://git.ecmwf.int/projects/RDX/repos/loki/browse/example/02_working_with_the_ir.ipynb).\n",
    "\n",
    "Let us start by parsing the file `src/loop_fuse.F90` from the `example` directory and pick out the `loop_fuse` [_Subroutine_](https://sites.ecmwf.int/docs/loki/main/loki.subroutine.html#loki.subroutine.Subroutine) from that file:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "2e5feac7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Subroutine:: loop_fuse"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from loki import Sourcefile\n",
    "source = Sourcefile.from_file('src/loop_fuse.F90')\n",
    "routine = source['loop_fuse']\n",
    "routine"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "32e6430e",
   "metadata": {},
   "source": [
    "`loop_fuse` starts with an [_Import_](https://sites.ecmwf.int/docs/loki/main/loki.ir.html#loki.ir.Import) statement to load the parameters `jpim` and `jprb`. Even though we have not specified where the [_Module_](https://sites.ecmwf.int/docs/loki/main/loki.module.html#loki.module.Module) `parkind1` is located, Loki is still able to successfully parse the file and treats `jpim` and `jprb` as a [_DeferredTypeSymbol_](https://sites.ecmwf.int/docs/loki/main/loki.expression.symbols.html#loki.expression.symbols.DeferredTypeSymbol). We can verify this by examining the specification [_Section_](https://sites.ecmwf.int/do/docs/loki/main/loki.ir.html#loki.ir.Section) of the  `loop_fuse`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b2078568",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "USE parkind1, ONLY: jpim, jprb\n",
      "IMPLICIT NONE\n",
      "\n",
      "INTEGER(KIND=jpim), INTENT(IN) :: n\n",
      "REAL(KIND=jprb), INTENT(IN) :: var_in(n, n, n)\n",
      "REAL(KIND=jprb), INTENT(OUT) :: var_out(n, n, n)\n",
      "INTEGER(KIND=jpim) :: i, j, k\n"
     ]
    }
   ],
   "source": [
    "from loki import fgen\n",
    "print(fgen(routine.spec))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "85708bf9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<Section::>\n",
      "  <Import:: parkind1 => (DeferredTypeSymbol('jpim', ('scope', Subroutine:: loop_fuse)), \n",
      "  DeferredTypeSymbol('jprb', ('scope', Subroutine:: loop_fuse)))>\n",
      "  <Intrinsic:: IMPLICIT NONE>\n",
      "  <Comment:: >\n",
      "  <VariableDeclaration:: n>\n",
      "  <VariableDeclaration:: var_in(n, n, n)>\n",
      "  <VariableDeclaration:: var_out(n, n, n)>\n",
      "  <VariableDeclaration:: i, j, k>"
     ]
    }
   ],
   "source": [
    "routine.spec.view()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f7903824",
   "metadata": {},
   "source": [
    "Examining the body [_Section_](https://sites.ecmwf.int/docs/loki/main/loki.ir.html#loki.ir.Section) of `loop_fuse` reveals a nested loop with three levels:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "48e2c128",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "DO k=1,n\n",
      "  DO j=1,n\n",
      "    DO i=1,n\n",
      "      var_out(i, j, k) = var_in(i, j, k)\n",
      "    END DO\n",
      "    DO i=1,n\n",
      "      var_out(i, j, k) = 2._JPRB*var_out(i, j, k)\n",
      "    END DO\n",
      "  END DO\n",
      "END DO\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(fgen(routine.body))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9cb0d7e0",
   "metadata": {},
   "source": [
    "As a first exercise, let us try to merge all the loops that use `i` as the iteration variable. This will involve using Loki's visitor utilities to traverse, search and manipulate Loki's internal representation ([_IR_](https://sites.ecmwf.int/docs/loki/main/loki.ir.html#module-loki.ir)). If you are unfamiliar with these topics, then a quick read of [`Working with Loki's internal representation`](https://github.com/ecmwf-ifs/loki/blob/main/example/02_working_with_the_ir.ipynb) is highly recommended.\n",
    "\n",
    "Let us start by identifying all the instances of [_Loop_](https://sites.ecmwf.int/docs/loki/main/loki.ir.html#loki.ir.Loop) that use `i` as the iteration variable:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "98ae76b5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Loop:: i=1:n, Loop:: i=1:n]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from loki import FindNodes,Loop,flatten\n",
    "iloops = [node for node in FindNodes(Loop).visit(routine.body) if node.variable == 'i']\n",
    "iloops"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9b8c054a",
   "metadata": {},
   "source": [
    "As the output shows, the visitor search correctly identified both loops. Merging these loops comprises of three main steps. The first is to build a new loop that contains the body of both the loops indentified above:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d236f472",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "var_out(i, j, k) = var_in(i, j, k)\n",
      "var_out(i, j, k) = 2._JPRB*var_out(i, j, k)\n"
     ]
    }
   ],
   "source": [
    "loop_body = flatten([loop.body for loop in iloops])\n",
    "new_loop = Loop(variable=iloops[0].variable, body=loop_body, bounds=iloops[0].bounds)\n",
    "print(fgen(new_loop.body))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "91b94218",
   "metadata": {},
   "source": [
    "`new_loop` now contains both the [_Assignment_](https://sites.ecmwf.int/docs/loki/main/loki.ir.html#loki.ir.Assignment) statements of the original `iloops`. The next step is to build a transformation map - a dictionary that maps the original node to its replacement:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "40214ad6",
   "metadata": {},
   "outputs": [],
   "source": [
    "loop_map = {iloops[0]: new_loop, iloops[1]: None}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6a120a02",
   "metadata": {},
   "source": [
    "Since we want to merge two loops into one, the first loop is mapped to `new_loop` and the secone is mapped to `None` i.e. it will be deleted. With the transformation map defined, we can execute the [_Transformer_](https://sites.ecmwf.int/docs/loki/main/loki.visitors.transform.html#loki.visitors.transform.Transformer):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "5cb5d7c0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SUBROUTINE loop_fuse (n, var_in, var_out)\n",
      "  USE parkind1, ONLY: jpim, jprb\n",
      "  IMPLICIT NONE\n",
      "  \n",
      "  INTEGER(KIND=jpim), INTENT(IN) :: n\n",
      "  REAL(KIND=jprb), INTENT(IN) :: var_in(n, n, n)\n",
      "  REAL(KIND=jprb), INTENT(OUT) :: var_out(n, n, n)\n",
      "  INTEGER(KIND=jpim) :: i, j, k\n",
      "  \n",
      "  DO k=1,n\n",
      "    DO j=1,n\n",
      "      DO i=1,n\n",
      "        var_out(i, j, k) = var_in(i, j, k)\n",
      "        var_out(i, j, k) = 2._JPRB*var_out(i, j, k)\n",
      "      END DO\n",
      "    END DO\n",
      "  END DO\n",
      "  \n",
      "END SUBROUTINE loop_fuse\n"
     ]
    }
   ],
   "source": [
    "from loki import Transformer\n",
    "routine.body = Transformer(loop_map).visit(routine.body)\n",
    "print(routine.to_fortran())\n",
    "\n",
    "assert len([node for node in FindNodes(Loop).visit(routine.body) if node.variable == 'i']) == 1"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9100bf60",
   "metadata": {},
   "source": [
    "We have also added an `assert` statement to programatically check the output of our loop tranformation. The `assert` will allow `pytest` to determine if this notebook continues to function as expected with future updates to Loki."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "71c989bf",
   "metadata": {},
   "source": [
    "## Loops separated by kernel call"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eea4d2ed",
   "metadata": {},
   "source": [
    "Let us now try a more complex loop fusion example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "5e1ca42c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SUBROUTINE loop_fuse_v1 (n, var_in, var_out)\n",
      "  USE parkind1, ONLY: jpim, jprb\n",
      "  IMPLICIT NONE\n",
      "  \n",
      "  INTEGER(KIND=jpim), INTENT(IN) :: n\n",
      "  REAL(KIND=jprb), INTENT(IN) :: var_in(n, n, n)\n",
      "  REAL(KIND=jprb), INTENT(OUT) :: var_out(n, n, n)\n",
      "  INTEGER(KIND=jpim) :: i, j, k\n",
      "  \n",
      "  DO k=1,n\n",
      "    DO j=1,n\n",
      "      DO i=1,n\n",
      "        var_out(i, j, k) = var_in(i, j, k)\n",
      "      END DO\n",
      "      DO i=1,n\n",
      "        var_out(i, j, k) = 2._JPRB*var_out(i, j, k)\n",
      "      END DO\n",
      "    END DO\n",
      "    \n",
      "    CALL some_kernel(n, var_out(1, 1, k))\n",
      "    \n",
      "    DO j=1,n\n",
      "      DO i=1,n\n",
      "        var_out(i, j, k) = var_out(i, j, k) + 1._JPRB\n",
      "      END DO\n",
      "      DO i=1,n\n",
      "        var_out(i, j, k) = 2._JPRB*var_out(i, j, k)\n",
      "      END DO\n",
      "    END DO\n",
      "  END DO\n",
      "  \n",
      "END SUBROUTINE loop_fuse_v1\n"
     ]
    }
   ],
   "source": [
    "routine = source['loop_fuse_v1']\n",
    "print(routine.to_fortran())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "61b39ce5",
   "metadata": {},
   "source": [
    "In `loop_fuse_v1`, there are two `j`-loops separated by a [_CallStatement_](https://sites.ecmwf.int/docs/loki/main/loki.ir.html#loki.ir.CallStatement) to a kernel that modifies `var_out`. Therefore we can only merge the `i`-loops within each `j`-loop.\n",
    "\n",
    "Using the visitor to locate the `i`-loops as was done in the previous example is inappropriate in this case, because it will locate all four `i`-loops:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "7f81d32d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Loop:: i=1:n, Loop:: i=1:n, Loop:: i=1:n, Loop:: i=1:n]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "iloops = [node for node in FindNodes(Loop).visit(routine.body) if node.variable == 'i']\n",
    "iloops"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "03c6ae62",
   "metadata": {},
   "source": [
    "Since we know the hierarchy of the loops, we can instead run the visitor on two levels. First to locate the `j`-loops, and then locate the `i`-loops within the body of each `j`-loop:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "25e64b2b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Loop:: i=1:n, Loop:: i=1:n] [Loop:: i=1:n, Loop:: i=1:n]\n"
     ]
    }
   ],
   "source": [
    "jloops = [node for node in FindNodes(Loop).visit(routine.body) if node.variable == 'j']\n",
    "iloops = [[node for node in FindNodes(Loop).visit(loop.body) if node.variable == 'i'] for loop in jloops]\n",
    "print(iloops[0], iloops[1])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c013040f",
   "metadata": {},
   "source": [
    "We can now merge the two blocks of `i`-loops:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "e1257b63",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SUBROUTINE loop_fuse_v1 (n, var_in, var_out)\n",
      "  USE parkind1, ONLY: jpim, jprb\n",
      "  IMPLICIT NONE\n",
      "  \n",
      "  INTEGER(KIND=jpim), INTENT(IN) :: n\n",
      "  REAL(KIND=jprb), INTENT(IN) :: var_in(n, n, n)\n",
      "  REAL(KIND=jprb), INTENT(OUT) :: var_out(n, n, n)\n",
      "  INTEGER(KIND=jpim) :: i, j, k\n",
      "  \n",
      "  DO k=1,n\n",
      "    DO j=1,n\n",
      "      DO i=1,n\n",
      "        var_out(i, j, k) = var_in(i, j, k)\n",
      "        var_out(i, j, k) = 2._JPRB*var_out(i, j, k)\n",
      "      END DO\n",
      "    END DO\n",
      "    \n",
      "    CALL some_kernel(n, var_out(1, 1, k))\n",
      "    \n",
      "    DO j=1,n\n",
      "      DO i=1,n\n",
      "        var_out(i, j, k) = var_out(i, j, k) + 1._JPRB\n",
      "        var_out(i, j, k) = 2._JPRB*var_out(i, j, k)\n",
      "      END DO\n",
      "    END DO\n",
      "  END DO\n",
      "  \n",
      "END SUBROUTINE loop_fuse_v1\n"
     ]
    }
   ],
   "source": [
    "for loop_block in iloops:\n",
    "    loop_body = flatten([loop.body for loop in loop_block])\n",
    "    new_loop = Loop(variable=loop_block[0].variable, body=loop_body, bounds=loop_block[0].bounds)\n",
    "    loop_map[loop_block[0]] = new_loop\n",
    "    loop_map.update({loop: None for loop in loop_block[1:]})\n",
    "routine.body = Transformer(loop_map).visit(routine.body)\n",
    "print(routine.to_fortran())\n",
    "\n",
    "assert len([node for node in FindNodes(Loop).visit(routine.body) if node.variable == 'i']) == 2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9005ba3c",
   "metadata": {},
   "source": [
    "In `loop_fuse_v1`, identifying the two blocks of `i`-loops was relatively straightforward because they were nested in different `j`-loops. Let us now try an example where all the `i`-loops and the kernel call are within the same `j`-loop:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "aece6f3d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SUBROUTINE loop_fuse_v2 (n, var_in, var_out)\n",
      "  USE parkind1, ONLY: jpim, jprb\n",
      "  IMPLICIT NONE\n",
      "  \n",
      "  INTEGER(KIND=jpim), INTENT(IN) :: n\n",
      "  REAL(KIND=jprb), INTENT(IN) :: var_in(n, n, n)\n",
      "  REAL(KIND=jprb), INTENT(OUT) :: var_out(n, n, n)\n",
      "  INTEGER(KIND=jpim) :: i, j, k\n",
      "  \n",
      "  DO k=1,n\n",
      "    DO j=1,n\n",
      "      DO i=1,n\n",
      "        var_out(i, j, k) = var_in(i, j, k)\n",
      "      END DO\n",
      "      DO i=1,n\n",
      "        var_out(i, j, k) = 2._JPRB*var_out(i, j, k)\n",
      "      END DO\n",
      "      \n",
      "      CALL some_kernel(n, var_out(1, j, k))\n",
      "      \n",
      "      DO i=1,n\n",
      "        var_out(i, j, k) = var_out(i, j, k) + 1._JPRB\n",
      "      END DO\n",
      "      DO i=1,n\n",
      "        var_out(i, j, k) = 2._JPRB*var_out(i, j, k)\n",
      "      END DO\n",
      "    END DO\n",
      "  END DO\n",
      "  \n",
      "END SUBROUTINE loop_fuse_v2\n"
     ]
    }
   ],
   "source": [
    "routine = source['loop_fuse_v2']\n",
    "print(routine.to_fortran())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5f9c763e",
   "metadata": {},
   "source": [
    "The [_FindNodes_](https://sites.ecmwf.int/docs/loki/main/loki.visitors.find.html#loki.visitors.find.FindNodes) visitor we used previously returns an ordered list of nodes that match a specified type. Previously we only searched for nodes of type [_Loop_](https://sites.ecmwf.int/do/docs/loki/main/loki.ir.html#loki.ir.Loop). We can easily extend this to also search for [_CallStatement_](https://sites.ecmwf.int/docs//docs/loki/main/loki.ir.html#loki.ir.CallStatement) by passing both node-types as a tuple when initializing the visitor:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "ddf60a6a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Loop:: i=1:n, Loop:: i=1:n, Call:: some_kernel, Loop:: i=1:n, Loop:: i=1:n]"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from loki import CallStatement\n",
    "jloops = [loop for loop in FindNodes(Loop).visit(routine.body) if loop.variable == 'j']\n",
    "assert len(jloops) == 1\n",
    "nodes = FindNodes((CallStatement,Loop)).visit(jloops[0].body)\n",
    "nodes"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bb60a77b",
   "metadata": {},
   "source": [
    "By first using `FindNodes` to locate the `j`-loop, and then applying `FindNodes` to that we have built an ordered list (`nodes`) containing just the `i`-loops and the kernel call. We can now identify the loops that appear before and after the kernel call:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "eec8d2b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "call_loc = [count for count,node in enumerate(nodes) if isinstance(node,CallStatement)][0]\n",
    "iloops[0] = [node for node in nodes[:call_loc] if node.variable == 'i']\n",
    "iloops[1] = [node for node in nodes[call_loc+1:] if node.variable == 'i']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "737c3769",
   "metadata": {},
   "source": [
    "We can now fuse the two blocks of `i`-loops:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "7b751473",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SUBROUTINE loop_fuse_v2 (n, var_in, var_out)\n",
      "  USE parkind1, ONLY: jpim, jprb\n",
      "  IMPLICIT NONE\n",
      "  \n",
      "  INTEGER(KIND=jpim), INTENT(IN) :: n\n",
      "  REAL(KIND=jprb), INTENT(IN) :: var_in(n, n, n)\n",
      "  REAL(KIND=jprb), INTENT(OUT) :: var_out(n, n, n)\n",
      "  INTEGER(KIND=jpim) :: i, j, k\n",
      "  \n",
      "  DO k=1,n\n",
      "    DO j=1,n\n",
      "      DO i=1,n\n",
      "        var_out(i, j, k) = var_in(i, j, k)\n",
      "        var_out(i, j, k) = 2._JPRB*var_out(i, j, k)\n",
      "      END DO\n",
      "      \n",
      "      CALL some_kernel(n, var_out(1, j, k))\n",
      "      \n",
      "      DO i=1,n\n",
      "        var_out(i, j, k) = var_out(i, j, k) + 1._JPRB\n",
      "        var_out(i, j, k) = 2._JPRB*var_out(i, j, k)\n",
      "      END DO\n",
      "    END DO\n",
      "  END DO\n",
      "  \n",
      "END SUBROUTINE loop_fuse_v2\n"
     ]
    }
   ],
   "source": [
    "for loop_block in iloops:\n",
    "    loop_body = flatten([loop.body for loop in loop_block])\n",
    "    new_loop = Loop(variable=loop_block[0].variable, body=loop_body, bounds=loop_block[0].bounds)\n",
    "    loop_map[loop_block[0]] = new_loop\n",
    "    loop_map.update({loop: None for loop in loop_block[1:]})\n",
    "routine.body = Transformer(loop_map).visit(routine.body)\n",
    "print(routine.to_fortran())\n",
    "\n",
    "assert len([node for node in FindNodes(Loop).visit(routine.body) if node.variable == 'i']) == 2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5ec146b7",
   "metadata": {},
   "source": [
    "## Using the built-in `loop_fusion` utility"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5a91b698",
   "metadata": {},
   "source": [
    "To facilitate loop fusion and make it readily available to users, Loki has a built-in `loop_fusion` transformation utility. However, currently this relies on manually annotating the loops with `!$loki` pragmas. To illustrate how it works, we outline its mechanics in the following:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "c79326a9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SUBROUTINE loop_fuse_pragma (n, var_in, var_out)\n",
      "  USE parkind1, ONLY: jpim, jprb\n",
      "  IMPLICIT NONE\n",
      "  \n",
      "  INTEGER(KIND=jpim), INTENT(IN) :: n\n",
      "  REAL(KIND=jprb), INTENT(IN) :: var_in(n, n, n)\n",
      "  REAL(KIND=jprb), INTENT(OUT) :: var_out(n, n, n)\n",
      "  INTEGER(KIND=jpim) :: i, j, k\n",
      "  \n",
      "  DO k=1,n\n",
      "    DO j=1,n\n",
      "      \n",
      "!$loki loop-fusion group( g1 )\n",
      "      DO i=1,n\n",
      "        var_out(i, j, k) = var_in(i, j, k)\n",
      "      END DO\n",
      "!$loki loop-fusion group( g1 )\n",
      "      DO i=1,n\n",
      "        var_out(i, j, k) = 2._JPRB*var_out(i, j, k)\n",
      "      END DO\n",
      "      \n",
      "      CALL some_kernel(n, var_out(1, j, k))\n",
      "      \n",
      "!$loki loop-fusion group( g2 )\n",
      "      DO i=1,n\n",
      "        var_out(i, j, k) = var_out(i, j, k) + 1._JPRB\n",
      "      END DO\n",
      "!$loki loop-fusion group( g2 )\n",
      "      DO i=1,n\n",
      "        var_out(i, j, k) = 2._JPRB*var_out(i, j, k)\n",
      "      END DO\n",
      "      \n",
      "    END DO\n",
      "  END DO\n",
      "  \n",
      "END SUBROUTINE loop_fuse_pragma\n"
     ]
    }
   ],
   "source": [
    "routine = source['loop_fuse_pragma']\n",
    "print(routine.to_fortran())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "99e6b49a",
   "metadata": {},
   "source": [
    "The routine `loop_fuse_pragma` is identical to `loop_fuse_v2` except for the `i`-loops being preceded by `!$loki loop-fusion` pragmas. The loops that are candidates for fusion have been assigned to the same group i.e. `g1` or `g2`. Examining the body of `loop_fuse_pragma` reveals the following:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "cccf8d6a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<Section::>\n",
      "  <Comment:: >\n",
      "  <Loop:: k=1:n>\n",
      "    <Loop:: j=1:n>\n",
      "      <Comment:: >\n",
      "      <Pragma:: loki loop-fusion g...>\n",
      "      <Loop:: i=1:n>\n",
      "        <Assignment:: var_out(i, j, k) = var_in(i, j, k)>\n",
      "      <Pragma:: loki loop-fusion g...>\n",
      "      <Loop:: i=1:n>\n",
      "        <Assignment:: var_out(i, j, k) = 2._JPRB*var_out(i, j, k)>\n",
      "      <Comment:: >\n",
      "      <Call:: some_kernel>\n",
      "      <Comment:: >\n",
      "      <Pragma:: loki loop-fusion g...>\n",
      "      <Loop:: i=1:n>\n",
      "        <Assignment:: var_out(i, j, k) = var_out(i, j, k) + 1._JPRB>\n",
      "      <Pragma:: loki loop-fusion g...>\n",
      "      <Loop:: i=1:n>\n",
      "        <Assignment:: var_out(i, j, k) = 2._JPRB*var_out(i, j, k)>\n",
      "      <Comment:: >\n",
      "  <Comment:: >"
     ]
    }
   ],
   "source": [
    "routine.body.view()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aab07764",
   "metadata": {},
   "source": [
    "One of the difficulties in parsing pragmas is that it is not always immediately clear whether they should be associated with the subsequent or preceeding node, or should stand alone; as examples think of the differing behaviours of `!$omp do`, `!$omp end do` and `!$omp barrier`. Therefore in Loki, pragmas are not attached by default to other nodes. Instead, Loki treats pragmas essentially like comments, but gives them a separate node-type to easily distinguish them.\n",
    "\n",
    "In situations where we do wish to associate pragmas with certain nodes, we can do so using the [_pragmas_attached_](https://sites.ecmwf.int/docs/loki/main/loki.pragma_utils.html#loki.pragma_utils.pragmas_attached) context manager:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "88a095e5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<Section::>\n",
      "  <Comment:: >\n",
      "  <Loop:: k=1:n>\n",
      "    <Loop:: j=1:n>\n",
      "      <Comment:: >\n",
      "      <Loop:: i=1:n>\n",
      "        <Assignment:: var_out(i, j, k) = var_in(i, j, k)>\n",
      "      <Loop:: i=1:n>\n",
      "        <Assignment:: var_out(i, j, k) = 2._JPRB*var_out(i, j, k)>\n",
      "      <Comment:: >\n",
      "      <Call:: some_kernel>\n",
      "      <Comment:: >\n",
      "      <Loop:: i=1:n>\n",
      "        <Assignment:: var_out(i, j, k) = var_out(i, j, k) + 1._JPRB>\n",
      "      <Loop:: i=1:n>\n",
      "        <Assignment:: var_out(i, j, k) = 2._JPRB*var_out(i, j, k)>\n",
      "      <Comment:: >\n",
      "  <Comment:: >"
     ]
    }
   ],
   "source": [
    "from loki import pragmas_attached\n",
    "with pragmas_attached(routine,Loop):\n",
    "    routine.body.view()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2e768ca6",
   "metadata": {},
   "source": [
    "We can now visit the loops and sort them into their respective fusion groups:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "16c23deb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "g1: [Loop:: i=1:n, Loop:: i=1:n], g2: [Loop:: i=1:n, Loop:: i=1:n]\n"
     ]
    }
   ],
   "source": [
    "from loki import is_loki_pragma,get_pragma_parameters,Pragma\n",
    "from collections import defaultdict\n",
    "\n",
    "fusion_groups = defaultdict(list)\n",
    "with pragmas_attached(routine,Loop):\n",
    "    for loop in FindNodes(Loop).visit(routine.body):\n",
    "        if is_loki_pragma(loop.pragma, starts_with='loop-fusion'):                         \n",
    "            parameters = get_pragma_parameters(loop.pragma, starts_with='loop-fusion')\n",
    "            group = parameters.get('group', 'default')\n",
    "            fusion_groups[group] += [loop]\n",
    "\n",
    "print(f\"g1: {fusion_groups['g1']}, g2: {fusion_groups['g2']}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4e327326",
   "metadata": {},
   "source": [
    "`fusion_groups` is now a dictionary with keys for the two fusion groups, and the associated `Loop` nodes are values for each key. We can now create and apply a transformation map similar to how it was done in the previous examples:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "625d0640",
   "metadata": {},
   "outputs": [],
   "source": [
    "for group,loops in fusion_groups.items():\n",
    "    loop_body = flatten([loop.body for loop in loops])\n",
    "    new_loop = Loop(variable=loops[0].variable, body=loop_body, bounds=loops[0].bounds)\n",
    "    loop_map[loops[0]] = new_loop\n",
    "    loop_map.update({loop: None for loop in loops[1:]})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0f2f4222",
   "metadata": {},
   "source": [
    "Since `!$loki` pragmas are only intended to pass instructions/hints to Loki on source manipulations, and aren't needed for the eventual compilation, we can also remove them from the routine:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "785e12df",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SUBROUTINE loop_fuse_pragma (n, var_in, var_out)\n",
      "  USE parkind1, ONLY: jpim, jprb\n",
      "  IMPLICIT NONE\n",
      "  \n",
      "  INTEGER(KIND=jpim), INTENT(IN) :: n\n",
      "  REAL(KIND=jprb), INTENT(IN) :: var_in(n, n, n)\n",
      "  REAL(KIND=jprb), INTENT(OUT) :: var_out(n, n, n)\n",
      "  INTEGER(KIND=jpim) :: i, j, k\n",
      "  \n",
      "  DO k=1,n\n",
      "    DO j=1,n\n",
      "      \n",
      "      DO i=1,n\n",
      "        var_out(i, j, k) = var_in(i, j, k)\n",
      "        var_out(i, j, k) = 2._JPRB*var_out(i, j, k)\n",
      "      END DO\n",
      "      \n",
      "      CALL some_kernel(n, var_out(1, j, k))\n",
      "      \n",
      "      DO i=1,n\n",
      "        var_out(i, j, k) = var_out(i, j, k) + 1._JPRB\n",
      "        var_out(i, j, k) = 2._JPRB*var_out(i, j, k)\n",
      "      END DO\n",
      "      \n",
      "    END DO\n",
      "  END DO\n",
      "  \n",
      "END SUBROUTINE loop_fuse_pragma\n"
     ]
    }
   ],
   "source": [
    "routine_copy = routine.clone()\n",
    "routine.body = Transformer(loop_map).visit(routine.body)\n",
    "pragma_map = {pragma: None for pragma in FindNodes(Pragma).visit(routine.body)}\n",
    "routine.body = Transformer(pragma_map).visit(routine.body)\n",
    "print(routine.to_fortran())\n",
    "\n",
    "assert len([node for node in FindNodes(Loop).visit(routine.body) if node.variable == 'i']) == 2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ac7e3b33",
   "metadata": {},
   "source": [
    "You may have noticed in the previous code-cell we made a copy of the object `routine` before applying a transformation to it. We can now apply the `loop_fusion` utility directly on `routine_copy` and compare the results:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "164b5054",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "loop_fuse_pragma: fused 4 loops in 2 groups.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SUBROUTINE loop_fuse_pragma (n, var_in, var_out)\n",
      "  USE parkind1, ONLY: jpim, jprb\n",
      "  IMPLICIT NONE\n",
      "  \n",
      "  INTEGER(KIND=jpim), INTENT(IN) :: n\n",
      "  REAL(KIND=jprb), INTENT(IN) :: var_in(n, n, n)\n",
      "  REAL(KIND=jprb), INTENT(OUT) :: var_out(n, n, n)\n",
      "  INTEGER(KIND=jpim) :: i, j, k\n",
      "  \n",
      "  DO k=1,n\n",
      "    DO j=1,n\n",
      "      \n",
      "      ! Loki loop-fusion group(g1)\n",
      "      DO i=1,n\n",
      "        ! Loki loop-fusion - body 0 begin\n",
      "        var_out(i, j, k) = var_in(i, j, k)\n",
      "        ! Loki loop-fusion - body 0 end\n",
      "        ! Loki loop-fusion - body 1 begin\n",
      "        var_out(i, j, k) = 2._JPRB*var_out(i, j, k)\n",
      "        ! Loki loop-fusion - body 1 end\n",
      "      END DO\n",
      "      ! Loki loop-fusion group(g1) - loop hoisted\n",
      "      \n",
      "      CALL some_kernel(n, var_out(1, j, k))\n",
      "      \n",
      "      ! Loki loop-fusion group(g2)\n",
      "      DO i=1,n\n",
      "        ! Loki loop-fusion - body 0 begin\n",
      "        var_out(i, j, k) = var_out(i, j, k) + 1._JPRB\n",
      "        ! Loki loop-fusion - body 0 end\n",
      "        ! Loki loop-fusion - body 1 begin\n",
      "        var_out(i, j, k) = 2._JPRB*var_out(i, j, k)\n",
      "        ! Loki loop-fusion - body 1 end\n",
      "      END DO\n",
      "      ! Loki loop-fusion group(g2) - loop hoisted\n",
      "      \n",
      "    END DO\n",
      "  END DO\n",
      "  \n",
      "END SUBROUTINE loop_fuse_pragma\n"
     ]
    }
   ],
   "source": [
    "from loki import do_loop_fusion\n",
    "do_loop_fusion(routine_copy)\n",
    "pragma_map = {pragma: None for pragma in FindNodes(Pragma).visit(routine_copy.body)}\n",
    "routine_copy.body = Transformer(pragma_map).visit(routine_copy.body)\n",
    "print(routine_copy.to_fortran())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e2534fd4",
   "metadata": {},
   "source": [
    "As we can see, the built-in Loki utility `loop_fusion` achieves an identical result to our manual transformation. "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
