A Novel "Reasoning"-Enhancing Technique for Large Language Models


Large language models (LLMs) demonstrate impressive capabilities across numerous domains, from programming and technical support to cooking instructions and research enhancement. However, they continue to struggle with precise symbol manipulation in tasks requiring rigorous logical reasoning, mathematical proofs, and complex algebraic operations. As a student of philosophy and computer science, I find this intersection especially fascinating because of the connections between logical precision and conceptual clarity. I wonder if Wittgeinstein would be fascinated with these models, do to his "language games", I digress however.

The Reasoning Gap in Current LLMs

Despite the introduction of techniques like chain-of-thought prompting, which helps models verify their own reasoning, fundamental limitations persist. LLMs still hallucinate and perform inconsistently when tackling tasks demanding sustained logical reasoning. The challenge is particularly evident when models need to correctly apply formal logic rules across extended inference chains.

Several approaches have attempted to address these shortcomings:

While these approaches show promise, they haven't fully resolved the reasoning deficiencies inherent in transformer-based LLMs.

Logic-Enhanced Technique: A Core Architectural Solution

Transformer Architecture Diagram

Rather than treating reasoning enhancement as a post-processing step or a prompt engineering challenge, I propose a novel method that targets the fundamental architecture of transformer models. As established in the groundbreaking "Attention Is All You Need" paper, the self-attention mechanism forms the backbone of modern LLMs, enabling them to process relationships between all tokens in a sequence.

The Logic-Enhanced Technique approach modifies this core attention mechanism, specifically:

Implementation Details and Technical Innovations

The Logic-Enhanced Transformer introduces several technical innovations: (Note this model was fine tuned on a base LLaMA 3 8B model; these methods can be applied to other models, although certain changes may need to be applied specific to said model).

Expected Benefits

This architectural modification offers several advantages over previous approaches:

Evaluation and Future Directions

Initial testing with classic logical reasoning problems, such as syllogistic fallacies ("If all birds can fly and penguins are birds, then penguins can fly"), shows promising results in the model's ability to identify logical inconsistencies. Training on the FOLIO dataset has also shown promising results, although I'm still testing and trying to refine the model.

Future work should focus on:

By addressing reasoning at the architectural level rather than through prompting techniques or post-processing, the Logic-Enhanced Transformer approach represents a fundamental advancement in our ability to equip language models with stronger logical reasoning capabilities.

You'll find supplemental code below. If there are any issues with the code please let me know. Furthermore, this is not the 100% full code; however, it is enough and should suffice—the main mechanism is here. (Full code needs to be cleaned up.)

Please note, I do not have a PhD in machine learning nor computer science; all of this is done through experimentation, reading, and learning. I would also like to give a special shoutout to Sebastian Raschka, whose book "Build a Language Model (From Scratch)" greatly helped enhance my knowledge. Also, thanks to Hugging Face for all of their material and, of course, providing the models.

