#!/usr/bin/env python3

# rknputop
# Copyright (C) 2024 Broox Technologies Ltd.

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
#      but WITHOUT ANY WARRANTY; without even the implied warranty of
#      MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#     GNU General Public License for more details.
#
#      You should have received a copy of the GNU General Public License
#      along with this program.  If not, see <http://www.gnu.org/licenses/>.os

import sys
import re
import time
import threading
import queue
import math
import tty
import os
import termios
import subprocess
import optparse

try:
    import plotext as plt
except:
    print("Install plotext with `pip3 install plotext`")
    sys.exit(-1)
try:
    import psutil
except:
    print("Install psutil with `pip3 install psutil`")
    sys.exit(-1)


def getkey():
    old_settings = termios.tcgetattr(sys.stdin)
    tty.setcbreak(sys.stdin.fileno())
    try:
        while True:
            b = os.read(sys.stdin.fileno(), 3).decode()
            if len(b) == 3:
                k = ord(b[2])
            else:
                k = ord(b)
            key_mapping = {
                127: "backspace",
                10: "return",
                32: "space",
                9: "tab",
                27: "esc",
                65: "up",
                66: "down",
                67: "right",
                68: "left",
            }
            return key_mapping.get(k, chr(k))
    finally:
        termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_settings)


def readload():
    rkload = None
    try:
        with open("/sys/kernel/debug/rknpu/load", "r") as f:
            rkload = f.read()
    except:
        print("Cannot read /sys/kernel/debug/rknpu/load. Run with `sudo`")
        sys.exit(-2)
    return rkload

def readrgaload():
    """
    num of scheduler = 3
================= load ==================
scheduler[0]: rga3_core0
         load = 0%
-----------------------------------
scheduler[1]: rga3_core1
         load = 0%
-----------------------------------
scheduler[2]: rga2
         load = 0%
-----------------------------------

"""
    rgaload = {}
    try:
        with open("/sys/kernel/debug/rkrga/load", "r") as f:
            lines = f.readlines()
            csched = ""
            for line in lines:
                if "-" in line or "= load =" in line:
                    continue
                line = line.strip()
                if line.startswith("num of scheduler ="):
                    sc = int(line.split("=")[1])
                elif line.startswith("scheduler"):
                    csched = line.split(":")[1].strip()
                elif line.startswith("load ="):
                    l = line.split("=")[1].replace("%","").strip()
                    rgaload[csched] = float(l)
    except:
        print("Cannot read /sys/kernel/debug/rkrga/load Run with `sudo`")
        sys.exit(-2)
    return sc, rgaload


def readkver():
    rkver = ":unknown"
    try:
        with open("/sys/kernel/debug/rknpu/version", "r") as f:
            rkver = f.read()
    except:
        print("Cannot read /sys/kernel/debug/rknpu/load. Run with `sudo`")
        sys.exit(-2)
    return rkver.split(":")[1]


def readlibver():
    ver = ""
    try:
        v = subprocess.check_output(
            'strings /usr/lib/librknnrt.so | grep "librknnrt version:"', shell=True
        )
        ver = v.decode("ascii").strip()
        ver = ver.replace("librknnrt version: ", "")
        ver = ver.replace("(", "\n").replace(")", "")
    except:
        pass
    return ver

def readrgaver():
    rgaver = ""
    try:
        with open("/sys/kernel/debug/rkrga/driver_version", "r") as f:
            rgaver = f.read()
    except:
        print("Cannot read /sys/kernel/debug/rkrga/driver_version. Run with `sudo`")
        sys.exit(-2)
    return rgaver.split(":")[1] 


def parseload(txt):
    res = []

    if "Core0:" in txt:
        # multi core NPU
        items = re.findall(r"Core(\d+):\s*(\d+)%", txt)
        for core, pct in items:
            res.append(int(pct))
    else:
        # single core NPU
        items = re.findall(r"NPU load:\s+(\d+)%", txt)
        for pct in items:
            res.append(int(pct))
    return res


def _kget_thread(q):
    while True:
        q.put(getkey())


def tempbars():
    labels = []
    temps = []
    tops = []

    for k, v in psutil.sensors_temperatures().items():
        if k.startswith("test"):
            continue
        if len(v) == 0:
            continue
        if v[0].current is None or v[0].high is None:
            continue
        labels.append(k.replace("_thermal", ""))
        temps.append(math.floor(v[0].current))
        tops.append(math.floor(v[0].high))

    return labels, temps, tops


def plot_npu_lines(plt, n_pts, samples):
    plt.title(f"NPU Load")
    plt.ylim(lower=0, upper=100)
    for k in range(n_pts):
        plt.plot([s[k] for s in samples], label=f"Core {k}")


