#!/usr/bin/python3

import json
import time
import subprocess
import shlex
import datetime
import os
import sys
import threading
import queue
import socket
import configparser
import logging
import signal
import tempfile
import string
import random
import hashlib
from http.server import HTTPServer, BaseHTTPRequestHandler
from socketserver import ThreadingMixIn
from xml.etree import ElementTree
my_env = os.environ

def resource_to_filename(url):
    # create something like https_github_com_SUSE_doc_sle from https://github.com/SUSE/doc-sle
    replace = "/\\-.,:;#+`´{}()[]!\"§$%&"
    for char in replace:
        url = str(url).replace(char, '_')
    return url

class RepoLock:
    def __init__(self, repo_dir, thread_id, gitLocks, gitLocksLock):
        self.gitLocks = gitLocks
        self.gitLocksLock = gitLocksLock
        self.resource_name = resource_to_filename(repo_dir)
        self.gitLocksLock.acquire()
        if self.resource_name not in gitLocks:
            self.gitLocks[self.resource_name] = threading.Lock()
        self.gitLocksLock.release()
        self.acquired = False
        self.thread_id = thread_id

    def acquire(self, blocking=True):
        if self.gitLocks[self.resource_name].acquire(blocking):
            self.acquired = True
            logger.debug("Thread %i: Acquired lock %s." % (self.thread_id, self.resource_name))
            return True
        return False

    def release(self):
        if self.acquired:
            self.gitLocks[self.resource_name].release()
            self.acquired = False

class Deliverable:
    def __init__(self, local_repo_build_dir, build_instruction, dc_file, build_format, target):
        self.local_repo_build_dir = local_repo_build_dir
        self.language = build_instruction['lang']
        self.dc_file = dc_file
        self.build_format = build_format
        self.target = target
        self.product_id = build_instruction['product']
        self.docset_id = build_instruction['docset']
        logger.debug("Created deliverable: %s %s %s %s %s" % (local_repo_build_dir, self.language, dc_file, build_format, target['target_path']))

    def run(self, thread_id):
        logger.info("Thread %i: Building deliverable" % thread_id)
        commands = {}

        # Write XSLT parameters to temp file
        n = 0
        xslt_params_file = tempfile.mkstemp(text=True)
        xslt_params = ""
        commands[n] = {}
        commands[n]['cmd'] = "echo \"%s\" > %s" % (xslt_params, xslt_params_file[1])
        commands[n]['ret_val'] = 0

        # Write daps parameters to temp file
        n += 1
        daps_params_file = tempfile.mkstemp(text=True)
        daps_params = "%s %s" % (
                                    "--remarks" if self.target['remarks'] == "yes" else "",
                                    "--draft" if self.target['draft'] == "yes" else ""
                                )
        commands[n] = {}
        commands[n]['cmd'] = "echo \"%s\" > %s" % (daps_params, daps_params_file[1])
        commands[n]['ret_val'] = 0

        # Run daps in the docker container, copy results to a build target directory
        n += 1
        tmp_build_target = tempfile.mkdtemp()
        dc_full_path = os.path.join(self.local_repo_build_dir, self.dc_file)
        commands[n] = {}
        commands[n]['cmd'] = "/usr/bin/d2d_runner -v=0 -x %s -d %s -o %s -i %s -f %s %s" % (
                                                xslt_params_file[1],       # -x
                                                daps_params_file[1],       # -d
                                                tmp_build_target,          # -o
                                                self.local_repo_build_dir, # -i
                                                self.build_format,         # -f
                                                self.dc_file               # last param
                                                )
        commands[n]['ret_val'] = 0

        # rsync build target directory to backup path
        n += 1
        backup_target = self.target['backup_path'] + os.path.join(self.language, self.product_id, self.docset_id, self.build_format)
        commands[n] = {}
        commands[n]['cmd'] = "rsync %s %s" % (tmp_build_target, backup_target)
        commands[n]['ret_val'] = 0

        # rsync built target directory with web server
        n += 1
        deploy_target = self.target['target_path'] + os.path.join(self.language, self.product_id, self.docset_id, self.build_format)
        commands[n] = {}
        commands[n]['cmd'] = "rsync %s %s" % (tmp_build_target, deploy_target)
        commands[n]['ret_val'] = 0

        # remove daps parameter file
        n += 1
        commands[n] = {}
        commands[n]['cmd'] = "rm %s" % (daps_params_file[1])
        commands[n]['ret_val'] = 0

        # remove xslt parameter file
        n += 1
        commands[n] = {}
        commands[n]['cmd'] = "rm %s" % (xslt_params_file[1])
        commands[n]['ret_val'] = 0

        # build target directory
        n += 1
        commands[n] = {}
        commands[n]['cmd'] = "rm -rf %s" % (tmp_build_target)
        commands[n]['ret_val'] = 0

        for i in range(0,n + 1):
            cmd = shlex.split(commands[i]['cmd'])
            logger.debug("Thread %i: %s" % (thread_id, commands[i]['cmd']))
            time.sleep(1)
            s = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
            s.communicate()[0]
            if commands[i]['ret_val'] is not None and not commands[i]['ret_val'] == int(s.returncode):
                logger.warning("Build failed! Unexpected return value %i for '%s'" % (s.returncode, commands[i]['cmd']) )
                return False

