Source code for chatterbot.trainers

importloggingimportosimportsysfrom.conversationimportStatement,Responsefrom.importutilsclassTrainer(object):""" Base class for all other trainer classes. """def__init__(self,storage,**kwargs):self.chatbot=kwargs.get('chatbot')self.storage=storageself.logger=logging.getLogger(__name__)self.show_training_progress=kwargs.get('show_training_progress',True)defget_preprocessed_statement(self,input_statement):""" Preprocess the input statement. """# The chatbot is optional to prevent backwards-incompatible changesifnotself.chatbot:returninput_statementforpreprocessorinself.chatbot.preprocessors:input_statement=preprocessor(self,input_statement)returninput_statementdeftrain(self,*args,**kwargs):""" This method must be overridden by a child class. """raiseself.TrainerInitializationException()defget_or_create(self,statement_text):""" Return a statement if it exists. Create and return the statement if it does not exist. """temp_statement=self.get_preprocessed_statement(Statement(text=statement_text))statement=self.storage.find(temp_statement.text)ifnotstatement:statement=Statement(temp_statement.text)returnstatementclassTrainerInitializationException(Exception):""" Exception raised when a base class has not overridden the required methods on the Trainer base class. """def__init__(self,value=None):default=('A training class must be specified before calling train(). '+'See http://chatterbot.readthedocs.io/en/stable/training.html')self.value=valueordefaultdef__str__(self):returnrepr(self.value)def_generate_export_data(self):result=[]forstatementinself.storage.filter():forresponseinstatement.in_response_to:result.append([response.text,statement.text])returnresultdefexport_for_training(self,file_path='./export.json'):""" Create a file from the database that can be used to train other chat bots. """importjsonexport={'conversations':self._generate_export_data()}withopen(file_path,'w+')asjsonfile:json.dump(export,jsonfile,ensure_ascii=False)

[docs]classListTrainer(Trainer):""" Allows a chat bot to be trained using a list of strings where the list represents a conversation. """deftrain(self,conversation):""" Train the chat bot based on the provided list of statements that represents a single conversation. """previous_statement_text=Noneforconversation_count,textinenumerate(conversation):ifself.show_training_progress:utils.print_progress_bar('List Trainer',conversation_count+1,len(conversation))statement=self.get_or_create(text)ifprevious_statement_text:statement.add_response(Response(previous_statement_text))previous_statement_text=statement.textself.storage.update(statement)

[docs]classChatterBotCorpusTrainer(Trainer):""" Allows the chat bot to be trained using data from the ChatterBot dialog corpus. """def__init__(self,storage,**kwargs):super(ChatterBotCorpusTrainer,self).__init__(storage,**kwargs)from.corpusimportCorpusself.corpus=Corpus()deftrain(self,*corpus_paths):# Allow a list of corpora to be passed instead of argumentsiflen(corpus_paths)==1:ifisinstance(corpus_paths[0],list):corpus_paths=corpus_paths[0]# Train the chat bot with each statement and response pairforcorpus_pathincorpus_paths:corpora=self.corpus.load_corpus(corpus_path)corpus_files=self.corpus.list_corpus_files(corpus_path)forcorpus_count,corpusinenumerate(corpora):forconversation_count,conversationinenumerate(corpus):ifself.show_training_progress:utils.print_progress_bar(str(os.path.basename(corpus_files[corpus_count]))+' Training',conversation_count+1,len(corpus))previous_statement_text=Nonefortextinconversation:statement=self.get_or_create(text)statement.add_tags(corpus.categories)ifprevious_statement_text:statement.add_response(Response(previous_statement_text))previous_statement_text=statement.textself.storage.update(statement)

[docs]classTwitterTrainer(Trainer):""" Allows the chat bot to be trained using data gathered from Twitter. :param random_seed_word: The seed word to be used to get random tweets from the Twitter API. This parameter is optional. By default it is the word 'random'. :param twitter_lang: Language for results as ISO 639-1 code. This parameter is optional. Default is None (all languages). """def__init__(self,storage,**kwargs):super(TwitterTrainer,self).__init__(storage,**kwargs)fromtwitterimportApiasTwitterApi# The word to be used as the first search term when searching for tweetsself.random_seed_word=kwargs.get('random_seed_word','random')self.lang=kwargs.get('twitter_lang')self.api=TwitterApi(consumer_key=kwargs.get('twitter_consumer_key'),consumer_secret=kwargs.get('twitter_consumer_secret'),access_token_key=kwargs.get('twitter_access_token_key'),access_token_secret=kwargs.get('twitter_access_token_secret'))defrandom_word(self,base_word,lang=None):""" Generate a random word using the Twitter API. Search twitter for recent tweets containing the term 'random'. Then randomly select one word from those tweets and do another search with that word. Return a randomly selected word from the new set of results. """importrandomrandom_tweets=self.api.GetSearch(term=base_word,count=5,lang=lang)random_words=self.get_words_from_tweets(random_tweets)random_word=random.choice(list(random_words))tweets=self.api.GetSearch(term=random_word,count=5,lang=lang)words=self.get_words_from_tweets(tweets)word=random.choice(list(words))returnworddefget_words_from_tweets(self,tweets):""" Given a list of tweets, return the set of words from the tweets. """words=set()fortweetintweets:tweet_words=tweet.text.split()forwordintweet_words:# If the word contains only letters with a length from 4 to 9ifword.isalpha()andlen(word)>3andlen(word)<=9:words.add(word)returnwordsdefget_statements(self):""" Returns list of random statements from the API. """fromtwitterimportTwitterErrorstatements=[]# Generate a random wordrandom_word=self.random_word(self.random_seed_word,self.lang)self.logger.info(u'Requesting 50 random tweets containing the word {}'.format(random_word))tweets=self.api.GetSearch(term=random_word,count=50,lang=self.lang)fortweetintweets:statement=Statement(tweet.text)iftweet.in_reply_to_status_id:try:status=self.api.GetStatus(tweet.in_reply_to_status_id)statement.add_response(Response(status.text))statements.append(statement)exceptTwitterErroraserror:self.logger.warning(str(error))self.logger.info('Adding {} tweets with responses'.format(len(statements)))returnstatementsdeftrain(self):for_inrange(0,10):statements=self.get_statements()forstatementinstatements:self.storage.update(statement)

