'''
Condition scanners for sagator
 
(c) 2006-2019 Jan ONDREJ (SAL) <ondrejj(at)salstar.sk>

 This program is free software; you can redistribute it and/or modify
 it under the terms of the GNU General Public License as published by
 the Free Software Foundation; either version 2 of the License, or
 (at your option) any later version.
                                                                                
'''

from __future__ import absolute_import
                                                                                
from avlib import *
from .match import match_any
from .cache import cache
                                                                                
__all__=['sql_find', 'regexp_find', 'check_level']

class sql_find(match_any):
  '''
  Recipient email address to index scanner, operating on SQL database.

  This scanner is an reimplementation of e2i_sql lmtpd() service from
  older sagator 0.7.2. It is useful only as lmtpd() scanner, because
  it operates on email recipients.

  Usage: sql_find(key, dbc, query, scanners)
  
  Where: key is an string, which defines which variable to compare.
           Currently only "recipient" or "sender" can be used here.
         dbc is an database connection
         query is an SQL condition to use
         scanners is an dictionary of scanners, each returned key from
           database must have coresponding key in this dictionary
  
  Example: sql_find(
             'recipient',
             db.sqlite(),
             "SELECT key FROM userpref WHERE email=%s",
             {
               'AV': b2f(libclam()),
               'AS': spamassassind(),
               '': s2f(libclam())+spamassassind() # default
             }
           )

  Limitations: It is not possible to include one sql_find() into another.
  
  New in version 0.9.0.
  '''
  name='sql_find()'
  def __init__(self,key,dbc,query,scanners):
      self.FIND_KEY=key # 'recipient' or 'sender'
      self.DBC=dbc
      self.QUERY=query
      self.SCANDICT=scanners
      self.reinit()
  def reinit(self):
      self.MATCHED={}
      for scanner in list(self.SCANDICT.values()):
        scanner.reinit()
  def rcpt_signature(self,rcpt):
      try:
        if self.FIND_KEY=='sender':
          find_email=mail.sender
        else:
          find_email=rcpt
        key=self.DBC.query(self.QUERY, [find_email])[0][0]
        debug.echo(4,'%s: "%s" matched for %s %s' \
                     % (self.name, key, self.FIND_KEY, find_email))
        if key in self.SCANDICT:
          debug.echo(4,'%s: Selected scanner: %s' \
                       % (self.name, self.SCANDICT[key].name))
        else:
          key=''
      except IndexError as e:
        key=''
      self.MATCHED[rcpt]=key
      return 'sql_find(%s,%s)' % (key, self.SCANDICT[key].rcpt_signature(rcpt))
  def scanbuffer(self, buffer, args={}):
      try:
        self.scanners=[self.SCANDICT[self.MATCHED[mail.recip[0]]]]
        return match_any.scanbuffer(self, buffer, args)
      except IndexError:
        debug.echo(4, "sql_find(): Default key '' not found!")
      return 0.0, b'', []

class regexp_find(sql_find):
  '''
  Recipient email address to index scanner, operating with regexp.

  It is useful only as lmtpd() scanner, because it operates on email
  recipients.

  Usage: regexp_find(key, scanners)
  
  Where: key is an string, which defines which variable to compare.
           Currently only "recipient" can be used here.
         scanners is an dictionary of scanners, each returned key from
           database must have coresponding key in this dictionary
  
  Example: regexp_find(
             'recipient',
             {
               '@somedomain.com$': b2f(libclam()),
               '@anotherdomain.sk$': spamassassind(),
               '': s2f(libclam())+spamassassind()
             }
           )

  Limitations: It is not possible to include one regexp_find() into another.
  
  New in version 0.9.0.
  '''
  name='regexp_find()'
  def __init__(self,dbc,query,scanners):
      self.DBC=dbc
      self.QUERY=query
      self.SCANDICT=scanners
      # compile regexps, ignore default ('') here
      self.REDICT=dict([[re.compile(key, re.I), key]
                        for key in list(scanners.keys())
                        if key!=''])
      # reinit
      self.reinit()
  def rcpt_signature(self,rcpt):
      for reg,key in list(self.REDICT.values()):
        if reg.search(rcpt):
          self.MATCHED[rcpt]='key'
          debug.echo(4,'%s: "%s" matched for %s' % (self.name, key, rcpt))
          return key
      self.MATCHED[rcpt]=''
      return ''

class check_level(match_any):
  '''
  Select scanner based on tested scanner return status.

  Usage: check_level(tested_scanner, {
                       (min,max): scanner,
                       (min,max): scanner, ...
                     })
     or: check_level()
  
  Where: tested_scanner is a scanner, which return level will be tested
         min is an integer, minimal level for this scanner
         max is an integer, maximal level for this scanner
  
  This scanner with no arguments will return previously saved status.
  Evaluation function is: min <= LEVEL < max .
  When no range is found, cached reply will be returned without changes.
  
  Example: check_level(spamassassind(), {
               (1.0, 5.0):     deliver(
                                 modify_subject('[SPAM:%L]',
                                   check_level()
                                 )
                               ),
               (5.0, 99999.0): drop('.', check_level())
           })

  New in version 0.9.0.
  '''
  name='check_level()'
  def __init__(self, tested_scanner=None, scanners={}):
      if type(tested_scanner) == type(None):
        match_any.__init__(self,
          [cache(self.name)] + list(scanners.values())
        )
      else:
        match_any.__init__(self,
          [cache(self.name, tested_scanner)] + list(scanners.values())
        )
      self.SCANDICT=scanners
  def scanbuffer(self, buffer, args={}):
      if not self.SCANDICT:
        return self.scanners[0].scanbuffer(buffer, args)
      level, detected, virlist = self.scanners[0].scanbuffer(buffer, args)
      for (lmin,lmax),scanner in list(self.SCANDICT.items()):
        if lmin <= level < lmax:
          debug.echo(5, "check_level(): %f<=%f<%f: %s" \
                        % (lmin, level, lmax, scanner.name))
          return scanner.scanbuffer(buffer, args)
      return level, detected, virlist
