#!/usr/bin/python3
import argparse
import subprocess
import os
import io
import sys
import datetime
import time
import csv
import numpy
            
def parseArgs():
    parser = argparse.ArgumentParser(description='Log process performance (CPU, memory usage)')
    parser.add_argument('pname', type=str, help='process name to get info from')
    parser.add_argument('duration', type=int, help='time(seconds) to keep the metrics being recorded')
    parser.add_argument("--output", type=str, help='IGNORED (no output file used any more)')
    parser.add_argument("--retry", default=False, action="store_true", help='if the pid of the process is not found on startup, keep trying')
    args = parser.parse_args()
    return args

class SystemPerformanceMon(object):
    
    def __init__(self, processName, duration, retry=False):
        self.processName = processName
        # duration is in seconds
        self.duration = duration
        self.retry = retry
        self.stat_tool_output = None # replacement for outputfile

    def findProcessesByProcessName(self):
        cmdCall1 = f"ps aux | grep -i {self.processName}"
        cmdCall2 = "| awk {'print $2'}"
        cmdCall = cmdCall1 + cmdCall2
        result = subprocess.run([cmdCall], stdout=subprocess.PIPE, shell=True)
        pid = result.stdout.decode("utf-8", "strict").strip()
        return pid


    def startMonitoring(self):
        start_time = datetime.datetime.now()
        pidStr = ""
        while True:
            pidStr = self.findProcessesByProcessName()
            if pidStr == "":
                if self.retry:
                    # we should keep retrying until we find the process
                    current_time = datetime.datetime.now()
                    minutes = abs((current_time - start_time).seconds / 60)
                    MAX_MINUTES = 5
                    if minutes > MAX_MINUTES:
                        print(f"Pid was not found for process name: {self.processName} for {minutes} minutes")
                        return False
                    time.sleep(1)
                    continue
                else:
                    print(f"Pid was not found for process name: {self.processName}")
                    return 2
            else:
                # pid was found, break the infinite loop
                # print("FOUND PID for {self.processName}:" , pidStr)
                break

        # if there are two services running with the same name, we pick the oldest one 
        # (the one with the smaller id). In most cases this is never going to happen.
        # Right now this will only happen when we are testing calculation nodes.
        arr = pidStr.splitlines()
        pid = int(arr[0])
        
        interv = 5
        repeat = int(self.duration/interv)
        # pidstat options: u=CPU, r=MEM, h=one-line, H=no-whitespace-in-timestamp
        cmdCall = f"pidstat -hurH {interv} {repeat} -p {pid}"

        result = subprocess.run([cmdCall], stdout=subprocess.PIPE, stderr=subprocess.STDOUT, shell=True)
        if result.returncode != 0:
            print("Failed to execute stat-tool", result)
            return 3

        self.stat_tool_output = self.outputToString(result)

        return 0


    def outputToString(self, subprocessOutput):
        return subprocessOutput.stdout.decode("utf-8", "strict").strip()


    def processData(self):

        output = self.stat_tool_output
        # example output:
        #
        # Linux 5.17.12-100.fc34.x86_64 (eltmal24)        01/11/2024      _x86_64_        (16 CPU)
        #
        # # Time        UID       PID    %usr %system  %guest   %wait    %CPU   CPU  minflt/s  majflt/s     VSZ     RSS   %MEM  Command
        # 1704977039   6173      5548    0.00    0.00    0.00    0.00    0.00    15      0.00      0.00 1173200    3400   0.01  redis-server
        #
        # # Time        UID       PID    %usr %system  %guest   %wait    %CPU   CPU  minflt/s  majflt/s     VSZ     RSS   %MEM  Command
        # 1704977041   6173      5548    0.00    0.50    0.00    0.00    0.50    15      0.00      0.00 1173200    3400   0.01  redis-server

        # split output into single lines, remove excessive whitespace,
        # then do some gymnastics on the unfortunate pidstat output format.
        lines = [" ".join(line.split()) for line in output.split('\n')]
        data = []
        data.append(lines[2][2:]) # csv header
        data.extend([line for line in lines[3:] if line != "" and not line.startswith("#") ])

        # now we have:
        # ['Time UID PID %usr %system %guest %wait %CPU CPU minflt/s majflt/s VSZ RSS %MEM Command',
        #  '1704977827 6173 5548 0.00 0.00 0.00 0.00 0.00 1 0.00 0.00 1173200 3400 0.01 redis-server',
        # '1704977829 6173 5548 0.00 0.00 0.00 0.00 0.00 1 0.00 0.00 1173200 3400 0.01 redis-server']

        # read as csv, extract interesting columns
        reader = csv.DictReader(data, delimiter=" ")
        results = list (reader)
        cpu_values = numpy.array( [ float(r["%CPU"]) for r in results ] )
        mem_values = numpy.array( [ float(r["%MEM"]) for r in results ] )


        # 1. calculate average

        cpuAverage = numpy.average (cpu_values)
        memAverage = numpy.average (mem_values)

        # 2. calculate deviations

        # Not sure how reasonable this metric is, it reminds of std-dev.
        # I would have gone for trend (numpy.polyfit) and variance.
        # For now, just ported the algorithm from SQL to Python.

        # we measured every 5 secs, so 120 measurements per 10 minutes
        timeWindow = 600
        interv = 5
        window_size = int(timeWindow / interv)

        #print (cpuAverage, len(cpu_values), cpu_values)

        def split (values, window_size):
            ret = []
            for i in range (0,len(values), window_size):
                slice = values[i:i+window_size]
                ret.append (slice)
            return ret

        def metric (slices, avg_over_all):
            avg_over_all = max (avg_over_all, 0.01) # prevent division by zero

            ret = []
            for slice in slices:
                avg = sum(slice)/len(slice)
                met = abs( avg/avg_over_all -1 ) # metrics by DKU
                #print (slice, avg, met)
                ret.append (met)
            return ret

        cpu_met = []
        mem_met = []
        try:
            not_all_cpu_values = cpu_values[1:-1] # ignore first and last measurement (from dku)
            not_all_mem_values = mem_values[1:-1] # ignore first and last measurement (from dku)

            cpu_met = metric (split( not_all_cpu_values, window_size), cpuAverage )
            mem_met = metric (split( not_all_mem_values, window_size), memAverage )
        except Exception as exc:
            print("Error processing deviations:", exc)
            return 11

        maxDeviation = 2
        deviations = []

        for cpu,memory in zip (cpu_met, mem_met):
            if cpu > maxDeviation:
                deviations.append(line)
            elif memory > maxDeviation:
                deviations.append(line)

        if len(deviations) > 0:
            print(f"Deviations going over {maxDeviation}% in either memory or cpu:")
            currentLine = 0
            for line in deviations:
                currentLine += 1
                if currentLine > 100:
                    more = len(deviations) - currentLine
                    print("... and {more} more deviations")
                    break
                print(line)
            return 12

        return 0


if __name__ == '__main__':
    args= parseArgs()
    sysMon = SystemPerformanceMon(args.pname, args.duration, args.retry)
    try:
        exitCode = sysMon.startMonitoring()
        if exitCode != 0:
            print("Failed to start monitoring")
            sys.exit(exitCode)

        exitCode = sysMon.processData()
        if exitCode == 0:
            sys.exit(0)
        else:
            print("Something went wrong when processing the monitoring data")
            sys.exit(exitCode)
    except KeyboardInterrupt:
        print("KeyboardInterrupt")
