#!/usr/bin/env python
#       m   
#      u    apswtrace.py - Sat Sep 07 19:22 CEST 2013
#  SQLite   enhanced APSW SQL tracer
#    d      part of sqmediumlite
#   e       Copyright (C): Roger Binns
#  m        

"""
Lets you automatically trace SQL operations in a program 
using APSW without having to modify the program in any way. 
Copied from apswtrace.py by Roger Binns, as distributed
in APSW "Another Python Sqlite Wrapper" version 3.7.9-R1.
Provided as open source, according to the terms of use 
described on the APSW Copyright and License page:
    http://apidoc.apsw.googlecode.com/hg/copyright.html
Changes to the original version are marked 'sqmediumlite'.

Changes:
- add a rowcount to the aggregare reports, displaying the
  number of rows fetched, inserted, updated or deleted. 
- work around a shortcoming in the SQLite profiler that 
  simply ignores queries that are closed before the last
  row has been fetched (still in SQLite library version 3.7.5)
- work around another shortcoming in the SQLite profiler
  and do not include processing time between cursor steps
- add command option explainplan
- add command option orderby
- add command option exeonly
- scale displayed timings depending on their magnitude
  (minor importance but better readable)
"""

import time
import sys
import weakref

class APSWTracer(object):

    # sqmediumlite begin
    class Timing (object):
        __slots__ = \
                "query", \
                "isconnection", \
                "nanoseconds", \
                "rowschanged", \
                "rowsreturned", \
                "starttime"
        def __init__ (self,query):
            self.query=query
            self.isconnection=None
            self.nanoseconds=None
            self.rowschanged=0
            self.rowsreturned=0
            self.starttime=time.time()
        def endtime (self):
            if self.nanoseconds is None:
                self.nanoseconds=int(1000000000*(time.time()-self.starttime))

    class Query:
        def __init__ (self):
            self.timings=[]
            self.threadsused={}
            self.plan=None

    class Threadused:
        def __init__ (self):
            self.activecursor=None
            self.actcon=None

    class Condata:
        def __init__ (self):
            self.totalchanges=0
            self.auxcursor=None
    # sqmediumlite end

    def __init__(self, options):
        if sys.version_info<(3,):
            self.u=eval("u''")
            import thread
            self.threadid=thread.get_ident
            self.stringtypes=(unicode,str)
            self.numtypes=(int, long, float)
            self.binarytypes=(buffer,)
        else:
            self.u=""
            import _thread
            self.threadid=_thread.get_ident
            self.stringtypes=(str,)
            self.numtypes=(int, float)
            self.binarytypes=(bytes,)
        self.options=options
        if options.output in ("-", "stdout"):
            self._writer=sys.stdout.write
        elif options.output=="stderr":
            self._writer=sys.stderr.write
        else:
            self._writer=open(options.output, "wt").write

        try:
            global apsw # sqmediumlite
            import apsw
            apsw.connection_hooks.append(self.connection_hook)
        except:
            sys.stderr.write(self.u+"Unable to import apsw\n")
            raise

        self.mapping_open_flags=apsw.mapping_open_flags
        self.zeroblob=apsw.zeroblob
        self.apswConnection=apsw.Connection
        self.newcon={} # sqmediumlite
        self.newcursor={}
        self.threadsused={} # really want a set
        self.queries={}
        # sqmediumlite begin
        ### self.timings={}
        ### self.rowsreturned=0
        # sqmediumlite end
        self.numcursors=0
        self.numconnections=0
        self.timestart=time.time()

    def writerpy2(self, s):
        # s should be a unicode string
        self._writer(s.encode("utf-8")+"\n")

    def writerpy3(self, s):
        self._writer(s+"\n")

    if sys.version_info<(3,):
        writer=writerpy2
    else:
        writer=writerpy3

    def format(self, obj):
        if isinstance(obj, dict):
            return self.formatdict(obj)
        if isinstance(obj, tuple):
            return self.formatseq(obj, '()')
        if isinstance(obj, list):
            return self.formatseq(obj, '[]')
        if isinstance(obj, self.stringtypes):
            return self.formatstring(obj)
        if obj is True:
            return "True"
        if obj is False:
            return "False"
        if obj is None:
            return "None"
        if isinstance(obj, self.numtypes):
            return repr(obj)
        if isinstance(obj, self.binarytypes):
            return self.formatbinary(obj)
        if isinstance(obj, self.zeroblob):
            return "zeroblob(%d)" % (obj.length(),)
        return repr(obj)

    def formatstring(self, obj, quote='"', checkmaxlen=True):
        obj=obj.replace("\n", "\\n").replace("\r", "\\r")
        if checkmaxlen and len(obj)>self.options.length:
            obj=obj[:self.options.length]+'..'
        return self.u+quote+obj+quote

    def formatdict(self, obj):
        items=list(obj.items())
        items.sort()
        op=[]
        for k,v in items:
            op.append(self.format(k)+": "+self.format(v))
        return self.u+"{"+", ".join(op)+"}"

    def formatseq(self, obj, paren):
        return self.u+paren[0]+", ".join([self.format(v) for v in obj])+paren[1]

    def formatbinarypy2(self, obj):
        if len(obj)<self.options.length:
            return "X'"+"".join(["%x" % ord(obj[i]) for i in range(len(obj))])+"'"
        return "(%d) X'"%(len(obj),)+"".join(["%x" % ord(obj[i]) for i in range(self.options.length)])+"..'"

    def formatbinarypy3(self, obj):
        if len(obj)<self.options.length:
            return "X'"+"".join(["%x" % obj[i] for i in range(len(obj))])+"'"
        return "(%d) X'"%(len(obj),)+"".join(["%x" % obj[i] for i in range(self.options.length)])+"..'"

    if sys.version_info<(3,):
        formatbinary=formatbinarypy2
    else:
        formatbinary=formatbinarypy3

    # sqmediumlite begin
    ### def sanitizesql(self, sql):
    ###     sql=sql.strip("; \t\r\n")
    ###     while sql.startswith("--"):
    ###         sql=sql.split("\n", 1)[1]
    ###         sql=sql.lstrip("; \t\r\n")
    ###     return sql
    # sqmediumlite end

    def profiler(self, sql, nanoseconds):
        # sqmediumlite begin
        """
        The profiler saves time and rows changed in the list
        of timing instance for the given query. To obtain
        the rows changed, we need the current timing instance.
        This is found in a two-dimensional dictionary of queries 
        and threads. The timing instance was created earlier by 
        the execution tracer.
        """
        ### sql=self.sanitizesql(sql)
        ### if sql not in self.timings:
        ###     self.timings[sql]=[nanoseconds]
        ### else:
        ###     self.timings[sql].append(nanoseconds)
        if sql.startswith ('explaiN') or sql.endswith ('sqlite_stat1'):
            return # internal recursive query
        sql=sql.strip("; \t\r\n")
        query=self.queries[sql]
        tid=self.threadid()
        threadused=query.threadsused[tid]
        try:
            timing=self.newcursor[threadused.activecursor]
        except KeyError:
            return # already taken care of in cursorfinished
        assert timing.query is query
        if timing.nanoseconds is None:
            timing.nanoseconds=nanoseconds
        if not timing.isconnection:
            wr=threadused.actcon
            condata=self.newcon[wr]
            rowschanged=wr().totalchanges()-condata.totalchanges
            if rowschanged:
                condata.totalchanges+=rowschanged
                timing.rowschanged=rowschanged
        # sqmediumlite end

    def cursorfinished(self, cursor):
        self.newcursor[cursor].endtime () # sqmediumlite
        del self.newcursor[cursor]

    # sqmediumlite begin
    def confinished(self, con):
        del self.newcon[con]
    # sqmediumlite end

    def exectracer(self, cursor, sql, bindings):
        # sqmediumlite begin
        """
        With regards to the reports:
        - Maintain the dictionaries of threads, queries and
          cursors.
        - Initailize a new timing for the current query.
        - Register the timing as active in the current
          thread.
        """
        if sql.startswith ('explaiN'): # apswtrace internal
            return True
        if self.options.report:
            tid=self.threadid()
            if tid not in self.threadsused:
                self.threadsused[tid]=True
            fix=sql.strip("; \t\r\n")
            if fix not in self.queries:
                self.queries[fix]=self.Query()
            query=self.queries[fix]
            if tid not in query.threadsused:
                query.threadsused[tid]=self.Threadused()
            threadused=query.threadsused[tid]
            wr=weakref.ref(cursor)
            if wr not in self.newcursor:
                if self.options.rows or self.options.report:
                    if cursor.getrowtrace():
                        externalrowtracer= cursor.getrowtrace()
                        cursor.setrowtrace(lambda cur, row:
                                externalrowtracer(cur, self.rowtracer(cur, row)))
                    else:
                        cursor.setrowtrace(self.rowtracer)
                wr=weakref.ref(cursor, self.cursorfinished)
                self.numcursors+=1
            else:
                self.newcursor[wr].endtime()
            timing=self.Timing(query)
            self.newcursor[wr]=timing
            query.timings.append(timing)
            if isinstance(cursor, self.apswConnection): # after "with con:"
                self.newcursor[wr].isconnection = True
                connection = cursor
            else:
                connection = cursor.getconnection()
            if threadused.activecursor!=wr:
                threadused.activecursor=wr
                threadused.actcon=weakref.ref(connection)
            if self.options.explainplan and query.plan is None:
                wr=weakref.ref(connection)
                auxcursor=self.newcon[wr].auxcursor
                auxcursor.execute('explaiN query plan\n'+fix,bindings)
                query.plan=list(auxcursor)
        # sqmediumlite end
        if self.options.sql:
            args=[id(cursor), "SQL:", self.formatstring(sql, '', False)]
            if bindings:
                args.extend(["BINDINGS:", self.format(bindings)])
            self.log(*args)
        return True


    def rowtracer(self, cursor, row):
        # sqmediumlite begin
        """
        With regards to the reports:
        - Skip apswtrace internal cursor
        - Increment the rowcount on the timing instance.
        - Again, register the timing as active in the 
          current thread.
        - Optionally compute the timing right after fetching 
          one of many rows. This may be more realistic than
          the total elasped time for fetching all rows, as
          is otherwise obtained from the profiler.
        """
        if self.options.explainplan:
            wr=weakref.ref(cursor.getconnection())
            if cursor is self.newcon[wr].auxcursor:
                return row
        wr=weakref.ref(cursor)
        timing=self.newcursor[wr]
        if self.options.report:
            timing.rowsreturned+=1
            tid=self.threadid()
            threadused=timing.query.threadsused[tid]
            if threadused.activecursor!=wr:
                threadused.activecursor=wr
                threadused.actcon=weakref.ref(cursor.getconnection())
        if self.options.exeonly:
            timing.endtime()
        # sqmediumlite end
        if self.options.rows:
            self.log(id(cursor), "ROW:", self.format(row))
        return row

    def flagme(self, value, mapping, strip=""):
        v=[(k,v) for k,v in mapping.items() if isinstance(k, int)]
        v.sort()
        op=[]
        for k,v in v:
            if value&k:
                if v.startswith(strip):
                    v=v[len(strip):]
                op.append(v)
        return self.u+"|".join(op)

    def connection_hook(self, con):
        self.numconnections+=1
        # sqmediumlite begin
        wr=weakref.ref(con, self.confinished)
        self.newcon[wr]=self.Condata()
        # sqmediumlite end
        if self.options.report:
            con.setprofile(self.profiler)
        if self.options.sql or self.options.report:
            con.setexectrace(self.exectracer)
        # sqmediumlite begin
        ### if self.options.rows or self.options.report:
        ###     con.setrowtrace(self.rowtracer)
        # sqmediumlite end
        if self.options.sql:
            self.log(id(con), "OPEN:", self.formatstring(con.filename, checkmaxlen=False), con.open_vfs, self.flagme(con.open_flags, self.mapping_open_flags, "SQLITE_OPEN_"))
        # sqmediumlite begin
        if self.options.explainplan:
            condata=self.newcon[wr]
            condata.auxcursor=apsw.Connection.cursor(con)
        # sqmediumlite end

    def log(self, lid, ltype, *args):
        out=["%x" % (lid,)]
        if self.options.timestamps:
            out.append("%.03f" % (time.time()-self.timestart,))
        if self.options.thread:
            out.append("%x" % (self.threadid(),))
        out.append(ltype)
        out.extend(args)
        self.writer(self.u+" ".join(out))

    def run(self):
        import sys
        import __main__
        d=__main__.__dict__
        if sys.version_info<(3,):
            execfile(sys.argv[0],d, d)
        else:
            # We use compile so that filename is present in printed exceptions
            code=compile(open(sys.argv[0], "rb").read(), sys.argv[0], "exec")
            exec(code, d, d)

    def mostpopular(self, howmany):
        # sqmediumlite begin
        ### all=[(v,k) for k,v in self.queries.items()]
        all=[(len(v.timings),k) for k,v in self.queries.items()]
        # sqmediumlite end
        all.sort()
        all.reverse()
        return all[:howmany]

    def longestrunningaggregate(self, howmany):
        # sqmediumlite begin
        ### all=[(sum(v),len(v),k) for k,v in self.timings.items()]
        all=[(
                sum(t.nanoseconds for t in v.timings),
                len(v.timings),
                sum(t.rowschanged+t.rowsreturned for t in v.timings),
                k,
                v.plan,
                ) for k,v in self.queries.items()]
        orderby = self.options.orderby
        if orderby and orderby != 2 and 1 <= orderby <= 4:
            if orderby == 1: index = 1
            elif orderby == 3: index = 2
            elif orderby == 4: index = 3
            all2 = [(i[index], i) for i in all]
            all2.sort()
            all = [i[1] for i in all2]
        else:
            all.sort()
        # sqmediumlite end
        all.reverse()
        return all[:howmany]

    def longestrunningindividual(self, howmany):
        res=[]
        # sqmediumlite begin
        ### for k,v in self.timings.items():
        ###     for t in v:
        ###         res.append( (t, k) )
        for k,v in self.queries.items():
            for t in v.timings:
                res.append( (t.nanoseconds, t.rowschanged+t.rowsreturned, k) )
        # sqmediumlite end
        res.sort()
        res.reverse()
        res=res[:howmany]
        return res

    def report(self):
        # sqmediumlite begin
        for timing in self.newcursor.values ():
            timing.endtime ()
        # sqmediumlite end
        import time
        if not self.options.report:
            return
        w=lambda *args: self.writer(self.u+" ".join(args))
        if "summary" in self.options.reports:
            w("APSW TRACE SUMMARY REPORT")
            w()
            w("Program run time                   ", "%.03f seconds" % (time.time()-self.timestart,))
            w("Total connections                  ", str(self.numconnections))
            w("Total cursors                      ", str(self.numcursors))
            w("Number of threads used for queries ", str(len(self.threadsused)))
        # sqmediumlite begin
        ### total=0
        ### for k,v in self.queries.items():
        ###     total+=v
        ### fmtq=len("%d" % (total,))+1
        ### if "summary" in self.options.reports:
            total=0
            for k,v in self.queries.items():
                total+=len(v.timings)
        # sqmediumlite end
            w("Total queries                      ", str(total))
            w("Number of distinct queries         ", str(len(self.queries)))
            total=0
            # sqmediumlite begin
            ### for k,v in self.timings.items():
            ###     for v2 in v:
            ###         total+=v2
            rowschanged=0
            rowsreturned=0
            for k,v in self.queries.items():
                for timing in v.timings:
                    total+=timing.nanoseconds
                    rowschanged+=timing.rowschanged
                    rowsreturned+=timing.rowsreturned
            w("Number of rows returned            ", str(rowsreturned))
            w("Number of rows changed             ", str(rowschanged))
            # sqmediumlite end
            w("Time spent processing queries      ", "%.03f seconds" % (total/1000000000.0))

        # show most popular queries
        if "popular" in self.options.reports:
            w()
            w("MOST POPULAR QUERIES")
            w()
            fmtc=None # sqmediumlite
            for count, query in self.mostpopular(self.options.reportn):
                # sqmediumlite begin
                ### w("% *d" % (fmtq, count,), self.formatstring(query, '', False))
                if fmtc is None:
                    fmtc="% "+str(1+len(str(count)))+"d"
                w(fmtc % (count,), self.formatstring(query, '', False))
                # sqmediumlite end

        # show longest running (aggregate)
        if "aggregate" in self.options.reports:
            w()
            w("LONGEST RUNNING - AGGREGATE")
            w()
            fmtt=None
            # sqmediumlite begin
            ### for total, count, query in self.longestrunningaggregate(self.options.reportn):
            ###     if fmtt is None:
            ###         fmtt=len(fmtfloat(total/1000000000.0))+1
            ###     w("% *d %s" % (fmtq, count, fmtfloat(total/1000000000.0, total=fmtt)), self.formatstring(query, '', False))
            qfmt=None
            all=self.longestrunningaggregate(self.options.reportn)
            if all:
                fmtc="% "+str(1+len(str(max(r for t,c,r,q,p in all))))+"d"
                fmtr=len(str(max(r for t,c,r,q,p in all)))
            for total, count, rows, query, plan in all:
                if fmtt is None:
                    fmtt=len(str(int(total/1000000000)))
                if qfmt is None:
                    qfmt=fmtc+"%"+str(6+fmtt)+"."+str(4-fmtt)+"f%"+str(2+fmtr)+"i"
                w(qfmt % (count, total/1000000000.0, rows), self.formatstring(query, '', False))
                if self.options.explainplan:
                    if plan: # if no syntaxerror
                        for sid,ord,frm,dtl in plan:
                            if frm==ord:
                                w(' %2.0i%2i    %s%s'%(sid,ord,ord*'  ',dtl))
                            else:
                                w(' %2.0i%2i%2i %s%s'%(sid,ord,frm,ord*'  ',dtl))
                    w()
            # sqmediumlite end

        # show longest running (individual)
        if "individual" in self.options.reports:
            w()
            w("LONGEST RUNNING - INDIVIDUAL")
            w()
            # sqmediumlite begin
            ### fmtt=None
            ### for t,query in self.longestrunningindividual(self.options.reportn):
            ###     if fmtt is None:
            ###         fmtt=len(fmtfloat(total/1000000000.0))+1
            ###     w(fmtfloat(t/1000000000.0, total=fmtt), self.formatstring(query, '', False))
            fmt=None
            all = self.longestrunningindividual(self.options.reportn)
            if all:
                fmtr=len(str(max(r for t,r,q in all)))
            for t,r,query in all:
                if fmt is None:
                    fmtt=len(str(int(t/1000000000)))
                    fmt="%"+str(6+fmtt)+"."+str(4-fmtt)+"f%"+str(2+fmtr)+"i"
                w(fmt % (t/1000000000.0,r), self.formatstring(query, '', False))
            # sqmediumlite end