class SpacyLogicParser:
   def __init__(self, model_name="en_core_web_lg"):
       debug_print(f"Initializing SpacyLogicParser with model {model_name}")
       try:
           self.nlp = spacy.load(model_name)
           debug_print("Loaded spaCy model successfully")
       except OSError as e:
           debug_print(f"Error loading spaCy model: {e}")
           debug_print("Falling back to smaller model")
           try:
               # Try to load the smaller model
               self.nlp = spacy.load("en_core_web_sm")
               debug_print("Loaded fallback spaCy model successfully")
           except OSError as e2:
               debug_print(f"Error loading fallback model: {e2}")
               raise
      
       self.logical_terms = {
           "operators": ["not", "and", "or", "if", "then", "implies", "unless", "except"],
           "quantifiers": ["all", "every", "any", "each", "some", "many", "few", "no", "none"],
           "modals": ["must", "should", "could", "may", "might", "can", "will", "would"],
           "causal": ["because", "since", "due", "causes", "results", "leads", "follows"],
           "comparison": ["more", "less", "greater", "equal", "same", "different", "than"]
       }
  
   def parse_logical_structure(self, text):
       debug_print(f"Parsing logical structure of text: {text[:50]}...")
       doc = self.nlp(text)
       logical_tokens = {}
       self._process_dependency_parse(doc, logical_tokens)
       debug_print(f"Found {len(logical_tokens)} logical tokens")
       return logical_tokens
  
   def _process_dependency_parse(self, doc, logical_tokens):
       for i, token in enumerate(doc):
           token_lower = token.text.lower()
           for category, terms in self.logical_terms.items():
               if token_lower in terms:
                   weight = 2.5 if category == "operators" else 2.0
                   logical_tokens[i] = {"type": f"logical_{category}", "weight": weight}
           if token.dep_ == "neg" or token_lower in ["not", "never", "no"]:
               logical_tokens[i] = {"type": "negation", "weight": 2.5}
               head_idx = token.head.i
               logical_tokens[head_idx] = {"type": "negated_term", "weight": 2.0}
           if token.pos_ == "DET" and token_lower in self.logical_terms["quantifiers"]:
               logical_tokens[i] = {"type": "quantifier", "weight": 2.5}
               head_idx = token.head.i
               logical_tokens[head_idx] = {"type": "quantified_term", "weight": 2.0}
           if token.pos_ == "SCONJ" and token_lower in ["if", "unless", "because", "although"]:
               logical_tokens[i] = {"type": "logical_connective", "weight": 2.5}
               for descendant in token.subtree:
                   if descendant.i != i:
                       logical_tokens[descendant.i] = {"type": "logical_clause", "weight": 1.8}
       self._identify_conditional_structures(doc, logical_tokens)
       self._identify_comparative_structures(doc, logical_tokens)
       self._identify_causal_relationships(doc, logical_tokens)
       self._identify_logical_equivalence(doc, logical_tokens)
  
   def _identify_conditional_structures(self, doc, logical_tokens):
       for token in doc:
           if token.text.lower() in ["if", "when", "unless"]:
               if_clause_tokens = list(token.subtree)
               for t in if_clause_tokens:
                   logical_tokens[t.i] = {"type": "condition_if", "weight": 2.0}
               if token.head.pos_ == "VERB":
                   main_verb = token.head
                   if_clause_indices = {t.i for t in if_clause_tokens}
                   for t in main_verb.subtree:
                       if t.i not in if_clause_indices and t.i != main_verb.i:
                           logical_tokens[t.i] = {"type": "condition_then", "weight": 1.8}
                   logical_tokens[main_verb.i] = {"type": "condition_then_verb", "weight": 2.2}
  
   def _identify_comparative_structures(self, doc, logical_tokens):
       for token in doc:
           if token.tag_ in ["JJR", "RBR"] or token.text.lower() in ["more", "less", "greater", "fewer"]:
               logical_tokens[token.i] = {"type": "comparison", "weight": 2.0}
               for child in token.children:
                   if child.dep_ == "than":
                       logical_tokens[child.i] = {"type": "comparison_than", "weight": 1.8}
                       for than_child in child.children:
                           logical_tokens[than_child.i] = {"type": "comparison_object", "weight": 1.8}
  
   def _identify_causal_relationships(self, doc, logical_tokens):
       for token in doc:
           if token.text.lower() in ["because", "since", "as"] and token.dep_ == "mark":
               logical_tokens[token.i] = {"type": "causal_marker", "weight": 2.5}
               for t in token.head.subtree:
                   if t.i != token.i:
                       logical_tokens[t.i] = {"type": "cause", "weight": 1.8}
               if token.head.head.pos_ == "VERB":
                   effect_verb = token.head.head
                   cause_indices = {t.i for t in token.head.subtree}
                   for t in effect_verb.subtree:
                       if t.i not in cause_indices and t.i != effect_verb.i:
                           logical_tokens[t.i] = {"type": "effect", "weight": 1.8}
                   logical_tokens[effect_verb.i] = {"type": "effect_verb", "weight": 2.0}
  
   def _identify_logical_equivalence(self, doc, logical_tokens):
       for token in doc:
           if token.lemma_ in ["be", "equal", "mean", "imply", "equivalent"]:
               if token.pos_ in ["VERB", "AUX"]:
                   logical_tokens[token.i] = {"type": "equivalence", "weight": 2.2}
                   subject = None
                   obj = None
                   for child in token.children:
                       if child.dep_ in ["nsubj", "nsubjpass"]:
                           subject = child
                       elif child.dep_ in ["attr", "dobj", "pobj"]:
                           obj = child
                   if subject:
                       for t in subject.subtree:
                           logical_tokens[t.i] = {"type": "equivalence_subject", "weight": 1.8}
                   if obj:
                       for t in obj.subtree:
                           logical_tokens[t.i] = {"type": "equivalence_object", "weight": 1.8}
  
   def create_logic_attention_mask(self, text, tokenizer):
       debug_print(f"Creating logic attention mask for text: {text[:30]}...")
       try:
           logical_tokens = self.parse_logical_structure(text)
           tokenized = tokenizer(text, return_tensors="pt")
           input_ids = tokenized["input_ids"][0]
           logic_mask = torch.ones(input_ids.size(), dtype=torch.float)
           spacy_to_transformer = self._align_tokenizations(text, tokenizer)
           for spacy_idx, token_info in logical_tokens.items():
               transformer_indices = spacy_to_transformer.get(spacy_idx, [])
               for idx in transformer_indices:
                   if 0 <= idx < len(logic_mask):
                       logic_mask[idx] = token_info["weight"]
          
           debug_print(f"Created logic mask with shape {logic_mask.unsqueeze(0).shape}")
           return logic_mask.unsqueeze(0)
       except Exception as e:
           debug_print(f"Error creating logic mask: {e}")
           import traceback
           traceback.print_exc()
           tokenized = tokenizer(text, return_tensors="pt")
           input_ids = tokenized["input_ids"][0]
           default_mask = torch.ones(input_ids.size(), dtype=torch.float).unsqueeze(0)
           debug_print(f"Returning default mask with shape {default_mask.shape}")
           return default_mask
  
   def _align_tokenizations(self, text, tokenizer):
       debug_print("Aligning spaCy and transformer tokenizations")
       doc = self.nlp(text)
       transformer_encoding = tokenizer.encode_plus(text, return_offsets_mapping=True)
       transformer_offsets = transformer_encoding["offset_mapping"]
       mapping = {}
       for spacy_token in doc:
           spacy_start = spacy_token.idx
           spacy_end = spacy_start + len(spacy_token.text)
           matching_transformer_indices = []
           for transformer_idx, (trans_start, trans_end) in enumerate(transformer_offsets):
               if trans_start == trans_end == 0:
                   continue
               if max(spacy_start, trans_start) < min(spacy_end, trans_end):
                   matching_transformer_indices.append(transformer_idx)
           mapping[spacy_token.i] = matching_transformer_indices
      
       debug_print(f"Alignment mapping created for {len(mapping)} tokens")
       return mapping


