Update model/modelling.py
This commit is contained in:
@@ -4,16 +4,18 @@ import torch.nn.functional as F
|
|||||||
from transformers import AutoModel
|
from transformers import AutoModel
|
||||||
|
|
||||||
class TransClassifier(nn.Module):
|
class TransClassifier(nn.Module):
|
||||||
def __init__(self, model_dir: str, device: str="cuda"):
|
def __init__(self, model_dir: str, output_classes: int, device: str="cuda"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.backbone = AutoModel.from_pretrained(
|
self.backbone = AutoModel.from_pretrained(
|
||||||
model_dir,
|
model_dir,
|
||||||
dtype = "bfloat16"
|
dtype = "bfloat16",
|
||||||
|
attn_implementation="flash_attention_2"
|
||||||
).to(device).eval()
|
).to(device).eval()
|
||||||
self.device = device
|
self.device = device
|
||||||
self.torch_dtype = torch.bfloat16
|
self.torch_dtype = torch.bfloat16
|
||||||
self.hidden_size = self.backbone.config.hidden_size
|
self.hidden_size = self.backbone.config.hidden_size
|
||||||
|
|
||||||
|
self.token_proj = nn.Linear(self.hidden_size, self.hidden_size).to(device=device, dtype=self.torch_dtype)
|
||||||
self.classifier = nn.Sequential(
|
self.classifier = nn.Sequential(
|
||||||
nn.LayerNorm(self.hidden_size),
|
nn.LayerNorm(self.hidden_size),
|
||||||
nn.Linear(self.hidden_size, self.hidden_size//2),
|
nn.Linear(self.hidden_size, self.hidden_size//2),
|
||||||
@@ -22,7 +24,7 @@ class TransClassifier(nn.Module):
|
|||||||
nn.Linear(self.hidden_size//2, self.hidden_size//4),
|
nn.Linear(self.hidden_size//2, self.hidden_size//4),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.Dropout(0.2),
|
nn.Dropout(0.2),
|
||||||
nn.Linear(self.hidden_size//4, 2)
|
nn.Linear(self.hidden_size//4, output_classes)
|
||||||
).to(device=device, dtype=self.torch_dtype)
|
).to(device=device, dtype=self.torch_dtype)
|
||||||
|
|
||||||
for param in self.backbone.parameters():
|
for param in self.backbone.parameters():
|
||||||
@@ -30,12 +32,15 @@ class TransClassifier(nn.Module):
|
|||||||
|
|
||||||
def forward(self, model_inputs: dict):
|
def forward(self, model_inputs: dict):
|
||||||
outputs = self.backbone(**model_inputs)
|
outputs = self.backbone(**model_inputs)
|
||||||
|
proj_states = self.token_proj(outputs.last_hidden_state)
|
||||||
|
|
||||||
last_hidden_state = outputs.last_hidden_state
|
attention_mask = model_inputs['attention_mask']
|
||||||
# take last token hidden state
|
mask_expanded = attention_mask.unsqueeze(-1).expand_as(proj_states).to(proj_states.dtype)
|
||||||
cls_hidden_state = last_hidden_state[:, -1, :]
|
sum_states = (proj_states * mask_expanded).sum(dim=1)
|
||||||
|
valid_tokens = mask_expanded.sum(dim=1)
|
||||||
|
pooled = sum_states / valid_tokens.clamp(min=1e-9)
|
||||||
|
|
||||||
logits = self.classifier(cls_hidden_state)
|
logits = self.classifier(pooled)
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user