'''
smtpd.py - SMTP daemon service for sagator.

(c) 2003-2024 Jan ONDREJ (SAL) <ondrejj(at)salstar.sk>

 This program is free software; you can redistribute it and/or modify
 it under the terms of the GNU General Public License as published by
 the Free Software Foundation; either version 2 of the License, or
 (at your option) any later version.

'''

from aglib import *
import stats, resource

__all__ = ['smtpd']

class smtps(smtp):
  '''
  SMTP server class
  '''
  def __init__(self, scanners, conn, addr):
      self.SCANNERS = scanners
      self.f = conn.makefile('rwb', self.bufsize)
      self.write = conn.sendall
      self.readline = self.f.readline
      self.RECV = RECV_SMTP
      self.stats = stats.statistics()
      self.srv = smtpc()
      self.write(self.srv.readline())
      mail.__init__()
      mail.addr = (
        addr[0].encode(), addr[1], mail.getnamebyaddr(addr[0]).encode()
      )
      self.connect_from = addr
      self.last_cmd = None # last command, like EHLO
      while True:
        if self.recv(): break
  def send_ok(self, r=b"Ok"):
      self.write(b"250 "+r+b"\r\n")
  def quit_connect(self, cmd=b'QUIT'):
      debug.echo(5, "SMTPS: Closing connect to server")
      self.RECV=RECV_QUIT
      if cmd:
        self.srv.conn.sendall(cmd+b"\r\n")
      self.srv.close()
      return self.recv()
  def check_reply(self, line, rw=True):
      '''Check for SMTP reply. Return True, if 2XX or 3XX returned.'''
      debug.echo(6, "SMTPC<: ", line)
      self.srv.conn.sendall(line)
      # send reply
      while 1:
        l=self.srv.readline()
        if not l:
          debug.echo(3, "SMTPS: Connection closed after:", line)
          return False
        if rw:
          if not (l.startswith(b'250-STARTTLS') \
                  and self.last_cmd
                  and self.last_cmd.group(1).upper()==b'EHLO'):
            debug.echo(7, 'SMTPC>: ', l)
            self.write(l)
          else:
            debug.echo(7, 'SMTPC>: Filtered STARTTLS: ', l)
        if self.reg_smtp_reply.search(l):
          break
      if self.reg_smtp_reply_ok.search(l):
        return True
      debug.echo(4, "SMTPS: ", l.rstrip(), ' --- ', line)
      return False
  def recv(self):
      try:
        line=self.readline()
      except IOError as err:
        (ec,es) = err.args
        debug.echo(0, "smtps(): ", es)
        return 1
      if not line:
        debug.echo(9, "smtps(): Empty line")
        return 1
      if self.RECV: # receiving data or header
        if self.RECV==RECV_QUIT: # waiting for quit?
          if self.reg_quit.search(line):
            # response for QUIT is not required, just close it
            #self.write(b"221 Bye\r\n")
            self.f.close()
            # connection shutdown follows from parent function
            return 1
          else:
            self.write(b"502 SMTP communication stopped by Sagator\r\n")
            return 0
        elif (line==b'.\r\n') | (line==b'.\n'):
          # end of DATA part
          mail.close()
          debug.echo(3, "SMTPS: BODY DONE, size: ", len(mail.data), " B")
          self.RECV = RECV_SMTP
          v, level, virname = checkvir(self.SCANNERS)
          if v!=S_TEMPFAIL:
            debug.echo(2, "STATS: %s seconds, %s bytes, status: %s"
                         % (self.stats.end(), len(mail.data), tostr(virname)))
          if (v==S_OK) | (v==S_FORCE_SEND):
            self.check_reply(b'DATA\r\n', rw=False)
            debug.echo(5, "SMTPS: Sending data")
            self.srv.conn.sendall(mail.xheader)
            self.srv.conn.sendall(mail.data)
            debug.echo(1, "SMTPS: OK: 250 Ok")
            self.srv.conn.sendall(line) # send .
            self.write(self.srv.readline())
            self.stats.update(len(mail.data))
            mail.__init__() # reinit
          elif v==S_REJECT:
            debug.echo(1, "SMTPS: REJECT: ","550 Content rejected - ",
                          tostr(virname))
            self.write(b"550 Content rejected - %s\r\n" % virname)
            self.stats.update(len(mail.data), 1)
            return self.quit_connect()
          elif v==S_DROP:
            debug.echo(1, "SMTPS: DROP: 250 mail dropped - ", tostr(virname))
            self.write(b"250 mail dropped - %s\r\n" % virname)
            self.stats.update(len(mail.data), 1)
            return self.quit_connect()
          elif v==S_CUSTOM:
            debug.echo(1, "SMTPS: CUSTOM: %s %s" % tostr(globals.REPLY))
            self.write(b"%s %s\r\n" % globals.REPLY)
            self.stats.update(len(mail.data), 1)
            return self.quit_connect()
          else: # S_TEMPFAIL
            debug.echo(1, "SMTPS: TEMPFAIL: 451 ", tostr(virname))
            self.write(b"451 %s\r\n" % virname)
            self.stats.update(tempfail=1) # update fail statistics
            return self.quit_connect()
        else:
          mail.df.write(line)
      else:
        mail.comm = mail.comm+line
        rcptto = self.reg_rcptto.search(line)
        mailfrom = self.reg_mailfrom.search(line)
        self.last_cmd = self.reg_cmd.search(line)
        if rcptto:
          try:
            recipient_addr = parseaddr(tostr(rcptto.group(1)))[1]
          except:
            recipient_addr = b''
          if globals.recipient_policy:
            mail.policy_request = {
              'client_address':	self.connect_from[0],
              'client_name':	self.connect_from[1],
              'helo_name':	self.reg_helo_ehlo.search(mail.comm).group(1),
              'sender':		mail.sender,
              'recipient':	recipient_addr
            }
            try:
              mail.policy_request['client_address'] = \
                self.reg_xforward_addr.search(mail.comm).group(1)
              mail.policy_request['client_name'] = \
                self.reg_xforward_name.search(mail.comm).group(1)
            except AttributeError:
              pass
            policy_reply = checkpolicy(globals.recipient_policy,True)
            self.stats.policy_update()
            if policy_reply[0]!=b'2':
              self.write(policy_reply+b"\r\n")
              return 0
          if self.check_reply(line):
            debug.echo(2, 'SMTPS: ', line)
            if recipient_addr:
              mail.recip.append(recipient_addr)
          return 0
        elif mailfrom:
          if self.check_reply(line):
            debug.echo(2, 'SMTPS: ', line)
            try:
              mail.sender = parseaddr(tostr(mailfrom.group(1)))[1]
            except:
              mail.sender = mailfrom.group(1)
        elif self.reg_rset.search(line):
          if self.check_reply(line):
            debug.echo(2, 'SMTPS: ', line.rstrip(), ", Resetting mail class")
            mail.__init__() # reinit mail class
            for scnr in self.SCANNERS:
              scnr.reinit() # reinit scanners
        elif self.reg_data.search(line):
          debug.echo(5, 'SMTPS: ', line)
          self.RECV = RECV_BODY
          self.write(b"354 End data with <CR><LF>.<CR><LF>\r\n")
        elif self.reg_quit.search(line):
          debug.echo(5, 'SMTPS: ', line)
          self.check_reply(line) # 221 Bye
          return 1
        elif self.reg_xforward.search(line):
          debug.echo(2, 'SMTPS: ', line)
          self.check_reply(line)
        else:
          debug.echo(6, 'SMTPS: ', line)
          self.check_reply(line)
      return 0

