Large language models (LLMs) demonstrate impressive capabilities across numerous domains, from programming and technical support to cooking instructions and research enhancement. However, they continue to struggle with precise symbol manipulation in tasks requiring rigorous logical reasoning, mathematical proofs, and complex algebraic operations. As a student of philosophy and computer science, I find this intersection especially fascinating because of the connections between logical precision and conceptual clarity. I wonder if Wittgeinstein would be fascinated with these models, do to his "language games", I digress however.
The Reasoning Gap in Current LLMs
Despite the introduction of techniques like chain-of-thought prompting, which helps models verify their own reasoning, fundamental limitations persist. LLMs still hallucinate and perform inconsistently when tackling tasks demanding sustained logical reasoning. The challenge is particularly evident when models need to correctly apply formal logic rules across extended inference chains.
Several approaches have attempted to address these shortcomings:
- Logical Neural Networks (LNNs) implement logical operators (AND, OR, NOT) using specialized neural network architectures
- Neural Semantic Parsing converts natural language into structured logical representations
- Chain-of-thought prompting encourages step-by-step reasoning and self-verification
While these approaches show promise, they haven't fully resolved the reasoning deficiencies inherent in transformer-based LLMs.
Logic-Enhanced Technique: A Core Architectural Solution

Rather than treating reasoning enhancement as a post-processing step or a prompt engineering challenge, I propose a novel method that targets the fundamental architecture of transformer models. As established in the groundbreaking "Attention Is All You Need" paper, the self-attention mechanism forms the backbone of modern LLMs, enabling them to process relationships between all tokens in a sequence.
The Logic-Enhanced Technique approach modifies this core attention mechanism, specifically:
- Logical Structure Identification: Using computational linguistics (via libraries like spaCy), the model identifies logical constructs within text, including:
- Operators (not, and, or, if-then)
- Quantifiers (all, every, some, none)
- Modal operators (must, could, should)
- Causal relationships (because, since, due to)
- Comparative structures (more than, less than, equal to)
- Weighted Attention Masks: The model dynamically generates an attention weighting mask that amplifies the connections between logically related elements.
- Enhanced Attention Computation: During the self-attention calculation, logical tokens receive elevated attention weights, ensuring the model properly captures and prioritizes logical relationships.
Implementation Details and Technical Innovations
The Logic-Enhanced Transformer introduces several technical innovations: (Note this model was fine tuned on a base LLaMA 3 8B model; these methods can be applied to other models, although certain changes may need to be applied specific to said model).
- SpacyLogicParser: A specialized component that identifies logical structures in text and creates a corresponding attention weighting mask
- LogicEnhancedLlamaAttention: A modified attention mechanism that incorporates logic-based weighting alongside standard self-attention
- Memory Optimization: Implementation of 8-bit quantization, bfloat16 precision, and Flash Attention 2 for efficient processing
- Parameter-Efficient Fine-Tuning: Integration with LoRA (Low-Rank Adaptation) for resource-efficient adaptation to logical reasoning tasks
Expected Benefits
This architectural modification offers several advantages over previous approaches:
- Targeted Enhancement: Rather than requiring complete retraining, it enhances the specific capability of logical reasoning
- Interpretable Approach: The logical structure identification provides transparency about which elements receive enhanced attention
- Integration with Existing Models: Works with mainstream transformer architectures like LLaMA without requiring specialized architectures
- Efficiency: The implementation is optimized for memory usage and computational overhead
Evaluation and Future Directions
Initial testing with classic logical reasoning problems, such as syllogistic fallacies ("If all birds can fly and penguins are birds, then penguins can fly"), shows promising results in the model's ability to identify logical inconsistencies. Training on the FOLIO dataset has also shown promising results, although I'm still testing and trying to refine the model.
Future work should focus on:
- Rigorous evaluation across diverse logical reasoning benchmarks
- Analysis of the computational overhead introduced by the parsing step
- Investigation into how this enhancement affects other model capabilities
- Integration with other reasoning-enhancement techniques like chain-of-thought prompting
By addressing reasoning at the architectural level rather than through prompting techniques or post-processing, the Logic-Enhanced Transformer approach represents a fundamental advancement in our ability to equip language models with stronger logical reasoning capabilities.
You'll find supplemental code below. If there are any issues with the code please let me know. Furthermore, this is not the 100% full code; however, it is enough and should suffice—the main mechanism is here. (Full code needs to be cleaned up.)
Please note, I do not have a PhD in machine learning nor computer science; all of this is done through experimentation, reading, and learning. I would also like to give a special shoutout to Sebastian Raschka, whose book "Build a Language Model (From Scratch)" greatly helped enhance my knowledge. Also, thanks to Hugging Face for all of their material and, of course, providing the models.
class SpacyLogicParser:
def __init__(self, model_name="en_core_web_lg"):
debug_print(f"Initializing SpacyLogicParser with model {model_name}")
try:
self.nlp = spacy.load(model_name)
debug_print("Loaded spaCy model successfully")
except OSError as e:
debug_print(f"Error loading spaCy model: {e}")
debug_print("Falling back to smaller model")
try:
# Try to load the smaller model
self.nlp = spacy.load("en_core_web_sm")
debug_print("Loaded fallback spaCy model successfully")
except OSError as e2:
debug_print(f"Error loading fallback model: {e2}")
raise
self.logical_terms = {
"operators": ["not", "and", "or", "if", "then", "implies", "unless", "except"],
"quantifiers": ["all", "every", "any", "each", "some", "many", "few", "no", "none"],
"modals": ["must", "should", "could", "may", "might", "can", "will", "would"],
"causal": ["because", "since", "due", "causes", "results", "leads", "follows"],
"comparison": ["more", "less", "greater", "equal", "same", "different", "than"]
}
def parse_logical_structure(self, text):
debug_print(f"Parsing logical structure of text: {text[:50]}...")
doc = self.nlp(text)
logical_tokens = {}
self._process_dependency_parse(doc, logical_tokens)
debug_print(f"Found {len(logical_tokens)} logical tokens")
return logical_tokens
def _process_dependency_parse(self, doc, logical_tokens):
for i, token in enumerate(doc):
token_lower = token.text.lower()
for category, terms in self.logical_terms.items():
if token_lower in terms:
weight = 2.5 if category == "operators" else 2.0
logical_tokens[i] = {"type": f"logical_{category}", "weight": weight}
if token.dep_ == "neg" or token_lower in ["not", "never", "no"]:
logical_tokens[i] = {"type": "negation", "weight": 2.5}
head_idx = token.head.i
logical_tokens[head_idx] = {"type": "negated_term", "weight": 2.0}
if token.pos_ == "DET" and token_lower in self.logical_terms["quantifiers"]:
logical_tokens[i] = {"type": "quantifier", "weight": 2.5}
head_idx = token.head.i
logical_tokens[head_idx] = {"type": "quantified_term", "weight": 2.0}
if token.pos_ == "SCONJ" and token_lower in ["if", "unless", "because", "although"]:
logical_tokens[i] = {"type": "logical_connective", "weight": 2.5}
for descendant in token.subtree:
if descendant.i != i:
logical_tokens[descendant.i] = {"type": "logical_clause", "weight": 1.8}
self._identify_conditional_structures(doc, logical_tokens)
self._identify_comparative_structures(doc, logical_tokens)
self._identify_causal_relationships(doc, logical_tokens)
self._identify_logical_equivalence(doc, logical_tokens)
def _identify_conditional_structures(self, doc, logical_tokens):
for token in doc:
if token.text.lower() in ["if", "when", "unless"]:
if_clause_tokens = list(token.subtree)
for t in if_clause_tokens:
logical_tokens[t.i] = {"type": "condition_if", "weight": 2.0}
if token.head.pos_ == "VERB":
main_verb = token.head
if_clause_indices = {t.i for t in if_clause_tokens}
for t in main_verb.subtree:
if t.i not in if_clause_indices and t.i != main_verb.i:
logical_tokens[t.i] = {"type": "condition_then", "weight": 1.8}
logical_tokens[main_verb.i] = {"type": "condition_then_verb", "weight": 2.2}
def _identify_comparative_structures(self, doc, logical_tokens):
for token in doc:
if token.tag_ in ["JJR", "RBR"] or token.text.lower() in ["more", "less", "greater", "fewer"]:
logical_tokens[token.i] = {"type": "comparison", "weight": 2.0}
for child in token.children:
if child.dep_ == "than":
logical_tokens[child.i] = {"type": "comparison_than", "weight": 1.8}
for than_child in child.children:
logical_tokens[than_child.i] = {"type": "comparison_object", "weight": 1.8}
def _identify_causal_relationships(self, doc, logical_tokens):
for token in doc:
if token.text.lower() in ["because", "since", "as"] and token.dep_ == "mark":
logical_tokens[token.i] = {"type": "causal_marker", "weight": 2.5}
for t in token.head.subtree:
if t.i != token.i:
logical_tokens[t.i] = {"type": "cause", "weight": 1.8}
if token.head.head.pos_ == "VERB":
effect_verb = token.head.head
cause_indices = {t.i for t in token.head.subtree}
for t in effect_verb.subtree:
if t.i not in cause_indices and t.i != effect_verb.i:
logical_tokens[t.i] = {"type": "effect", "weight": 1.8}
logical_tokens[effect_verb.i] = {"type": "effect_verb", "weight": 2.0}
def _identify_logical_equivalence(self, doc, logical_tokens):
for token in doc:
if token.lemma_ in ["be", "equal", "mean", "imply", "equivalent"]:
if token.pos_ in ["VERB", "AUX"]:
logical_tokens[token.i] = {"type": "equivalence", "weight": 2.2}
subject = None
obj = None
for child in token.children:
if child.dep_ in ["nsubj", "nsubjpass"]:
subject = child
elif child.dep_ in ["attr", "dobj", "pobj"]:
obj = child
if subject:
for t in subject.subtree:
logical_tokens[t.i] = {"type": "equivalence_subject", "weight": 1.8}
if obj:
for t in obj.subtree:
logical_tokens[t.i] = {"type": "equivalence_object", "weight": 1.8}
def create_logic_attention_mask(self, text, tokenizer):
debug_print(f"Creating logic attention mask for text: {text[:30]}...")
try:
logical_tokens = self.parse_logical_structure(text)
tokenized = tokenizer(text, return_tensors="pt")
input_ids = tokenized["input_ids"][0]
logic_mask = torch.ones(input_ids.size(), dtype=torch.float)
spacy_to_transformer = self._align_tokenizations(text, tokenizer)
for spacy_idx, token_info in logical_tokens.items():
transformer_indices = spacy_to_transformer.get(spacy_idx, [])
for idx in transformer_indices:
if 0 <= idx < len(logic_mask):
logic_mask[idx] = token_info["weight"]
debug_print(f"Created logic mask with shape {logic_mask.unsqueeze(0).shape}")
return logic_mask.unsqueeze(0)
except Exception as e:
debug_print(f"Error creating logic mask: {e}")
import traceback
traceback.print_exc()
tokenized = tokenizer(text, return_tensors="pt")
input_ids = tokenized["input_ids"][0]
default_mask = torch.ones(input_ids.size(), dtype=torch.float).unsqueeze(0)
debug_print(f"Returning default mask with shape {default_mask.shape}")
return default_mask
def _align_tokenizations(self, text, tokenizer):
debug_print("Aligning spaCy and transformer tokenizations")
doc = self.nlp(text)
transformer_encoding = tokenizer.encode_plus(text, return_offsets_mapping=True)
transformer_offsets = transformer_encoding["offset_mapping"]
mapping = {}
for spacy_token in doc:
spacy_start = spacy_token.idx
spacy_end = spacy_start + len(spacy_token.text)
matching_transformer_indices = []
for transformer_idx, (trans_start, trans_end) in enumerate(transformer_offsets):
if trans_start == trans_end == 0:
continue
if max(spacy_start, trans_start) < min(spacy_end, trans_end):
matching_transformer_indices.append(transformer_idx)
mapping[spacy_token.i] = matching_transformer_indices
debug_print(f"Alignment mapping created for {len(mapping)} tokens")
return mapping
class LogicEnhancedLlamaAttention(torch.nn.Module):
def __init__(self, original_attention):
super().__init__()
debug_print("Initializing LogicEnhancedLlamaAttention")
self.original_attention = original_attention
self.q_proj = original_attention.q_proj
self.k_proj = original_attention.k_proj
self.v_proj = original_attention.v_proj
self.o_proj = original_attention.o_proj
self.rotary_emb = getattr(original_attention, 'rotary_emb', None)
self.original_forward = original_attention.forward
self.rope_scaling = getattr(original_attention, 'rope_scaling', None)
debug_print(f"Initialized with original attention of type {type(original_attention)}")
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
logic_mask: Optional[torch.Tensor] = None,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
debug_print(f"LogicEnhancedLlamaAttention.forward called with logic_mask: {logic_mask}")
if logic_mask is None or not (torch.is_tensor(logic_mask) and logic_mask.numel() > 0):
debug_print("Using original attention (no valid logic mask)")
return self.original_forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
**kwargs
)
try:
debug_print("Using logic-enhanced attention")
attn_outputs = self.original_forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=True,
use_cache=use_cache,
**kwargs
)
attn_output = attn_outputs[0]
if use_cache:
past_key_value = attn_outputs[1]
if len(attn_outputs) > 2:
debug_print("Got attention weights, applying logic mask")
attn_weights = attn_outputs[2]
debug_print(f"Attention weights shape: {attn_weights.shape}")
if logic_mask.dim() == 2:
debug_print(f"Logic mask shape: {logic_mask.shape}")
seq_len = hidden_states.size(1)
mask_len = logic_mask.size(1)
debug_print(f"Sequence length: {seq_len}, Mask length: {mask_len}")
if mask_len < seq_len:
debug_print(f"Padding logic mask from {mask_len} to {seq_len}")
padding = torch.ones(logic_mask.size(0), seq_len - mask_len, device=logic_mask.device)
logic_mask = torch.cat([logic_mask, padding], dim=1)
elif mask_len > seq_len:
debug_print(f"Truncating logic mask from {mask_len} to {seq_len}")
logic_mask = logic_mask[:, :seq_len]
row_mask = logic_mask.unsqueeze(1).unsqueeze(-1)
col_mask = logic_mask.unsqueeze(1).unsqueeze(2)
debug_print("Applying row and column masks to attention weights")
modified_attn_weights = attn_weights * row_mask * col_mask
debug_print("Normalizing modified attention weights")
modified_attn_weights = torch.nn.functional.softmax(
modified_attn_weights, dim=-1, dtype=torch.float32
).to(attn_weights.dtype)
batch_size, seq_length = hidden_states.shape[:2]
value_states = self.v_proj(hidden_states)
if hasattr(self.original_attention, 'num_heads'):
num_heads = self.original_attention.num_heads
debug_print(f"Using num_heads={num_heads} from original attention")
else:
num_heads = attn_weights.size(1)
debug_print(f"Inferred num_heads={num_heads} from attention weights")
head_dim = value_states.size(-1) // num_heads
debug_print(f"head_dim={head_dim}")
try:
debug_print(f"Reshaping value states from {value_states.shape}")
value_states = value_states.view(batch_size, seq_length, num_heads, head_dim).transpose(1, 2)
debug_print(f"Reshaped value states to {value_states.shape}")
debug_print("Computing new attention output")
attn_output = torch.matmul(modified_attn_weights, value_states)
debug_print(f"New attention output shape: {attn_output.shape}")
attn_output = attn_output.transpose(1, 2).contiguous().reshape(batch_size, seq_length, -1)
debug_print("Applying output projection")
attn_output = self.o_proj(attn_output)
outputs = (attn_output,)
if use_cache:
outputs += (past_key_value,)
if output_attentions:
outputs += (modified_attn_weights,)
debug_print("Returning modified attention outputs")
return outputs
except RuntimeError as e:
debug_print(f"Error in logic-enhanced attention matrix operations: {e}")
import traceback
traceback.print_exc()
debug_print("Falling back to original attention output")
pass
debug_print("Returning original attention outputs")
return attn_outputs
except Exception as e:
debug_print(f"Exception in logic-enhanced attention: {e}")
import traceback
traceback.print_exc()
debug_print("Falling back to original attention due to exception")
return self.original_forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
**kwargs
)
def add_logic_enhanced_attention(model):
"""
Replace the standard attention in each layer with logic-enhanced attention.
This is the primary mechanism for adding logic enhancement.
"""
debug_print("Adding logic-enhanced attention to model")
layers = identify_layers_path(model)
debug_print(f"Found {len(layers)} layers")
enhanced_count = 0
for i, layer in enumerate(layers):
if hasattr(layer, 'self_attn'):
debug_print(f"Enhancing attention in layer {i}")
original_attention = layer.self_attn
layer.self_attn = LogicEnhancedLlamaAttention(original_attention)
enhanced_count += 1
else:
debug_print(f"Layer {i} does not have self_attn attribute")
debug_print(f"Enhanced {enhanced_count} attention layers")
return model
def debug_print(msg):
# Replace with your actual debug print implementation (e.g., logging)
print(f"[DEBUG] {msg}")
def identify_layers_path(model):
"""Replace with your actual layer identification logic."""
# Example assuming a LLaMA-like structure
return model.model.layers # Access LLaMA-specific layers
class MemoryOptimizedLogicEnhancedLLaMA:
def __init__(self, model_name="meta-llama/Meta-Llama-3-8B"):
debug_print(f"Initializing MemoryOptimizedLogicEnhancedLLaMA with model {model_name}")
# Set environment variable for PyTorch memory management
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:256"
# Load tokenizer with same settings
debug_print("Loading tokenizer")
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
cache_dir=CACHE_DIR,
use_auth_token=True,
trust_remote_code=True
)
if not self.tokenizer.pad_token:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Load model with memory optimization flags
debug_print("Loading model")
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
cache_dir=CACHE_DIR,
device_map="auto",
use_auth_token=True,
trust_remote_code=True,
load_in_8bit=True, # Use 8-bit quantization
torch_dtype=torch.bfloat16, # Use bfloat16 precision
attn_implementation="flash_attention_2", # Use Flash Attention 2 for H100
)
debug_print(f"Model loaded with type {type(self.model)}")
# Initialize logic parser
debug_print("Initializing logic parser")
self.logic_parser = SpacyLogicParser()
# Get and store reference to model layers for later use
debug_print("Identifying model layers")
self.model_layers = identify_layers_path(self.model)
# Enhance the model with logic-aware attention
debug_print("Enhancing model with logic-aware attention")
self.model = add_logic_enhanced_attention(self.model)
debug_print("Applying LoRA for fine-tuning") # Re-added LoRA
self.apply_lora() # Re-added LoRA
def apply_lora(self):
"""Apply LoRA adapter for memory-efficient fine-tuning"""
debug_print("Configuring LoRA")
lora_config = LoraConfig( # Re-added LoRA
r=16, # Rank of the update matrices
lora_alpha=32, # Parameter for scaling
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
lora_dropout=0.05,
bias="none",
task_type=TaskType.CAUSAL_LM
)
debug_print("Applying LoRA to model")
self.model = get_peft_model(self.model, lora_config)
self.model.print_trainable_parameters()
def process_with_logic(self, text, use_logic_mask=True):
"""Process text with or without logic enhancement."""
debug_print(f"Processing text (use_logic_mask={use_logic_mask})")
inputs = self.tokenizer(text, return_tensors="pt", padding=True).to(self.model.device)
if use_logic_mask:
debug_print("Creating logic mask")
logic_mask = self.logic_parser.create_logic_attention_mask(text, self.tokenizer)
logic_mask = logic_mask.to(self.model.device)
debug_print("Forwarding with logic mask")
outputs = self.forward_with_logic_mask(inputs, logic_mask)
else:
debug_print("Standard processing without logic mask")
outputs = self.model(**inputs)
return outputs
def forward_with_logic_mask(self, inputs, logic_mask):
"""Helper method to forward inputs with a logic mask."""
debug_print("In forward_with_logic_mask")
input_dict = {k: v for k, v in inputs.items()}
debug_print(f"Setting logic mask for {len(self.model_layers)} layers")
for layer in self.model_layers:
if hasattr(layer, 'self_attn') and isinstance(layer.self_attn, LogicEnhancedLlamaAttention):
if not hasattr(layer.self_attn, '_original_forward'):
layer.self_attn._original_forward = layer.self_attn.forward
def wrapped_forward(self_attn, *args, logic_mask=logic_mask, **kwargs):
kwargs['logic_mask'] = logic_mask
return self_attn._original_forward(*args, **kwargs)
layer.self_attn.forward = types.MethodType(wrapped_forward, layer.self_attn)
try:
debug_print("Calling model with modified layers")
outputs = self.model(**input_dict)
for layer in self.model_layers:
if hasattr(layer, 'self_attn') and hasattr(layer.self_attn, '_original_forward'):
layer.self_attn.forward = layer.self_attn._original_forward
return outputs
except Exception as e:
debug_print(f"Error in forward_with_logic_mask: {e}")
import traceback
traceback.print_exc()
for layer in self.model_layers:
if hasattr(layer, 'self_attn') and hasattr(layer.self_attn, '_original_forward'):
layer.self_attn.forward = layer.self_attn._original_forward
debug_print("Falling back to standard processing")
return self.model(**input_dict)
def generate_with_logic(self, text, use_logic_mask=True, max_new_tokens=1024, **kwargs):
"""Generate text with or without logic enhancement."""
debug_print(f"Generating text (use_logic_mask={use_logic_mask})")
inputs = self.tokenizer(text, return_tensors="pt", padding=True).to(self.model.device)
generation_kwargs = {
"max_new_tokens": max_new_tokens,
"num_return_sequences": 1,
"pad_token_id": self.tokenizer.pad_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
**kwargs
}
if not use_logic_mask:
debug_print("Using original attention (no logic mask)")
outputs = self.model.generate(**inputs, **generation_kwargs)
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
debug_print(f"Generated text length: {len(generated_text)}")
return generated_text
debug_print("Using logic-enhanced generation")
try:
logic_mask = self.logic_parser.create_logic_attention_mask(text, self.tokenizer)
logic_mask = logic_mask.to(self.model.device)
debug_print(f"Created logic mask with shape: {logic_mask.shape}")
for layer in self.model_layers:
if hasattr(layer, 'self_attn'):
original_attn_forward = layer.self_attn.forward
def make_enhanced_forward(original_forward):
def logic_enhanced_forward(self_attn, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, output_attentions=False, use_cache=False, **kwargs):
debug_print("Inside logic_enhanced_forward")
kwargs['logic_mask'] = logic_mask
return original_forward(hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, **kwargs)
return logic_enhanced_forward
layer.self_attn.forward = types.MethodType(make_enhanced_forward(original_attn_forward), layer.self_attn)
try:
outputs = self.model.generate(**inputs, **generation_kwargs)
finally:
for layer in self.model_layers:
if hasattr(layer, 'self_attn'):
layer.self_attn.forward = layer.self_attn.original_forward
except Exception as e:
debug_print(f"Error in logic-enhanced generation: {e}")
import traceback
traceback.print_exc()
outputs = self.model.generate(**inputs, **generation_kwargs)
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
debug_print(f"Generated text length: {len(generated_text)}")
return generated_text