class DocConfParser:
    def __init__(self, build_instruction, config, gitLocks, gitLocksLock, thread_id):
        self.deliverables = queue.Queue()
        self.thread_id = thread_id
        if self.validate(build_instruction, config):
            self.initialized = True
            self.build_instruction = build_instruction
            self.config = config
            self.read_conf_dir()
            self.git_lock = RepoLock(resource_to_filename(self.remote_repo), thread_id, gitLocks, gitLocksLock)
            self.prepare_repo()
            self.compare_commit_id()
        else:
            self.initialized = False
        return

    def __del__(self):
        logger.debug("Cleaning up %s" % json.dumps(self.build_instruction))
        cmd = shlex.split("rm -rf %s" % self.local_repo_build_dir)
        s = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        s.communicate()[0]

    def __str__(self):
        return json.dumps(self.build_instruction)

    def __dict__(self):
        return self.build_instruction

    def __getitem__(self, arg):
        print(arg)
        return self.build_instruction

    def read_conf_dir(self):
        target = self.build_instruction['target']
        if not self.config['targets'][target]['active'] == "yes":
            logger.debug("Target %s not active." % target)
            return False
        # validate with sknorrs magic bash script
        tmp_handle, tmp_filename = tempfile.mkstemp(text=True)
        logger.info("Stitching XML config directory to %s" % tmp_filename)
        cmd = '/usr/bin/docserv-stitch %s %s' % (self.config['targets'][target]['config_dir'], tmp_filename)
        logger.debug("Stitching command: %s" % cmd)
        cmd = shlex.split(cmd)
        s = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        s.communicate()[0]
        rc = int(s.returncode)
        if rc == 0:
            logger.debug("Stitching of %s successful" % self.config['targets'][target]['config_dir'])
        else:
            logger.warning("Stitching of %s failed!" % self.config['targets'][target]['config_dir'])
            self.initialized = False
            return False
        # then read all files into an xml tree
        self.tree = ElementTree.parse(tmp_filename)
        xml_root = self.tree.getroot()
        self.branch = xml_root.find(".//productindex[@id='%s']/docset[@id='%s']/builddocs/language[@code='%s']/branch" % (self.build_instruction['product'], self.build_instruction['docset'], self.build_instruction['lang'])).text
        self.remote_repo = repo = xml_root.find(".//productindex[@id='%s']/docset[@id='%s']/builddocs/git/remote" % (self.build_instruction['product'], self.build_instruction['docset'])).text

    def prepare_repo(self):
        commands = {}
        # Clone locally cached repository, does nothing if exists
        n = 0
        local_repo_cache_dir = os.path.join(self.config['server']['repo_dir'], resource_to_filename(self.remote_repo))
        commands[n] = {}
        commands[n]['cmd'] = "git clone %s %s" % (self.remote_repo, local_repo_cache_dir)
        commands[n]['ret_val'] = None
        commands[n]['repo_lock'] = local_repo_cache_dir

        # update locally cached repo
        n += 1
        commands[n] = {}
        commands[n]['cmd'] = "git -C %s pull --all " % local_repo_cache_dir
        commands[n]['ret_val'] = 0
        commands[n]['repo_lock'] = local_repo_cache_dir

        # Create local copy in temp build dir
        n += 1
        local_repo_build_dir = os.path.join(self.config['server']['temp_repo_dir'], ''.join(random.choices(string.ascii_uppercase + string.digits, k=12)))
        commands[n] = {}
        commands[n]['cmd'] = "git clone --single-branch --branch %s %s %s" % (self.branch, local_repo_cache_dir, local_repo_build_dir)
        commands[n]['ret_val'] = 0
        commands[n]['repo_lock'] = None

        for i in range(0,n + 1):
            cmd = shlex.split(commands[i]['cmd'])
            if commands[i]['repo_lock'] is not None:
                self.git_lock.acquire()
            logger.debug("Thread %i: %s" % (self.thread_id, commands[i]['cmd']))
            s = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
            s.communicate()[0]
            self.git_lock.release()
            if commands[i]['ret_val'] is not None and not commands[i]['ret_val'] == int(s.returncode):
                logger.warning("Build failed! Unexpected return value %i for '%s'" % (s.returncode, commands[i]['cmd']) )
                self.initialized = False
                return False

        # a working copy containing the required branch has been created in local_repo_build_dir
        self.local_repo_build_dir = local_repo_build_dir
        return True

    def compare_commit_id(self):
        cmd = shlex.split("git -C "+self.local_repo_build_dir+" log --format=\"%H\" -n 1")
        s = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        current_commit_id = s.communicate()[0]
        current_commit_id = current_commit_id.decode('utf-8').rstrip()
        logger.debug("Current commit hash: %s" % current_commit_id)
        if current_commit_id == self.build_instruction['commit']:
            logger.info("Build instruction %s is already up to date. Current commit: %s" % (json.dumps(self.build_instruction), current_commit_id))
            self.initialized = False
        else:
            self.build_instruction['commit'] = current_commit_id

    def validate(self, build_instruction, config):
        # build_instruction = {'docs': [{'docset': '15ga', 'lang': 'en', 'product': 'sles'}],'targets': ['external']}
        if not isinstance(build_instruction, dict):
            logger.warning("Validation: Is not a dict")
            return False
        if not isinstance(build_instruction['docset'], str):
            logger.warning("Validation: docset is not a string")
            return False
        if not isinstance(build_instruction['lang'], str):
            logger.warning("Validation: lang is not a string")
            return False
        if not isinstance(build_instruction['product'], str):
            logger.warning("Validation: product is not a string")
            return False
        if not isinstance(build_instruction['target'], str):
            logger.warning("Validation: target is not a string")
            return False
        logger.debug("Valid build instruction: %s" % json.dumps(build_instruction))
        return True

    def generate_deliverables(self):
        if self.initialized == False:
            return False
        target = self.build_instruction['target']
        xml_root = self.tree.getroot()
        for xml_deliverable in xml_root.findall(".//productindex[@id='%s']/docset[@id='%s']/builddocs/language[@code='%s']/deliverable" % (self.build_instruction['product'], self.build_instruction['docset'], self.build_instruction['lang'])):
            for build_format in xml_deliverable.findall(".//format"):
                deliverable = Deliverable(  self.local_repo_build_dir,
                                            self.build_instruction,
                                            xml_deliverable.find(".//dc").text,
                                            build_format.text,
                                            self.config['targets'][target]
                                         )
                self.deliverables.put(deliverable)
            # after all deliverables are generated, we don't need the xml tree anymore
            self.tree = None
        return True