class smtpd(service):
  '''
  SMTP daemon service.
  
  This service can be used to start sagator as separate filtering SMTP
  daemon. Is is useful for postfix and any other SMTP daemon, which
  can use these filters.
  
  Usage: smtpd(scanners, host, port, prefork=2)
  
  Where: scanners is an array of scanners (see README.scanners for more info)
         host is a an ip address to bind
         port is a port to bind
         prefork is a number, which defines preforked process count.
           Set this parameter to actual processor count + 1
           or leave it's default (2). For multicore servers you can use
           core_count() function to use autodetection.
  
  Example: smtpd(SCANNERS, '127.0.0.1', 27)
  '''
  name='smtpd()'
  def accept(self,connects=0):
      # accept
      conn,addr = self.s.accept()
      socket_settimeout(conn,120)
      # reinit scanners
      for scnr in self.SCANNERS:
        scnr.reinit()
      # generate ID
      self.time2=time.strftime("%Y%m%d-%H%M%S",time.localtime(time.time()))
      if self.time2!=self.time1:
        self.time1,self.timec=self.time2,1
      globals.gen_id(self.time2, self.timec)
      debug.echo(1, "Connection from: ", addr[0], " at ", \
        time.strftime("%c", time.localtime()))
      debug.echo(1, "Process id: %s" % globals.id)
      smtps(self.SCANNERS, conn, addr)
      debug.echo(1, "%s: Closing connection." % self.name)
      # catch shutdown errors
      try:
        conn.shutdown(socket.SHUT_RDWR)
        conn.close()
      except socket.error:
        pass
