BERT情感分析训练教程 - IMDb数据集 本教程使用BERT模型对IMDb电影评论数据集进行情感分析训练,从数据准备到模型部署的完整流程。 
目录 
环境准备和依赖安装  
数据集下载和预处理  
数据集类定义  
模型训练工具  
模型初始化  
模型训练  
模型评估  
预测示例  
模型保存和加载  
总结和使用说明  
 
1. 环境准备和依赖安装 首先安装必要的Python包: 
1 pip install torch transformers scikit-learn tqdm requests numpy matplotlib seaborn 
 
然后导入所需的库: 
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 import  osimport  globimport  torchimport  numpy as  npimport  requestsimport  tarfilefrom  tqdm import  tqdmfrom  sklearn.model_selection import  train_test_splitfrom  torch.utils.data import  Dataset, DataLoaderfrom  transformers import  (    BertTokenizer,      BertForSequenceClassification,      AdamW,      get_linear_schedule_with_warmup ) device = torch.device('cuda'  if  torch.cuda.is_available() else  'cpu' ) print (f"使用设备: {device} " )if  torch.cuda.is_available():    print (f"GPU型号: {torch.cuda.get_device_name(0 )} " )     print (f"GPU内存: {torch.cuda.get_device_properties(0 ).total_memory / 1024 **3 :.1 f}  GB" ) 
 
2. 数据集下载和预处理 下载IMDb数据集 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 def  download_imdb_dataset (filename, url="http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz"  ):    """下载IMDb数据集"""      if  os.path.exists(filename):         print (f"{filename}  已存在,跳过下载" )         return           print (f"正在下载 {filename} ..." )     try :         response = requests.get(url, stream=True )         response.raise_for_status()                  total_size = int (response.headers.get('content-length' , 0 ))         with  open (filename, 'wb' ) as  f, tqdm(             desc=filename,             total=total_size,             unit='B' ,             unit_scale=True ,             unit_divisor=1024 ,         ) as  pbar:             for  chunk in  response.iter_content(chunk_size=8192 ):                 if  chunk:                     f.write(chunk)                     pbar.update(len (chunk))         print (f"{filename}  下载完成" )     except  Exception as  e:         print (f"下载失败: {e} " )         raise  def  extract_imdb_dataset (filename, extract_path='./'  ):    """解压IMDb数据集"""      if  os.path.exists(os.path.join(extract_path, 'aclImdb' )):         print ("数据集已解压,跳过解压" )         return           print (f"正在解压 {filename} ..." )     try :         with  tarfile.open (filename, 'r:gz' ) as  tar:             tar.extractall(path=extract_path)         print ("解压完成" )     except  Exception as  e:         print (f"解压失败: {e} " )         raise  dataset_file = 'aclImdb_v1.tar.gz'  dataset_extract_path = './'  download_imdb_dataset(dataset_file) extract_imdb_dataset(dataset_file, dataset_extract_path) 
 
加载和预处理数据 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 def  load_imdb_data (extract_path ):    """加载IMDb数据集"""      texts = []     labels = []               for  label in  ['pos' , 'neg' ]:         path = os.path.join(extract_path, 'aclImdb' , 'train' , label)         if  not  os.path.exists(path):             raise  FileNotFoundError(f"路径不存在: {path} " )                  files = glob.glob(os.path.join(path, '*.txt' ))         print (f"正在加载 {label}  数据: {len (files)}  个文件" )                  for  file in  tqdm(files, desc=f"Loading {label} " ):             try :                 with  open (file, 'r' , encoding='utf-8' ) as  f:                     text = f.read().strip()                     if  text:                           texts.append(text)                         labels.append(1  if  label == 'pos'  else  0 )             except  Exception as  e:                 print (f"读取文件失败 {file} : {e} " )                 continue           print (f"总共加载 {len (texts)}  条数据" )     return  texts, labels texts, labels = load_imdb_data(dataset_extract_path) labels = np.array(labels) print (f"\n数据集统计:" )print (f"总样本数: {len (texts)} " )print (f"积极样本: {sum (labels)}  ({sum (labels)/len (labels)*100 :.1 f} %)" )print (f"消极样本: {len (labels) - sum (labels)}  ({(len (labels) - sum (labels))/len (labels)*100 :.1 f} %)" )train_texts, test_texts, train_labels, test_labels = train_test_split(     texts, labels, test_size=0.2 , random_state=42 , stratify=labels ) train_texts, val_texts, train_labels, val_labels = train_test_split(     train_texts, train_labels, test_size=0.2 , random_state=42 , stratify=train_labels ) print (f"数据集划分:" )print (f"训练集大小: {len (train_texts)}  ({len (train_texts)/len (texts)*100 :.1 f} %)" )print (f"验证集大小: {len (val_texts)}  ({len (val_texts)/len (texts)*100 :.1 f} %)" )print (f"测试集大小: {len (test_texts)}  ({len (test_texts)/len (texts)*100 :.1 f} %)" )
 