class DocservState:
    in_queues = {}
    in_queues_lock = threading.Lock()
    scheduled_docs = queue.Queue()
    building_docs = {}
    building_docs_lock = threading.Lock()
    past_builds = {}
    past_builds_lock = threading.Lock()
    currently_building = queue.Queue()
    currently_building_lock = threading.Lock()

    def __del__(self):
        pass

    def __str__(self):
        return json.dumps(self.__dict__())

    def __dict__(self):
        return {
            'building': self.in_queues,
            'finished': self.past_builds
        }
    
    def generate_id(self, build_instruction):
        return hashlib.md5((build_instruction['target'] +
                            build_instruction['docset'] +
                            build_instruction['lang'] +
                            build_instruction['product']).encode('utf-8')
                          ).hexdigest()[:9]

    def add_build_instruction(self, build_instruction):
        build_instruction['id'] = self.generate_id(build_instruction)
        if build_instruction['id'] in self.past_builds.keys():
            build_instruction['commit'] =self.past_builds[build_instruction['id']]['commit']
        else:
            build_instruction['commit'] = ""
        self.in_queues_lock.acquire()
        if build_instruction['id'] not in self.in_queues:
            self.in_queues[build_instruction['id']] = build_instruction
            self.in_queues_lock.release()
            self.scheduled_docs.put(build_instruction)
            return True
        else:
            # Already building or queued for building
            self.in_queues_lock.release()
            return False

    def get_scheduled_build_instruction(self):
        # Get a build_instruction for building
        try:
            build_instruction = self.scheduled_docs.get(False)
        except queue.Empty:
            return None
        if self.building_docs_lock.acquire():
            self.building_docs[build_instruction['id']] = build_instruction
            self.building_docs_lock.release()
            return build_instruction

    def get_deliverable(self):
        if self.currently_building_lock.acquire(False):
            try:
                doc = self.currently_building.get(False)
            except queue.Empty:
                self.currently_building_lock.release()
                return None
            else:
                try:
                    deliverable = doc.deliverables.get(False)
                except queue.Empty:
                    logger.info("Finished building deliverables for %s" % doc)
                    self.currently_building_lock.release()
                    self.finished_building_doc(doc)
                    return None
                else:
                    logger.info("Got deliverable from %s" % doc)
                    self.currently_building.put(doc)
                    self.currently_building_lock.release()
                    return deliverable
        return None

    def finished_building_doc(self, doc):
        self.building_docs_lock.acquire()
        self.building_docs.pop(doc.__dict__()['id'],None)
        self.building_docs_lock.release()

        self.in_queues_lock.acquire()
        self.in_queues.pop(doc.__dict__()['id'],None)
        self.in_queues_lock.release()

        self.past_builds_lock.acquire()
        self.past_builds[doc.__dict__()['id']] = doc.__dict__()
        self.past_builds_lock.release()
    
    def save_state(self):
        f = open("/etc/docserv/%s.json" % self.config['server']['name'], "w")
        f.write(self.__str__()) 

    def load_state(self):
        logger.info("Reading previous state.")
        filepath = "/etc/docserv/%s.json" % self.config['server']['name']
        if os.path.isfile(filepath):
            file = open(filepath, "r")
            state = json.loads(file.read())
            if 'finished' in state and 'building' in state:
                self.past_builds = state['finished']
                for build_instruction_id, build_instruction in state['building'].items():
                    self.add_build_instruction(build_instruction)
                return True
        return False