def plot_npu_bars(plt, n_pts, samples):
    plt.title(f"NPU Load")
    plt.ylim(lower=0, upper=100)
    bars = samples[-1]
    plt.bar([str(i) for i in range(len(bars))], bars, width=1 / 5)
    [
        plt.text(f"{bars[i]}%", x=i + 1, y=bars[i] + 1.5, alignment="center")
        for i in range(len(bars))
    ]

def plot_npu_rga_bars(plt, n_pts, samples, rgaload):
    plt.title(f"NPU/RGA Load")
    bars = samples[-1]
    labels = [f"NPU{i}" for i in range(len(bars))]
    labels += list(rgaload.keys())
    bars += list(rgaload.values())
    tops = [100.0 for _ in bars] 
    plt.stacked_bar(labels, [bars, tops], width=1 / 5, orientation='h')
    [
        plt.text(f"{bars[i]}%", x=18, y=i+1, alignment="center")
        for i in range(len(bars))
    ]


if __name__ == "__main__":
    parser = optparse.OptionParser("Show different NPU/CPU stats")
    parser.add_option(
        "-n",
        "--npu-only",
        dest="npuonly",
        default=False,
        help="Only show the NPU load",
        action="store_true",
    )
    parser.add_option(
        "-b",
        "--npu-bars",
        dest="npubars",
        default=False,
        action="store_true",
        help="Show the NPU with bars instead of lines",
    )
    parser.add_option(
        "-r",
        "--rga",
        dest="rgaload",
        default=False,
        action="store_true",
        help="Show the RGA load",
    )
    opts, _ = parser.parse_args()

    rkload = readload()
    if rkload is None or len(rkload) == 0:
        print("Cannot read anything in /sys/kernel/debug/rknpu/load. Run with `sudo`")
        sys.exit(-2)

    rkver = readkver().strip()
    libver = readlibver().strip()
    rgaver = readrgaver().strip()
    pts = parseload(rkload)
    n_pts = len(pts)

    MAX_SAMPLES = 100
    samples = []

    input_queue = queue.Queue()
    input_thread = threading.Thread(target=_kget_thread, args=(input_queue,))
    input_thread.daemon = True
    input_thread.start()

    while True:
        loads = parseload(readload())
        samples.append(loads)
        if len(samples) > MAX_SAMPLES:
            samples.pop(0)
        plt.clf()

        if opts.npuonly:
            if opts.npubars:
                plot_npu_bars(plt, n_pts, samples)
            else:
                plot_npu_lines(plt, n_pts, samples)
        else:
            plt.subplots(2, 2)
            topper = plt.subplot(1,1)
            if opts.rgaload:
                _, rgal = readrgaload()
                plot_npu_rga_bars(topper, n_pts, samples, rgal) 
            elif opts.npubars:
                plot_npu_bars(topper, n_pts, samples)
            else:
                plot_npu_lines(topper, n_pts, samples)

            # CPU Cores
            cpus = psutil.cpu_percent(percpu=True)
            plt.subplot(1, 2).title("CPU Load per core")
            plt.subplot(1, 2).ylim(lower=0, upper=100)
            plt.subplot(1, 2).bar([str(i) for i in range(len(cpus))], cpus, width=1 / 5)
            [
                plt.subplot(1, 2).text(
                    f"{cpus[i]}%", x=i + 1, y=cpus[i] + 1.5, alignment="center"
                )
                for i in range(len(cpus))
            ]

            # Thermals
            labels, temps, tops = tempbars()
            plt.subplot(2, 1).title("Thermals")
            plt.subplot(2, 1).stacked_bar(
                labels, [temps, tops], orientation="h", width=1 / 5
            )
            [
                plt.subplot(2, 1).text(f"{temps[i]}ºC", x=18, y=i + 1)
                for i in range(len(temps))
            ]
            plt.subplot(2, 2).subplots(2, 1)

            # Memory
            mem = psutil.virtual_memory().percent
            swp = psutil.swap_memory().percent
            plt.subplot(2, 2).subplot(1, 1).title("Memory")
            plt.subplot(2, 2).subplot(1, 1).stacked_bar(
                ["Memory", "Swap"],
                [[mem, swp], [100 - mem, 100 - swp]],
                orientation="h",
                width=1 / 5,
            )
            # Versions
            plt.subplot(2, 2).subplot(2, 1).indicator(
                    f"librknnrt: {libver}\nKmod: {rkver} RGA: {rgaver}" 
            )

        plt.show()
        plt.sleep(1)

        if not input_queue.empty():
            k = input_queue.get()
            if k == "Q" or k == "q" or k == "esc":
                sys.exit(0)