3. 数据集类定义 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 class  IMDbDataset (Dataset ):    def  __init__ (self, texts, labels, tokenizer, max_len=512  ):         self .texts = texts         self .labels = labels         self .tokenizer = tokenizer         self .max_len = max_len     def  __len__ (self ):         return  len (self .texts)     def  __getitem__ (self, idx ):         text = str (self .texts[idx])         label = self .labels[idx]                  text = ' ' .join(text.split())         encoding = self .tokenizer.encode_plus(             text,             add_special_tokens=True ,             max_length=self .max_len,             return_token_type_ids=False ,             padding='max_length' ,             truncation=True ,             return_attention_mask=True ,             return_tensors='pt' ,         )         return  {             'input_ids' : encoding['input_ids' ].flatten(),             'attention_mask' : encoding['attention_mask' ].flatten(),             'labels' : torch.tensor(label, dtype=torch.long)         } 
 
4. 模型训练工具 早停机制 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 class  EarlyStopping :    def  __init__ (self, patience=5 , delta=0  ):         self .patience = patience         self .delta = delta         self .best_score = None          self .early_stop = False          self .counter = 0      def  __call__ (self, val_loss, model, model_path ):         score = -val_loss         if  self .best_score is  None :             self .best_score = score             self .save_checkpoint(model, model_path)         elif  score < self .best_score + self .delta:             self .counter += 1              print (f"EarlyStopping counter: {self.counter}  out of {self.patience} " )             if  self .counter >= self .patience:                 self .early_stop = True          else :             self .best_score = score             self .save_checkpoint(model, model_path)             self .counter = 0      def  save_checkpoint (self, model, model_path ):         """保存模型检查点"""          os.makedirs(os.path.dirname(model_path) if  os.path.dirname(model_path) else  '.' , exist_ok=True )         torch.save(model.state_dict(), model_path)         print (f"✅ 模型已保存到: {model_path} " ) 
 
训练函数 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 def  train_model (model, train_loader, val_loader, optimizer, scheduler, epochs=3 , model_path='best_model.pth'  ):    """训练模型"""      early_stopping = EarlyStopping(patience=3 , delta=0.001 )          for  epoch in  range (epochs):         print (f"\n{'=' *60 } " )         print (f"Epoch {epoch + 1 } /{epochs} " )         print (f"{'=' *60 } " )                           model.train()         total_loss = 0          train_progress = tqdm(train_loader, desc=f"Training" )                  for  batch in  train_progress:             input_ids = batch['input_ids' ].to(device)             attention_mask = batch['attention_mask' ].to(device)             labels = batch['labels' ].to(device)             optimizer.zero_grad()             outputs = model(input_ids, attention_mask=attention_mask, labels=labels)             loss = outputs.loss             loss.backward()                                       torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0 )                          optimizer.step()             scheduler.step()             total_loss += loss.item()             train_progress.set_postfix({'loss' : f'{loss.item():.4 f} ' })         avg_train_loss = total_loss / len (train_loader)         print (f"\n📊 平均训练损失: {avg_train_loss:.4 f} " )                  val_loss = validate_model(model, val_loader)         print (f"📊 验证损失: {val_loss:.4 f} " )                  early_stopping(val_loss, model, model_path)         if  early_stopping.early_stop:             print ("\n🛑 Early stopping triggered" )             break  def  validate_model (model, val_loader ):    """验证模型"""      model.eval ()     total_loss = 0           with  torch.no_grad():         val_progress = tqdm(val_loader, desc="Validating" )         for  batch in  val_progress:             input_ids = batch['input_ids' ].to(device)             attention_mask = batch['attention_mask' ].to(device)             labels = batch['labels' ].to(device)             outputs = model(input_ids, attention_mask=attention_mask, labels=labels)             loss = outputs.loss             total_loss += loss.item()             val_progress.set_postfix({'loss' : f'{loss.item():.4 f} ' })     return  total_loss / len (val_loader) def  evaluate_model (model, test_loader ):    """评估模型"""      model.eval ()     total_correct = 0      total_samples = 0      all_predictions = []     all_labels = []          with  torch.no_grad():         test_progress = tqdm(test_loader, desc="Testing" )         for  batch in  test_progress:             input_ids = batch['input_ids' ].to(device)             attention_mask = batch['attention_mask' ].to(device)             labels = batch['labels' ].to(device)             outputs = model(input_ids, attention_mask=attention_mask)             logits = outputs.logits             predictions = torch.argmax(logits, dim=1 )                          total_correct += (predictions == labels).sum ().item()             total_samples += labels.size(0 )                          all_predictions.extend(predictions.cpu().numpy())             all_labels.extend(labels.cpu().numpy())                          current_acc = total_correct / total_samples if  total_samples > 0  else  0              test_progress.set_postfix({'accuracy' : f'{current_acc:.4 f} ' })          accuracy = total_correct / total_samples     print (f"\n🎯 测试准确率: {accuracy:.4 f}  ({total_correct} /{total_samples} )" )     return  accuracy, all_predictions, all_labels 
 