[docs]classUbuntuCorpusTrainer(Trainer):""" Allow chatbots to be trained with the data from the Ubuntu Dialog Corpus. """def__init__(self,storage,**kwargs):super(UbuntuCorpusTrainer,self).__init__(storage,**kwargs)self.data_download_url=kwargs.get('ubuntu_corpus_data_download_url','http://cs.mcgill.ca/~jpineau/datasets/ubuntu-corpus-1.0/ubuntu_dialogs.tgz')self.data_directory=kwargs.get('ubuntu_corpus_data_directory','./data/')self.extracted_data_directory=os.path.join(self.data_directory,'ubuntu_dialogs')# Create the data directory if it does not already existifnotos.path.exists(self.data_directory):os.makedirs(self.data_directory)defis_downloaded(self,file_path):""" Check if the data file is already downloaded. """ifos.path.exists(file_path):self.logger.info('File is already downloaded')returnTruereturnFalsedefis_extracted(self,file_path):""" Check if the data file is already extracted. """ifos.path.isdir(file_path):self.logger.info('File is already extracted')returnTruereturnFalsedefdownload(self,url,show_status=True):""" Download a file from the given url. Show a progress indicator for the download status. Based on: http://stackoverflow.com/a/15645088/1547223 """importrequestsfile_name=url.split('/')[-1]file_path=os.path.join(self.data_directory,file_name)# Do not download the data if it already existsifself.is_downloaded(file_path):returnfile_pathwithopen(file_path,'wb')asopen_file:print('Downloading %s'%url)response=requests.get(url,stream=True)total_length=response.headers.get('content-length')iftotal_lengthisNone:# No content length headeropen_file.write(response.content)else:download=0total_length=int(total_length)fordatainresponse.iter_content(chunk_size=4096):download+=len(data)open_file.write(data)ifshow_status:done=int(50*download/total_length)sys.stdout.write('\r[%s%s]'%('='*done,' '*(50-done)))sys.stdout.flush()# Add a new line after the download barsys.stdout.write('\n')print('Download location: %s'%file_path)returnfile_pathdefextract(self,file_path):""" Extract a tar file at the specified file path. """importtarfileprint('Extracting {}'.format(file_path))ifnotos.path.exists(self.extracted_data_directory):os.makedirs(self.extracted_data_directory)deftrack_progress(members):sys.stdout.write('.')formemberinmembers:# This will be the current file being extractedyieldmemberwithtarfile.open(file_path)astar:tar.extractall(path=self.extracted_data_directory,members=track_progress(tar))self.logger.info('File extracted to {}'.format(self.extracted_data_directory))returnTruedeftrain(self):importglobimportcsv# Download and extract the Ubuntu dialog corpus if neededcorpus_download_path=self.download(self.data_download_url)# Extract if the directory doesn not already existsifnotself.is_extracted(self.extracted_data_directory):self.extract(corpus_download_path)extracted_corpus_path=os.path.join(self.extracted_data_directory,'**','**','*.tsv')file_kwargs={}ifsys.version_info[0]>2:# Specify the encoding in Python versions 3 and upfile_kwargs['encoding']='utf-8'# WARNING: This might fail to read a unicode corpus file in Python 2.xforfileinglob.iglob(extracted_corpus_path):self.logger.info('Training from: {}'.format(file))withopen(file,'r',**file_kwargs)astsv:reader=csv.reader(tsv,delimiter='\t')previous_statement_text=Noneforrowinreader:iflen(row)>0:text=row[3]statement=self.get_or_create(text)print(text,len(row))statement.add_extra_data('datetime',row[0])statement.add_extra_data('speaker',row[1])ifrow[2].strip():statement.add_extra_data('addressing_speaker',row[2])ifprevious_statement_text:statement.add_response(Response(previous_statement_text))previous_statement_text=statement.textself.storage.update(statement)