Upload files to "model"
This commit is contained in:
1
model/__init__.py
Normal file
1
model/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .modelling import TransClassifier
|
||||
52
model/modelling.py
Normal file
52
model/modelling.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import AutoModel
|
||||
|
||||
class TransClassifier(nn.Module):
|
||||
def __init__(self, model_dir: str, device: str="cuda"):
|
||||
super().__init__()
|
||||
self.backbone = AutoModel.from_pretrained(
|
||||
model_dir,
|
||||
dtype = "bfloat16"
|
||||
).to(device).eval()
|
||||
self.device = device
|
||||
self.torch_dtype = torch.bfloat16
|
||||
self.hidden_size = self.backbone.config.hidden_size
|
||||
|
||||
self.classifier = nn.Sequential(
|
||||
nn.LayerNorm(self.hidden_size),
|
||||
nn.Linear(self.hidden_size, self.hidden_size//2),
|
||||
nn.GELU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(self.hidden_size//2, self.hidden_size//4),
|
||||
nn.GELU(),
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(self.hidden_size//4, 2)
|
||||
).to(device=device, dtype=self.torch_dtype)
|
||||
|
||||
for param in self.backbone.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, model_inputs: dict):
|
||||
outputs = self.backbone(**model_inputs)
|
||||
|
||||
last_hidden_state = outputs.last_hidden_state
|
||||
# take last token hidden state
|
||||
cls_hidden_state = last_hidden_state[:, -1, :]
|
||||
|
||||
logits = self.classifier(cls_hidden_state)
|
||||
return logits
|
||||
|
||||
if __name__ == "__main__":
|
||||
model_dir = r"C:\Users\GA\Desktop\models\Qwen3-1.7B"
|
||||
device = "cuda"
|
||||
model = TransClassifier(model_dir, device)
|
||||
print(model.hidden_size)
|
||||
print(model)
|
||||
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
print(f"总参数量: {total_params:,}")
|
||||
print(f"可训练参数量: {trainable_params:,}")
|
||||
Reference in New Issue
Block a user