5. 模型初始化 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 print ("🔄 正在加载tokenizer..." )tokenizer = BertTokenizer.from_pretrained('bert-base-uncased' ) print ("✅ Tokenizer加载完成" )sample_text = "This movie is absolutely amazing!"  tokens = tokenizer.tokenize(sample_text) print (f"\n示例文本: {sample_text} " )print (f"分词结果: {tokens[:10 ]} ..." )print (f"词汇表大小: {tokenizer.vocab_size} " )print ("🔄 正在创建数据加载器..." )train_dataset = IMDbDataset(train_texts, train_labels, tokenizer) val_dataset = IMDbDataset(val_texts, val_labels, tokenizer) test_dataset = IMDbDataset(test_texts, test_labels, tokenizer) batch_size = 16  print (f"批次大小: {batch_size} " )train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True ) val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False ) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False ) print (f"✅ 数据加载器创建完成" )print (f"训练批次数: {len (train_loader)} " )print (f"验证批次数: {len (val_loader)} " )print (f"测试批次数: {len (test_loader)} " )print ("🔄 正在加载BERT模型..." )model = BertForSequenceClassification.from_pretrained(     'bert-base-uncased' ,     num_labels=2    ).to(device) print ("✅ BERT模型加载完成" )print (f"模型参数数量: {sum (p.numel() for  p in  model.parameters()):,} " )print (f"可训练参数数量: {sum (p.numel() for  p in  model.parameters() if  p.requires_grad):,} " )print ("🔄 配置优化器和调度器..." )learning_rate = 2e-5  epochs = 3  eps = 1e-8  optimizer = AdamW(model.parameters(), lr=learning_rate, eps=eps) total_steps = len (train_loader) * epochs scheduler = get_linear_schedule_with_warmup(     optimizer,     num_warmup_steps=0 ,     num_training_steps=total_steps ) print (f"✅ 优化器配置完成" )print (f"学习率: {learning_rate} " )print (f"训练轮数: {epochs} " )print (f"总训练步数: {total_steps} " )
 
6. 模型训练 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 print ("🚀 开始训练模型..." )print (f"设备: {device} " )print (f"批次大小: {batch_size} " )print (f"训练轮数: {epochs} " )print ("\n"  + "=" *80 )model_path = 'best_model.pth'  train_model(     model=model,     train_loader=train_loader,     val_loader=val_loader,     optimizer=optimizer,     scheduler=scheduler,     epochs=epochs,     model_path=model_path ) print ("\n🎉 训练完成!" )
 
7. 模型评估 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 print ("🔄 加载最佳模型..." )if  os.path.exists(model_path):    model.load_state_dict(torch.load(model_path, map_location=device))     print ("✅ 最佳模型加载成功" ) else :    print ("⚠️ 最佳模型文件不存在,使用当前模型" ) print ("\n"  + "=" *60 )print ("📊 在测试集上评估模型..." )print ("=" *60 )accuracy, predictions, true_labels = evaluate_model(model, test_loader) from  sklearn.metrics import  classification_report, confusion_matriximport  matplotlib.pyplot as  pltimport  seaborn as  snsprint ("\n📋 分类报告:" )print (classification_report(true_labels, predictions, target_names=['消极' , '积极' ]))cm = confusion_matrix(true_labels, predictions) print ("\n🔍 混淆矩阵:" )print (cm)plt.figure(figsize=(8 , 6 )) sns.heatmap(cm, annot=True , fmt='d' , cmap='Blues' ,              xticklabels=['消极' , '积极' ], yticklabels=['消极' , '积极' ]) plt.title('混淆矩阵' ) plt.xlabel('预测标签' ) plt.ylabel('真实标签' ) plt.show() tn, fp, fn, tp = cm.ravel() precision = tp / (tp + fp) recall = tp / (tp + fn) f1 = 2  * (precision * recall) / (precision + recall) print (f"\n📊 详细指标:" )print (f"准确率 (Accuracy): {accuracy:.4 f} " )print (f"精确率 (Precision): {precision:.4 f} " )print (f"召回率 (Recall): {recall:.4 f} " )print (f"F1分数: {f1:.4 f} " )
 
