Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynatemp ooba hf #1

Open
wants to merge 29 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
fad6d37
WIP dynamic temp ooba HF support
kalomaze Jan 5, 2024
07929de
update max
kalomaze Jan 6, 2024
00ed020
Initiate float to store temp + preliminary fixes
kalomaze Jan 6, 2024
31c671b
oops
kalomaze Jan 6, 2024
501ba08
ok the last 2 commits broke idk why
kalomaze Jan 6, 2024
3672b53
Remove irrelevant copypaste init bits
kalomaze Jan 6, 2024
e33acf4
Make it use the dynatemp range value properly
kalomaze Jan 6, 2024
939bf05
Fix mirostat (?) and Max Entropy calculation
kalomaze Jan 6, 2024
31d3824
Remove nonsensical comments and return old ones
kalomaze Jan 6, 2024
5788ad8
Fix whitespace and remove irrelevant comment
kalomaze Jan 6, 2024
ae9461e
Initialize Dynamic Temp value for OAI extension
kalomaze Jan 6, 2024
6c70dc7
Ensure min_temp has a maximum of zero
kalomaze Jan 6, 2024
4f3264c
Guard against div by zero for both entropy calcs
kalomaze Jan 6, 2024
25f3cab
Properly ensure div by zero for max entropy calc
kalomaze Jan 6, 2024
ae476d6
Update UI description
kalomaze Jan 6, 2024
529daac
Fix tiny merge conflict
kalomaze Jan 6, 2024
85828d8
Sync Dynatemp branch with latest mainline ooba
kalomaze Jan 6, 2024
1cc7a14
Attempt to fix duplicated Temperature logic
kalomaze Jan 6, 2024
44e8a92
Lint
oobabooga Jan 7, 2024
941d257
Use a single warper for temperature and dynamic temperature
oobabooga Jan 7, 2024
33821b0
Comment the debug statements
oobabooga Jan 7, 2024
4849c57
Various minor changes
oobabooga Jan 7, 2024
951b268
Minor changes
oobabooga Jan 7, 2024
4023be2
Always replace temperature with TemperatureLogitsWarperWithDynatemp
oobabooga Jan 7, 2024
2fc441f
Add an extension for dynamic temperature with range
oobabooga Jan 7, 2024
6306927
Fix silent exception when temperature is int
oobabooga Jan 7, 2024
ba65b3c
Fix a logits issue with llamacpp_HF
oobabooga Jan 7, 2024
aa78dfd
Add a Dynamic Temperature preset
oobabooga Jan 7, 2024
09d5dd7
Document the new extension
oobabooga Jan 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/03 - Parameters Tab.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ For more information about the parameters, the [transformers documentation](http
* **mirostat_mode**: Activates the Mirostat sampling technique. It aims to control perplexity during sampling. See the [paper](https://arxiv.org/abs/2007.14966).
* **mirostat_tau**: No idea, see the paper for details. According to the Preset Arena, 8 is a good value.
* **mirostat_eta**: No idea, see the paper for details. According to the Preset Arena, 0.1 is a good value.
* **dynatemp**: Dynamic Temperature is activated when this parameter is greater than 0. The temperature range is determined by adding and subtracting dynatemp from the current temperature.
* **temperature_last**: Makes temperature the last sampler instead of the first. With this, you can remove low probability tokens with a sampler like min_p and then use a high temperature to make the model creative without losing coherency.
* **do_sample**: When unchecked, sampling is entirely disabled, and greedy decoding is used instead (the most likely token is always picked).
* **Seed**: Set the Pytorch seed to this number. Note that some loaders do not use Pytorch (notably llama.cpp), and others are not deterministic (notably ExLlama v1 and v2). For these loaders, the seed has no effect.
Expand Down
17 changes: 17 additions & 0 deletions extensions/dynatemp_with_range/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# dynatemp_with_range

This extension makes it possible to set the minimum and maximum temperatures for dynamic temperature explicitly.

For instance, you can directly set

```
min_T = 0.1
max_T = 3
```

instead of having to convert that to

```
T = 1.55
dynatemp = 1.45
```
50 changes: 50 additions & 0 deletions extensions/dynatemp_with_range/script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import gradio as gr

params = {
"activate": True,
"minimum_temperature": 0.1,
"maximum_temperature": 2,
}

def convert_to_dynatemp():
temperature = 0.5 * (params["minimum_temperature"] + params["maximum_temperature"])
dynatemp = params["maximum_temperature"] - temperature
return temperature, dynatemp


def state_modifier(state):
"""
Modifies the state variable, which is a dictionary containing the input
values in the UI like sliders and checkboxes.
"""

if params["activate"]:
temperature, dynatemp = convert_to_dynatemp()

state["temperature"] = temperature
state["dynatemp"] = dynatemp

return state


def generate_info():
temperature, dynatemp = convert_to_dynatemp()
return f"The combination above is equivalent to: T={temperature:.2f}, dynatemp={dynatemp:.2f}"


def ui():
activate = gr.Checkbox(value=params['activate'], label='Activate Dynamic Temperature Range', info='When checked, the default temperature/dynatemp parameters are ignored and the parameters below are used instead.')
with gr.Row():
minimum_temperature = gr.Slider(0, 5, step=0.01, label="Minimum temperature", value=params["minimum_temperature"], interactive=True)
maximum_temperature = gr.Slider(0, 5, step=0.01, label="Maximum temperature", value=params["maximum_temperature"], interactive=True)

info = gr.HTML(generate_info())

activate.change(lambda x: params.update({"activate": x}), activate, None)
minimum_temperature.change(
lambda x: params.update({"minimum_temperature": x}), minimum_temperature, None).then(
generate_info, None, info, show_progress=False)

maximum_temperature.change(
lambda x: params.update({"maximum_temperature": x}), maximum_temperature, None).then(
generate_info, None, info, show_progress=False)
1 change: 1 addition & 0 deletions extensions/openai/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
class GenerationOptions(BaseModel):
preset: str | None = Field(default=None, description="The name of a file under text-generation-webui/presets (without the .yaml extension). The sampling parameters that get overwritten by this option are the keys in the default_preset() function in modules/presets.py.")
min_p: float = 0
dynatemp: float = 0
top_k: int = 0
repetition_penalty: float = 1
repetition_penalty_range: int = 1024
Expand Down
8 changes: 6 additions & 2 deletions modules/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
import modules.shared as shared


class StopNowException(Exception):
pass


class _StopEverythingStoppingCriteria(transformers.StoppingCriteria):
def __init__(self):
transformers.StoppingCriteria.__init__(self)
Expand Down Expand Up @@ -49,13 +53,13 @@ def __init__(self, func, args=None, kwargs=None, callback=None):

def _callback(val):
if self.stop_now or shared.stop_everything:
raise ValueError
raise StopNowException
self.q.put(val)

def gentask():
try:
ret = self.mfunc(callback=_callback, *args, **self.kwargs)
except ValueError:
except StopNowException:
pass
except:
traceback.print_exc()
Expand Down
3 changes: 3 additions & 0 deletions modules/llamacpp_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ def __call__(self, *args, **kwargs):
self.model.n_tokens = longest_prefix
if len(seq_tensor) - longest_prefix > 0:
self.model.eval(seq[longest_prefix:])
else:
self.model.n_tokens -= 1
self.model.eval([seq[-1]])

if reset:
self.model.reset()
Expand Down
3 changes: 3 additions & 0 deletions modules/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def transformers_samplers():
return {
'temperature',
'temperature_last',
'dynatemp',
'top_p',
'min_p',
'top_k',
Expand Down Expand Up @@ -220,6 +221,7 @@ def transformers_samplers():
'ExLlamav2_HF': {
'temperature',
'temperature_last',
'dynatemp',
'top_p',
'min_p',
'top_k',
Expand Down Expand Up @@ -272,6 +274,7 @@ def transformers_samplers():
'llamacpp_HF': {
'temperature',
'temperature_last',
'dynatemp',
'top_p',
'min_p',
'top_k',
Expand Down
2 changes: 1 addition & 1 deletion modules/logits.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
global_scores = None


def get_next_logits(prompt, state, use_samplers, previous, top_logits=50, return_dict=False):
def get_next_logits(prompt, state, use_samplers, previous, top_logits=25, return_dict=False):
if shared.model is None:
logger.error("No model is loaded! Select one in the Model tab.")
return 'Error: No model is loaded1 Select one in the Model tab.', previous
Expand Down
1 change: 1 addition & 0 deletions modules/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def default_preset():
return {
'temperature': 1,
'temperature_last': False,
'dynatemp': 0,
'top_p': 1,
'min_p': 0,
'top_k': 0,
Expand Down
98 changes: 94 additions & 4 deletions modules/sampler_hijack.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,84 @@
TemperatureLogitsWarper
)

from modules import shared

global_scores = None


class TemperatureLogitsWarperWithDynatemp(LogitsWarper):
def __init__(self, temperature: float, dynatemp: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
if not isinstance(temperature, float) or not (temperature > 0):
except_msg = (
f"`temperature` (={temperature}) has to be a strictly positive float, otherwise your next token "
"scores will be invalid."
)
if isinstance(temperature, float) and temperature == 0.0:
except_msg += " If you're looking for greedy decoding strategies, set `do_sample=False`."

raise ValueError(except_msg)

self.temperature = temperature
self.dynatemp = dynatemp
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:

# Regular temperature
if self.dynatemp == 0:
scores = scores / self.temperature
return scores

# Dynamic temperature
else:
min_temp = max(0.0, self.temperature - self.dynatemp)
max_temp = self.temperature + self.dynatemp
exponent_val = 1.0

# Convert logits to probabilities
probs = torch.softmax(scores, dim=-1)

# Calculate entropy of the softmax probabilities
entropy = -1.0 * torch.where(probs > 0, probs * torch.log(probs), torch.zeros_like(probs)).sum()

# Guard against future possible division by zero
entropy = max(entropy, torch.tensor(1e-10)) # Ensures entropy is slightly greater than 0

# Any logits which are not -Infinity will be considered for calculating max entropy.
num_valid_tokens = torch.sum(scores > -float('inf')).item()

# Now, calculate the max entropy by using only the valid tokens' count
max_entropy = math.log(num_valid_tokens)

# Guard against future possible division by zero
max_entropy = max_entropy if max_entropy > 0.0 else 1e-10

# Normalize the entropy
normalized_entropy = entropy / max_entropy

# Map the normalized entropy to the desired temperature range using the power function
dyn_temp = min_temp + (max_temp - min_temp) * (normalized_entropy.pow(exponent_val))

# Apply the dynamically calculated temperature scaling
scores = scores / dyn_temp

# print("----------------------\nTemperature from generation_config:", self.temperature)
# print("min_temp:", min_temp)
# print("max_temp:", max_temp)
# print("Entropy:", entropy.item())
# print("Max Possible Entropy considering valid tokens only:", max_entropy)
# print("Normalized Entropy:", normalized_entropy.item())
# print("Dynamic Temperature (dyn_temp):", dyn_temp.item())
# print("----------------------")

# max_prob_token_id = torch.argmax(scores, dim=-1) # Get the token ID with the highest probability
# max_prob_token = shared.tokenizer.convert_ids_to_tokens(int(max_prob_token_id)) # Convert ID to token
# print("--- T=", float(dyn_temp), "token=", max_prob_token, "min=", min_temp, "max=", max_temp)

return scores


class MinPLogitsWarper(LogitsWarper):
def __init__(self, min_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
if min_p < 0 or min_p > 1.0:
Expand Down Expand Up @@ -198,14 +273,28 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
# presence_penalty and frequency_penalty
raw_presence_penalty = (counts > 0).to(scores.dtype)
raw_frequency_penalty = counts.to(scores.dtype)
additive_penalty = raw_presence_penalty*self.presence_penalty + raw_frequency_penalty*self.frequency_penalty
additive_penalty = raw_presence_penalty * self.presence_penalty + raw_frequency_penalty * self.frequency_penalty
scores_row.scatter_add_(0, unique_ids, -additive_penalty)

return scores


def get_logits_warper_patch(self, generation_config):
# Make sure that temperature is float and not int
if isinstance(generation_config.temperature, int):
generation_config.temperature = float(generation_config.temperature)

temperature = generation_config.temperature
if generation_config.dynatemp > 0:
# Make sure TemperatureLogitsWarper will be created by temporarily
# setting temperature to a value != 1.
generation_config.temperature = 1.1

warpers = self._get_logits_warper_old(generation_config)
for i in range(len(warpers)):
if warpers[i].__class__.__name__ == 'TemperatureLogitsWarper':
warpers[i] = TemperatureLogitsWarperWithDynatemp(temperature, generation_config.dynatemp)

warpers_to_add = LogitsProcessorList()
min_tokens_to_keep = 2 if generation_config.num_beams > 1 else 1

Expand All @@ -232,18 +321,18 @@ def get_logits_warper_patch(self, generation_config):
if generation_config.temperature_last:
temperature_idx = None
for i in range(len(warpers)):
if warpers[i].__class__.__name__ == 'TemperatureLogitsWarper':
if warpers[i].__class__.__name__ in ['TemperatureLogitsWarper', 'TemperatureLogitsWarperWithDynatemp']:
temperature_idx = i
break

if temperature_idx is not None:
warpers = warpers[:temperature_idx] + warpers[temperature_idx + 1:] + [warpers[temperature_idx]]
warpers = LogitsProcessorList(warpers)
warpers.append(warpers.pop(temperature_idx))

if normalize is not None:
warpers.append(normalize)

warpers.append(SpyLogitsWarper())
warpers = LogitsProcessorList(warpers)
# for i in range(len(warpers)):
# print(warpers[i].__class__.__name__)
return warpers
Expand Down Expand Up @@ -272,6 +361,7 @@ def get_logits_processor_patch(self, **kwargs):
def generation_config_init_patch(self, **kwargs):
self.__init___old(**kwargs)
self.min_p = kwargs.pop("min_p", 0.0)
self.dynatemp = kwargs.pop("dynatemp", 0.0)
self.tfs = kwargs.pop("tfs", 1.0)
self.top_a = kwargs.pop("top_a", 0.0)
self.mirostat_mode = kwargs.pop("mirostat_mode", 0)
Expand Down
2 changes: 1 addition & 1 deletion modules/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def get_reply_from_output_ids(output_ids, state, starting_from=0):

def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False):
generate_params = {}
for k in ['max_new_tokens', 'do_sample', 'temperature', 'temperature_last', 'top_p', 'min_p', 'typical_p', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'guidance_scale']:
for k in ['max_new_tokens', 'do_sample', 'temperature', 'temperature_last', 'dynatemp', 'top_p', 'min_p', 'typical_p', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'guidance_scale']:
generate_params[k] = state[k]

if state['negative_prompt'] != '':
Expand Down
1 change: 1 addition & 0 deletions modules/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def list_interface_input_elements():
'seed',
'temperature',
'temperature_last',
'dynatemp',
'top_p',
'min_p',
'top_k',
Expand Down
1 change: 1 addition & 0 deletions modules/ui_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def create_ui(default_preset):
shared.gradio['mirostat_mode'] = gr.Slider(0, 2, step=1, value=generate_params['mirostat_mode'], label='mirostat_mode', info='mode=1 is for llama.cpp only.')
shared.gradio['mirostat_tau'] = gr.Slider(0, 10, step=0.01, value=generate_params['mirostat_tau'], label='mirostat_tau')
shared.gradio['mirostat_eta'] = gr.Slider(0, 1, step=0.01, value=generate_params['mirostat_eta'], label='mirostat_eta')
shared.gradio['dynatemp'] = gr.Slider(0, 5, value=generate_params['dynatemp'], step=0.01, label='dynatemp')
shared.gradio['temperature_last'] = gr.Checkbox(value=generate_params['temperature_last'], label='temperature_last', info='Makes temperature the last sampler instead of the first.')
shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample')
shared.gradio['seed'] = gr.Number(value=shared.settings['seed'], label='Seed (-1 for random)')
Expand Down
4 changes: 4 additions & 0 deletions presets/Dynamic Temperature.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
temperature: 1.55
temperature_last: true
dynatemp: 1.45
min_p: 0.05