class Docserv(DocservState):
    gitLocks = {}
    gitLocksLock = threading.Lock()
    end_end_all = queue.Queue()

    def __init__(self, argv):
        self.parse_config(argv)
        self.load_state()

    def parse_config(self, argv):
        config = configparser.ConfigParser()
        if len(argv) == 1:
            config_file = "docserv"
        else:
            config_file = argv[1]
        logger.info("Reading /etc/docserv/%s.ini" % config_file)
        config.read("/etc/docserv/%s.ini" % config_file)
        LOGLEVELS = {0: logging.WARNING,
             1: logging.INFO,
             2: logging.DEBUG,
        }
        try:
            self.config = {}
            self.config['server'] =                  {}
            self.config['server']['name'] =          config_file
            self.config['server']['loglevel'] =      int(config['server']['loglevel'])
            logger.setLevel(LOGLEVELS[self.config['server']['loglevel']])
            self.config['server']['host'] =          config['server']['host']
            self.config['server']['port'] =          int(config['server']['port'])
            self.config['server']['repo_dir'] =      config['server']['repo_dir']
            self.config['server']['temp_repo_dir'] = config['server']['temp_repo_dir']
            self.config['server']['max_threads'] =   int(config['server']['max_threads'])
            self.config['targets'] =                 {}
            for section in config.sections():
                if not str(section).startswith("target_"):
                    continue
                self.config['targets'][config[section]['name']] =                  {}
                self.config['targets'][config[section]['name']]['templdate_dir'] = config[section]['template_dir']
                self.config['targets'][config[section]['name']]['active'] =        config[section]['active']
                self.config['targets'][config[section]['name']]['draft'] =         config[section]['draft']
                self.config['targets'][config[section]['name']]['remarks'] =       config[section]['remarks']
                self.config['targets'][config[section]['name']]['beta_warning'] =  config[section]['beta_warning']
                self.config['targets'][config[section]['name']]['target_path'] =   config[section]['target_path']
                self.config['targets'][config[section]['name']]['backup_path'] =   config[section]['backup_path']
                self.config['targets'][config[section]['name']]['config_dir'] =    config[section]['config_dir']
        except KeyError as error:
            logger.warning("Invalid configuration file, missing configuration key '%s'. Exiting." % error)
            sys.exit(1)

    def start(self):
        # start everything
        try:
            thread_receive = threading.Thread(target=self.listen)
            thread_receive.start()
            workers = []
            for i in range(0,min([os.cpu_count(), self.config['server']['max_threads']])):
                logger.info("Starting build thread %i" % i)
                worker = threading.Thread(target=self.worker, args=(i,))
                worker.start()
                workers.append(worker)
            # to have a clean shutdown, wait for all threads to finish
            thread_receive.join()
        except KeyboardInterrupt:
            self.exit()
        for worker in workers:
            worker.join()

    def exit(self):
        logger.warning("Received SIGINT. Telling all threads to end. Please wait.")
        self.end_end_all.put("now")
        self.rest.shutdown()
        self.save_state()

    def worker(self, thread_id):
        while( True ):
            # 1. parse input from rest api and put the instance of the doc class on the currently building queue
            build_instruction = self.get_scheduled_build_instruction()
            if build_instruction is not None:
                doc = DocConfParser(build_instruction, self.config, self.gitLocks, self.gitLocksLock, thread_id)
                if doc.generate_deliverables():
                    self.currently_building.put(doc)
                del doc
             
            # 2. get doc from currently_building queue and then a deliverable from doc.
            #    after that, put doc back on the currently building queue. unless it was
            #    the last deliverable.
            deliverable = self.get_deliverable()
            if deliverable is not None:
                deliverable.run(thread_id)

            # 3. end thread if sigint
            if not self.end_end_all.empty(): return True

            # 4. wait for a short while, then repeat
            time.sleep(2)

    def listen(self):
        server_address = (self.config['server']['host'], int(self.config['server']['port']))
        self.rest = ThreadedRESTServer(server_address, RESTServer, self)
        self.rest.serve_forever()
        return True