8. 预测示例 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 def  predict_sentiment (text, model, tokenizer, device ):    """预测单个文本的情感"""      model.eval ()          text = ' ' .join(str (text).split())     encoding = tokenizer.encode_plus(         text,         add_special_tokens=True ,         max_length=512 ,         padding='max_length' ,         truncation=True ,         return_attention_mask=True ,         return_tensors='pt'      )     input_ids = encoding['input_ids' ].to(device)     attention_mask = encoding['attention_mask' ].to(device)     with  torch.no_grad():         outputs = model(input_ids, attention_mask=attention_mask)         logits = outputs.logits         probabilities = torch.nn.functional.softmax(logits, dim=1 )         prediction = torch.argmax(logits, dim=1 ).item()         confidence = probabilities[0 ][prediction].item()     sentiment = "积极"  if  prediction == 1  else  "消极"      return  sentiment, confidence test_texts_examples = [     "This movie is absolutely amazing! I love it!" ,     "This is the worst movie I have ever seen." ,     "The movie was okay, nothing special." ,     "Fantastic acting and great storyline!" ,     "I fell asleep during the movie. So boring." ,     "The cinematography was breathtaking and the story was compelling." ,     "Waste of time and money. Terrible plot." ,     "One of the best films I've watched this year!"  ] print ("🎬 情感分析预测示例:" )print ("=" *80 )for  i, text in  enumerate (test_texts_examples, 1 ):    sentiment, confidence = predict_sentiment(text, model, tokenizer, device)          emoji = "😊"  if  sentiment == "积极"  else  "😞"      print (f"\n{i} . 文本: {text} " )     print (f"   预测: {sentiment}  {emoji}  (置信度: {confidence:.4 f} )" )     print (f"   {'-' *60 } " ) print ("\n✅ 预测完成!" )
 
9. 模型保存和加载 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 model_save_path = './bert_sentiment_model'  print (f"💾 保存模型到: {model_save_path} " )os.makedirs(model_save_path, exist_ok=True ) model.save_pretrained(model_save_path) tokenizer.save_pretrained(model_save_path) print ("✅ 模型和tokenizer保存完成" )print (f"   模型文件: {model_save_path} /pytorch_model.bin" )print (f"   配置文件: {model_save_path} /config.json" )print (f"   词汇表: {model_save_path} /vocab.txt" )print ("🔄 测试模型加载..." )loaded_model = BertForSequenceClassification.from_pretrained(model_save_path).to(device) loaded_tokenizer = BertTokenizer.from_pretrained(model_save_path) print ("✅ 模型加载成功" )test_text = "This movie is fantastic!"  sentiment, confidence = predict_sentiment(test_text, loaded_model, loaded_tokenizer, device) print (f"\n🧪 测试预测:" )print (f"文本: {test_text} " )print (f"预测: {sentiment}  (置信度: {confidence:.4 f} )" )print ("\n✅ 模型加载测试通过!" )
 
10. 总结和使用说明 训练总结 **🎉 **BERT情感分析模型训练完成! 
📊 模型性能: 
测试集准确率: ~87-92%  
模型参数量: 109,483,778  
支持GPU加速训练  
 
💾 模型文件: 
**权重文件: **best_model.pth 
**完整模型: **./bert_sentiment_model 
 
🔧 使用方法: 
**加载模型: **BertForSequenceClassification.from_pretrained('./bert_sentiment_model') 
**加载tokenizer: **BertTokenizer.from_pretrained('./bert_sentiment_model') 
使用  predict_sentiment函数进行预测 
 
⚙️ 超参数: 
学习率: 2e-5  
批次大小: 16  
训练轮数: 3  
最大序列长度: 512  
 
应用场景 🎯 模型可以用于: 
电影评论情感分析  
**产品评价情感分析  ** 
社交媒体文本情感分析  
客户反馈分析  
 
注意事项 💡 重要提示: 
模型在英文文本上效果最佳  
可以通过更多数据和更长训练时间进一步提升性能  
建议在GPU上运行以获得最佳性能  
 
📝 技术要求: 
硬件要求 : 建议使用GPU训练,CPU训练会非常慢 
内存管理 : 如果GPU内存不足,可以减小 batch_size 
训练时间 : 完整训练可能需要1-3小时(取决于硬件) 
数据集 : 首次运行会自动下载约84MB的IMDb数据集 
模型大小 : 保存的模型约400MB 
 
扩展建议 🚀 进一步改进: 
数据增强 : 使用同义词替换、回译等技术 
超参数调优 : 尝试不同的学习率、批次大小等 
集成学习 : 结合多个模型的预测结果 
领域适应 : 针对特定领域进行微调 
多分类 : 扩展到细粒度情感分类(如1-5星评级) 
 
 
🔗 相关资源 
 
本教程提供了完整的BERT情感分析解决方案,从数据准备到模型部署的全流程。代码已经过测试,可以直接运行使用。