daam package

Submodules

daam.evaluate module

class daam.evaluate.MeanEvaluator(name: str = 'MeanEvaluator')

Bases: object

log_intensity(pred: torch.Tensor)
log_iou(preds: Union[torch.Tensor, List[torch.Tensor]], truth: torch.Tensor)
property mean_intensity: float
property mean_iou: float
daam.evaluate.compute_iou(a: torch.Tensor, b: torch.Tensor) float

daam.experiment module

class daam.experiment.GenerationExperiment(id: str, image: PIL.Image.Image, global_heat_map: torch.Tensor, seed: int, prompt: str, path: Optional[pathlib.Path] = None, truth_masks: Optional[Dict[str, torch.Tensor]] = None, prediction_masks: Optional[Dict[str, torch.Tensor]] = None, annotations: Optional[Dict[str, Any]] = None)

Bases: object

Class to hold experiment parameters. Pickleable.

annotate(key: str, value: Any) daam.experiment.GenerationExperiment
annotations: Optional[Dict[str, Any]] = None
clear_prediction_masks(name: str)
static contains_truth_mask(path: str | pathlib.Path, prompt_id: Optional[str] = None) bool
global_heat_map: torch.Tensor
static has_annotations(path: str | pathlib.Path) bool
static has_experiment(path: str | pathlib.Path, prompt_id: str) bool
id: str
image: PIL.Image.Image
classmethod load(path: str, pred_prefix: str = 'daam', composite: bool = False, simplify80: bool = False, vocab: Optional[List[str]] = None) daam.experiment.GenerationExperiment
nsfw() bool
path: Optional[pathlib.Path] = None
prediction_masks: Optional[Dict[str, torch.Tensor]] = None
prompt: str
static read_prompt(path: str | pathlib.Path, prompt_id: Optional[str] = None) str
static read_seed(path: str | pathlib.Path, prompt_id: Optional[str] = None) int
save(path: Optional[str] = None)
save_annotations(path: Optional[pathlib.Path] = None)
save_heat_map(tokenizer: transformers.tokenization_utils.PreTrainedTokenizer, word: str)
save_prediction_mask(mask: torch.Tensor, word: str, name: str)
seed: int
truth_masks: Optional[Dict[str, torch.Tensor]] = None
daam.experiment.build_word_list_coco80() Dict[str, List[str]]

daam.hook module

class daam.hook.AggregateHooker(module: daam.hook.ModuleType)

Bases: daam.hook.ObjectHooker[daam.hook.ModuleListType]

register_hook(hook: daam.hook.ObjectHooker)
class daam.hook.ModuleLocator

Bases: Generic[daam.hook.ModuleType]

locate(model: torch.nn.modules.module.Module) List[daam.hook.ModuleType]
class daam.hook.ObjectHooker(module: daam.hook.ModuleType)

Bases: Generic[daam.hook.ModuleType]

hook()
monkey_patch(fn_name, fn)
monkey_super(fn_name, *args, **kwargs)
unhook()
class daam.hook.UNetCrossAttentionLocator

Bases: daam.hook.ModuleLocator[diffusers.models.attention.CrossAttention]

locate(model: diffusers.models.unet_2d_condition.UNet2DConditionModel) List[diffusers.models.attention.CrossAttention]

Locate all cross-attention modules in a UNet2DConditionModel.

Parameters

model (UNet2DConditionModel) – The model to locate the cross-attention modules in.

Returns

The list of cross-attention modules.

Return type

List[CrossAttention]

daam.trace module

class daam.trace.DiffusionHeatMapHooker(pipeline: diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline, weighted: bool = False)

Bases: daam.hook.AggregateHooker

property all_heat_maps
compute_global_heat_map(prompt: str, time_weights: Optional[List[float]] = None, time_idx: Optional[int] = None, last_n: Optional[int] = None, factors: Optional[List[float]] = None) daam.trace.HeatMap

Compute the global heat map for the given prompt, aggregating across time (inference steps) and space (different spatial transformer block heat maps).

Parameters
  • prompt – The prompt to compute the heat map for.

  • time_weights – The weights to apply to each time step. If None, all time steps are weighted equally.

  • time_idx – The time step to compute the heat map for. If None, the heat map is computed for all time steps.

  • last_n – The number of time steps (last n) to use. If None, the heat map is computed for all time steps.

  • factors – Restrict the application to heat maps with spatial factors in this set. If None, use all sizes.

class daam.trace.HeatMap(tokenizer: Any, prompt: str, heat_maps: torch.Tensor)

Bases: object

compute_word_heat_map(word: str, word_idx: Optional[int] = None) torch.Tensor
class daam.trace.MmDetectHeatMap(pred_file: str | pathlib.Path, threshold: float = 0.95)

Bases: object

compute_word_heat_map(word: str) torch.Tensor
daam.trace.trace

alias of daam.trace.DiffusionHeatMapHooker

daam.utils module

daam.utils.compute_token_merge_indices(tokenizer, prompt: str, word: str, word_idx: Optional[int] = None)
daam.utils.expand_image(im: torch.Tensor, out: int = 512, absolute: bool = False, threshold: Optional[float] = None) torch.Tensor
daam.utils.plot_mask_heat_map(im: PIL.Image.Image, heat_map: torch.Tensor, threshold: float = 0.4)
daam.utils.plot_overlay_heat_map(im: PIL.Image.Image | numpy.ndarray, heat_map: torch.Tensor, word: Optional[str] = None, out_file: Optional[pathlib.Path] = None)
daam.utils.set_seed(seed: int) torch._C.Generator

Module contents