import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoModel class TransClassifier(nn.Module): def __init__(self, model_dir: str, output_classes: int, device: str="cuda"): super().__init__() self.backbone = AutoModel.from_pretrained( model_dir, dtype = "bfloat16", attn_implementation="flash_attention_2" ).to(device).eval() self.device = device self.torch_dtype = torch.bfloat16 self.hidden_size = self.backbone.config.hidden_size self.token_proj = nn.Linear(self.hidden_size, self.hidden_size).to(device=device, dtype=self.torch_dtype) self.classifier = nn.Sequential( nn.LayerNorm(self.hidden_size), nn.Linear(self.hidden_size, self.hidden_size//2), nn.GELU(), nn.Dropout(0.3), nn.Linear(self.hidden_size//2, self.hidden_size//4), nn.GELU(), nn.Dropout(0.2), nn.Linear(self.hidden_size//4, output_classes) ).to(device=device, dtype=self.torch_dtype) for param in self.backbone.parameters(): param.requires_grad = False def forward(self, model_inputs: dict): outputs = self.backbone(**model_inputs) proj_states = self.token_proj(outputs.last_hidden_state) attention_mask = model_inputs['attention_mask'] mask_expanded = attention_mask.unsqueeze(-1).expand_as(proj_states).to(proj_states.dtype) sum_states = (proj_states * mask_expanded).sum(dim=1) valid_tokens = mask_expanded.sum(dim=1) pooled = sum_states / valid_tokens.clamp(min=1e-9) logits = self.classifier(pooled) return logits if __name__ == "__main__": model_dir = r"C:\Users\GA\Desktop\models\Qwen3-1.7B" device = "cuda" model = TransClassifier(model_dir, device) print(model.hidden_size) print(model) total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"总参数量: {total_params:,}") print(f"可训练参数量: {trainable_params:,}")