Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions torchTextClassifiers/torchTextClassifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ class ModelConfig:
"""Base configuration class for text classifiers."""

embedding_dim: int
num_classes: int
categorical_vocabulary_sizes: Optional[List[int]] = None
categorical_embedding_dims: Optional[Union[List[int], int]] = None
num_classes: Optional[int] = None
attention_config: Optional[AttentionConfig] = None
label_attention_config: Optional[LabelAttentionConfig] = None
n_heads_label_attention: Optional[int] = None

def to_dict(self) -> Dict[str, Any]:
return asdict(self)
Expand Down Expand Up @@ -142,7 +142,7 @@ def __init__(
self.embedding_dim = model_config.embedding_dim
self.categorical_vocabulary_sizes = model_config.categorical_vocabulary_sizes
self.num_classes = model_config.num_classes
self.enable_label_attention = model_config.label_attention_config is not None
self.enable_label_attention = model_config.n_heads_label_attention is not None

if self.tokenizer.output_vectorized:
self.text_embedder = None
Expand All @@ -156,7 +156,10 @@ def __init__(
embedding_dim=self.embedding_dim,
padding_idx=tokenizer.padding_idx,
attention_config=model_config.attention_config,
label_attention_config=model_config.label_attention_config,
label_attention_config=LabelAttentionConfig(
n_head=model_config.n_heads_label_attention,
num_classes=model_config.num_classes,
),
)
self.text_embedder = TextEmbedder(
text_embedder_config=text_embedder_config,
Expand Down Expand Up @@ -697,10 +700,6 @@ def load(cls, path: Union[str, Path], device: str = "auto") -> "torchTextClassif

# Reconstruct model_config
model_config = ModelConfig.from_dict(metadata["model_config"])
if isinstance(model_config.label_attention_config, dict):
model_config.label_attention_config = LabelAttentionConfig(
**model_config.label_attention_config
)

# Create instance
instance = cls(
Expand Down
Loading