# sqmediumlite begin
### def fmtfloat(n, decimals=3, total=None):
###     "Work around borken python float formatting"
###     s="%0.*f" % (decimals, n)
###     if total:
###         s=(" "*total+s)[-total:]
###     return s
# sqmediumlite end

def main():
    import optparse
    import os
    import sys

    reports=("summary", "popular", "aggregate", "individual")

    parser=optparse.OptionParser(usage="%prog [options] pythonscript.py [pythonscriptoptions]",
                                 description="This script runs a Python program that uses APSW "
                                 "and reports on SQL queries without modifying the program.  This is "
                                 "done by using connection_hooks and registering row and execution "
                                 "tracers.  See APSW documentation for more details on the output.")
    parser.add_option("-o", "--output", dest="output", default="stdout",
                      help="Where to send the output.  Use a filename, a single dash for stdout, or the words stdout and stderr. [%default]")
    parser.add_option("-s", "--sql", dest="sql", default=False, action="store_true",
                      help="Log SQL statements as they are executed. [%default]")
    parser.add_option("-r", "--rows", dest="rows", default=False, action="store_true",
                      help="Log returned rows as they are returned (turns on sql). [%default]")
    parser.add_option("-t", "--timestamps", dest="timestamps", default=False, action="store_true",
                      help="Include timestamps in logging")
    parser.add_option("-i", "--thread", dest="thread", default=False, action="store_true",
                      help="Include thread id in logging")
    parser.add_option("-l", "--length", dest="length", default=30, type="int",
                      help="Max amount of a string to print [%default]")
    parser.add_option("--no-report", dest="report", default=True, action="store_false",
                      help="A summary report is normally generated at program exit.  This turns off the report and saves memory.")
    parser.add_option("--report-items", dest="reportn", metavar="N", default=15, type="int",
                      help="How many items to report in top lists [%default]")
    parser.add_option("--reports", dest="reports", default=",".join(reports),
                      help="Which reports to show [%default]")
    parser.add_option("--orderby", dest="orderby", default=2, type="int",
                      help="Order report by which column number (1..4) [%default]")
    parser.add_option("--explainplan", dest="explainplan", default=False, action="store_true",
                      help="Add query plan to aggregate report  [%default]")
    parser.add_option("--exeonly", dest="exeonly", default=False, action="store_true",
                      help="The time for fetching all rows is not included in the reports [%default]")

    parser.disable_interspersed_args()
    options, args=parser.parse_args()

    options.reports=[x.strip() for x in options.reports.split(",") if x.strip()]
    for r in options.reports:
        if r not in reports:
            parser.error(r+" is not a valid report.  You should supply one or more of "+", ".join(reports))

    if options.rows:
        options.sql=True

    if not args:
        parser.error("You must specify a python script to execute")

    if not os.path.exists(args[0]):
        parser.error("Unable to find script %r\n" % (args[0],))

    sys.argv=args
    sys.path[0]=os.path.split(os.path.abspath(sys.argv[0]))[0]

    t=APSWTracer(options)

    try:
        t.run()
    finally:
        t.report()

if __name__=="__main__":
    main()