class LogicEnhancedLlamaAttention(torch.nn.Module):
   def __init__(self, original_attention):
       super().__init__()
       debug_print("Initializing LogicEnhancedLlamaAttention")
       self.original_attention = original_attention
       self.q_proj = original_attention.q_proj
       self.k_proj = original_attention.k_proj
       self.v_proj = original_attention.v_proj
       self.o_proj = original_attention.o_proj
       self.rotary_emb = getattr(original_attention, 'rotary_emb', None)
      
       self.original_forward = original_attention.forward
       self.rope_scaling = getattr(original_attention, 'rope_scaling', None)
      
       debug_print(f"Initialized with original attention of type {type(original_attention)}")
  
   def forward(
       self,
       hidden_states: torch.Tensor,
       attention_mask: Optional[torch.Tensor] = None,
       position_ids: Optional[torch.LongTensor] = None,
       past_key_value: Optional[Tuple[torch.Tensor]] = None,
       output_attentions: bool = False,
       use_cache: bool = False,
       logic_mask: Optional[torch.Tensor] = None,
       **kwargs
   ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
       debug_print(f"LogicEnhancedLlamaAttention.forward called with logic_mask: {logic_mask}")
      
       if logic_mask is None or not (torch.is_tensor(logic_mask) and logic_mask.numel() > 0):
           debug_print("Using original attention (no valid logic mask)")
           return self.original_forward(
               hidden_states=hidden_states,
               attention_mask=attention_mask,
               position_ids=position_ids,
               past_key_value=past_key_value,
               output_attentions=output_attentions,
               use_cache=use_cache,
               **kwargs
           )
      
       try:
           debug_print("Using logic-enhanced attention")
           attn_outputs = self.original_forward(
               hidden_states=hidden_states,
               attention_mask=attention_mask,
               position_ids=position_ids,
               past_key_value=past_key_value,
               output_attentions=True,
               use_cache=use_cache,
               **kwargs
           )
          
           attn_output = attn_outputs[0]
           if use_cache:
               past_key_value = attn_outputs[1]
          
           if len(attn_outputs) > 2:
               debug_print("Got attention weights, applying logic mask")
               attn_weights = attn_outputs[2]
               debug_print(f"Attention weights shape: {attn_weights.shape}")
              
               if logic_mask.dim() == 2:
                   debug_print(f"Logic mask shape: {logic_mask.shape}")
                   seq_len = hidden_states.size(1)
                   mask_len = logic_mask.size(1)
                  
                   debug_print(f"Sequence length: {seq_len}, Mask length: {mask_len}")
                  
                   if mask_len < seq_len:
                       debug_print(f"Padding logic mask from {mask_len} to {seq_len}")
                       padding = torch.ones(logic_mask.size(0), seq_len - mask_len, device=logic_mask.device)
                       logic_mask = torch.cat([logic_mask, padding], dim=1)
                   elif mask_len > seq_len:
                       debug_print(f"Truncating logic mask from {mask_len} to {seq_len}")
                       logic_mask = logic_mask[:, :seq_len]
                  
                   row_mask = logic_mask.unsqueeze(1).unsqueeze(-1)
                   col_mask = logic_mask.unsqueeze(1).unsqueeze(2)
                  
                   debug_print("Applying row and column masks to attention weights")
                   modified_attn_weights = attn_weights * row_mask * col_mask
                  
                   debug_print("Normalizing modified attention weights")
                   modified_attn_weights = torch.nn.functional.softmax(
                       modified_attn_weights, dim=-1, dtype=torch.float32
                   ).to(attn_weights.dtype)
                  
                   batch_size, seq_length = hidden_states.shape[:2]
                   value_states = self.v_proj(hidden_states)
                  
                   if hasattr(self.original_attention, 'num_heads'):
                       num_heads = self.original_attention.num_heads
                       debug_print(f"Using num_heads={num_heads} from original attention")
                   else:
                       num_heads = attn_weights.size(1)
                       debug_print(f"Inferred num_heads={num_heads} from attention weights")
                  
                   head_dim = value_states.size(-1) // num_heads
                   debug_print(f"head_dim={head_dim}")
                  
                   try:
                       debug_print(f"Reshaping value states from {value_states.shape}")
                       value_states = value_states.view(batch_size, seq_length, num_heads, head_dim).transpose(1, 2)
                       debug_print(f"Reshaped value states to {value_states.shape}")
                      
                       debug_print("Computing new attention output")
                       attn_output = torch.matmul(modified_attn_weights, value_states)
                       debug_print(f"New attention output shape: {attn_output.shape}")
                       attn_output = attn_output.transpose(1, 2).contiguous().reshape(batch_size, seq_length, -1)
                      
                       debug_print("Applying output projection")
                       attn_output = self.o_proj(attn_output)
                      
                       outputs = (attn_output,)
                       if use_cache:
                           outputs += (past_key_value,)
                       if output_attentions:
                           outputs += (modified_attn_weights,)
                      
                       debug_print("Returning modified attention outputs")
                       return outputs
                      
                   except RuntimeError as e:
                       debug_print(f"Error in logic-enhanced attention matrix operations: {e}")
                       import traceback
                       traceback.print_exc()
                       debug_print("Falling back to original attention output")
                       pass
          
           debug_print("Returning original attention outputs")
           return attn_outputs
          
       except Exception as e:
           debug_print(f"Exception in logic-enhanced attention: {e}")
           import traceback
           traceback.print_exc()
           debug_print("Falling back to original attention due to exception")
           return self.original_forward(
               hidden_states=hidden_states,
               attention_mask=attention_mask,
               position_ids=position_ids,
               past_key_value=past_key_value,
               output_attentions=output_attentions,
               use_cache=use_cache,
               **kwargs
           )


def add_logic_enhanced_attention(model):
   """
   Replace the standard attention in each layer with logic-enhanced attention.
   This is the primary mechanism for adding logic enhancement.
   """
   debug_print("Adding logic-enhanced attention to model")
   layers = identify_layers_path(model)
   debug_print(f"Found {len(layers)} layers")
  
   enhanced_count = 0
   for i, layer in enumerate(layers):
       if hasattr(layer, 'self_attn'):
           debug_print(f"Enhancing attention in layer {i}")
           original_attention = layer.self_attn
           layer.self_attn = LogicEnhancedLlamaAttention(original_attention)
           enhanced_count += 1
       else:
           debug_print(f"Layer {i} does not have self_attn attribute")
  
   debug_print(f"Enhanced {enhanced_count} attention layers")
   return model




def debug_print(msg):
   # Replace with your actual debug print implementation (e.g., logging)
   print(f"[DEBUG] {msg}")


def identify_layers_path(model):
   """Replace with your actual layer identification logic."""
   # Example assuming a LLaMA-like structure
   return model.model.layers  # Access LLaMA-specific layers


class MemoryOptimizedLogicEnhancedLLaMA:
   def __init__(self, model_name="meta-llama/Meta-Llama-3-8B"):
       debug_print(f"Initializing MemoryOptimizedLogicEnhancedLLaMA with model {model_name}")
      
       # Set environment variable for PyTorch memory management
       os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:256"
      
       # Load tokenizer with same settings
       debug_print("Loading tokenizer")
       self.tokenizer = AutoTokenizer.from_pretrained(
           model_name,
           cache_dir=CACHE_DIR,
           use_auth_token=True,
           trust_remote_code=True
       )
       if not self.tokenizer.pad_token:
           self.tokenizer.pad_token = self.tokenizer.eos_token
      
       # Load model with memory optimization flags
       debug_print("Loading model")
       self.model = AutoModelForCausalLM.from_pretrained(
           model_name,
           cache_dir=CACHE_DIR,
           device_map="auto",
           use_auth_token=True,
           trust_remote_code=True,
           load_in_8bit=True,          # Use 8-bit quantization
           torch_dtype=torch.bfloat16,  # Use bfloat16 precision
           attn_implementation="flash_attention_2",  # Use Flash Attention 2 for H100
       )
      
       debug_print(f"Model loaded with type {type(self.model)}")
      
       # Initialize logic parser
       debug_print("Initializing logic parser")
       self.logic_parser = SpacyLogicParser()
      
       # Get and store reference to model layers for later use
       debug_print("Identifying model layers")
       self.model_layers = identify_layers_path(self.model)
      
       # Enhance the model with logic-aware attention
       debug_print("Enhancing model with logic-aware attention")
       self.model = add_logic_enhanced_attention(self.model)
      
       debug_print("Applying LoRA for fine-tuning")  # Re-added LoRA
       self.apply_lora()  # Re-added LoRA


   def apply_lora(self):
       """Apply LoRA adapter for memory-efficient fine-tuning"""
       debug_print("Configuring LoRA")
       lora_config = LoraConfig(  # Re-added LoRA
           r=16,                     # Rank of the update matrices
           lora_alpha=32,            # Parameter for scaling
           target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
           lora_dropout=0.05,
           bias="none",
           task_type=TaskType.CAUSAL_LM
       )
      
       debug_print("Applying LoRA to model")
       self.model = get_peft_model(self.model, lora_config)
       self.model.print_trainable_parameters()
  
   def process_with_logic(self, text, use_logic_mask=True):
       """Process text with or without logic enhancement."""
       debug_print(f"Processing text (use_logic_mask={use_logic_mask})")
       inputs = self.tokenizer(text, return_tensors="pt", padding=True).to(self.model.device)
      
       if use_logic_mask:
           debug_print("Creating logic mask")
           logic_mask = self.logic_parser.create_logic_attention_mask(text, self.tokenizer)
           logic_mask = logic_mask.to(self.model.device)
          
           debug_print("Forwarding with logic mask")
           outputs = self.forward_with_logic_mask(inputs, logic_mask)
       else:
           debug_print("Standard processing without logic mask")
           outputs = self.model(**inputs)
      
       return outputs
  
   def forward_with_logic_mask(self, inputs, logic_mask):
       """Helper method to forward inputs with a logic mask."""
       debug_print("In forward_with_logic_mask")
       input_dict = {k: v for k, v in inputs.items()}
      
       debug_print(f"Setting logic mask for {len(self.model_layers)} layers")
       for layer in self.model_layers:
           if hasattr(layer, 'self_attn') and isinstance(layer.self_attn, LogicEnhancedLlamaAttention):
               if not hasattr(layer.self_attn, '_original_forward'):
                   layer.self_attn._original_forward = layer.self_attn.forward


               def wrapped_forward(self_attn, *args, logic_mask=logic_mask, **kwargs):
                   kwargs['logic_mask'] = logic_mask
                   return self_attn._original_forward(*args, **kwargs)


               layer.self_attn.forward = types.MethodType(wrapped_forward, layer.self_attn)


       try:
           debug_print("Calling model with modified layers")
           outputs = self.model(**input_dict)


           for layer in self.model_layers:
               if hasattr(layer, 'self_attn') and hasattr(layer.self_attn, '_original_forward'):
                   layer.self_attn.forward = layer.self_attn._original_forward


           return outputs
       except Exception as e:
           debug_print(f"Error in forward_with_logic_mask: {e}")
           import traceback
           traceback.print_exc()
           for layer in self.model_layers:
               if hasattr(layer, 'self_attn') and hasattr(layer.self_attn, '_original_forward'):
                   layer.self_attn.forward = layer.self_attn._original_forward
           debug_print("Falling back to standard processing")
           return self.model(**input_dict)


   def generate_with_logic(self, text, use_logic_mask=True, max_new_tokens=1024, **kwargs):
       """Generate text with or without logic enhancement."""
       debug_print(f"Generating text (use_logic_mask={use_logic_mask})")
       inputs = self.tokenizer(text, return_tensors="pt", padding=True).to(self.model.device)
      
       generation_kwargs = {
           "max_new_tokens": max_new_tokens,
           "num_return_sequences": 1,
           "pad_token_id": self.tokenizer.pad_token_id,
           "eos_token_id": self.tokenizer.eos_token_id,
           **kwargs
       }
      
       if not use_logic_mask:
           debug_print("Using original attention (no logic mask)")
           outputs = self.model.generate(**inputs, **generation_kwargs)
           generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
           debug_print(f"Generated text length: {len(generated_text)}")
           return generated_text
      
       debug_print("Using logic-enhanced generation")
       try:
           logic_mask = self.logic_parser.create_logic_attention_mask(text, self.tokenizer)
           logic_mask = logic_mask.to(self.model.device)
           debug_print(f"Created logic mask with shape: {logic_mask.shape}")


           for layer in self.model_layers:
               if hasattr(layer, 'self_attn'):
                   original_attn_forward = layer.self_attn.forward


                   def make_enhanced_forward(original_forward):
                       def logic_enhanced_forward(self_attn, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, output_attentions=False, use_cache=False, **kwargs):
                           debug_print("Inside logic_enhanced_forward")
                           kwargs['logic_mask'] = logic_mask
                           return original_forward(hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, **kwargs)
                       return logic_enhanced_forward


                   layer.self_attn.forward = types.MethodType(make_enhanced_forward(original_attn_forward), layer.self_attn)


           try:
               outputs = self.model.generate(**inputs, **generation_kwargs)
           finally:
               for layer in self.model_layers:
                   if hasattr(layer, 'self_attn'):
                       layer.self_attn.forward = layer.self_attn.original_forward


       except Exception as e:
           debug_print(f"Error in logic-enhanced generation: {e}")
           import traceback
           traceback.print_exc()
           outputs = self.model.generate(**inputs, **generation_kwargs)
      
       generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
       debug_print(f"Generated text length: {len(generated_text)}")
       return generated_text