class RESTServer(BaseHTTPRequestHandler):
    def _set_headers(self):
        self.send_response(200)
        self.send_header('Content-type', 'application/json')
        self.end_headers()
    def do_GET(self):
        self._set_headers()
        self.wfile.write(bytes(json.dumps(self.server.docserv.__dict__()), "utf-8"))
    def do_POST(self):
        content_length = int(self.headers['Content-Length'])
        # [{"docset": "15ga", "lang": "en-us", "product": "sles", "target": "external"}. ]
        post_data = self.rfile.read(content_length)
        build_jobs = json.loads(post_data)
        for job in build_jobs:
            logger.info("Queueing %s" % json.dumps(job))
            self.server.docserv.add_build_instruction(job)
        self._set_headers()

class ThreadedRESTServer(ThreadingMixIn, HTTPServer):
    def __init__(self, server_address, RequestHandlerClass, docserv, bind_and_activate=True):
        HTTPServer.__init__(self, server_address, RequestHandlerClass, bind_and_activate)
        logger.info("Starting HTTP server on %s:%i" % server_address)
        self.docserv = docserv

logger = logging.getLogger('docserv')
logger.setLevel(logging.INFO)

ch = logging.StreamHandler(sys.stdout)
ch.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
ch.setFormatter(formatter)
logger.addHandler(ch)

if __name__ == "__main__":
    docserv = Docserv(sys.argv)
    docserv.start()
