Notebooks
H
Hugging Face
Stable Diffusion Jax How To

Stable Diffusion Jax How To

hf-notebooksdiffusers

Open In Colab

🧨 Stable Diffusion in JAX / Flax !

🤗 Hugging Face Diffusers supports Flax since version 0.5.1! This allows for super fast inference on Google TPUs, such as those available in Colab, Kaggle or Google Cloud Platform.

This notebook shows how to run inference using JAX / Flax. If you want more details about how Stable Diffusion works or want to run it in GPU, please refer to this Colab notebook.

First, make sure you are using a TPU backend. If you are running this notebook in Colab, select Runtime in the menu above, then select the option "Change runtime type" and then select TPU under the Hardware accelerator setting.

Note that JAX is not exclusive to TPUs, but it shines on that hardware because each TPU server has 8 TPU accelerators working in parallel.

Setup

[1]
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Requirement already satisfied: flax in /usr/local/lib/python3.7/dist-packages (0.6.1)
Requirement already satisfied: transformers in /usr/local/lib/python3.7/dist-packages (4.23.1)
Requirement already satisfied: ftfy in /usr/local/lib/python3.7/dist-packages (6.1.1)
Requirement already satisfied: optax in /usr/local/lib/python3.7/dist-packages (from flax) (0.1.3)
Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from flax) (3.2.2)
Requirement already satisfied: rich>=11.1 in /usr/local/lib/python3.7/dist-packages (from flax) (12.6.0)
Requirement already satisfied: PyYAML>=5.4.1 in /usr/local/lib/python3.7/dist-packages (from flax) (6.0)
Requirement already satisfied: numpy>=1.12 in /usr/local/lib/python3.7/dist-packages (from flax) (1.21.6)
Requirement already satisfied: msgpack in /usr/local/lib/python3.7/dist-packages (from flax) (1.0.4)
Requirement already satisfied: typing-extensions>=4.1.1 in /usr/local/lib/python3.7/dist-packages (from flax) (4.1.1)
Requirement already satisfied: jax>=0.3.16 in /usr/local/lib/python3.7/dist-packages (from flax) (0.3.23)
Requirement already satisfied: absl-py in /usr/local/lib/python3.7/dist-packages (from jax>=0.3.16->flax) (1.3.0)
Requirement already satisfied: etils[epath] in /usr/local/lib/python3.7/dist-packages (from jax>=0.3.16->flax) (0.8.0)
Requirement already satisfied: scipy>=1.5 in /usr/local/lib/python3.7/dist-packages (from jax>=0.3.16->flax) (1.7.3)
Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax>=0.3.16->flax) (3.3.0)
Requirement already satisfied: commonmark<0.10.0,>=0.9.0 in /usr/local/lib/python3.7/dist-packages (from rich>=11.1->flax) (0.9.1)
Requirement already satisfied: pygments<3.0.0,>=2.6.0 in /usr/local/lib/python3.7/dist-packages (from rich>=11.1->flax) (2.6.1)
Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.7/dist-packages (from transformers) (4.64.1)
Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers) (3.8.0)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (2022.6.2)
Requirement already satisfied: huggingface-hub<1.0,>=0.10.0 in /usr/local/lib/python3.7/dist-packages (from transformers) (0.10.1)
Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.7/dist-packages (from transformers) (0.13.1)
Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers) (2.23.0)
Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from transformers) (4.13.0)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from transformers) (21.3)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.0->transformers) (3.0.9)
Requirement already satisfied: wcwidth>=0.2.5 in /usr/local/lib/python3.7/dist-packages (from ftfy) (0.2.5)
Requirement already satisfied: zipp in /usr/local/lib/python3.7/dist-packages (from etils[epath]->jax>=0.3.16->flax) (3.9.0)
Requirement already satisfied: importlib_resources in /usr/local/lib/python3.7/dist-packages (from etils[epath]->jax>=0.3.16->flax) (5.10.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax) (1.4.4)
Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax) (2.8.2)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax) (0.11.0)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.1->matplotlib->flax) (1.15.0)
Requirement already satisfied: jaxlib>=0.1.37 in /usr/local/lib/python3.7/dist-packages (from optax->flax) (0.3.22+cuda11.cudnn805)
Requirement already satisfied: chex>=0.0.4 in /usr/local/lib/python3.7/dist-packages (from optax->flax) (0.1.5)
Requirement already satisfied: dm-tree>=0.1.5 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax->flax) (0.1.7)
Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax->flax) (0.12.0)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (1.24.3)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (3.0.4)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2.10)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2022.9.24)
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Requirement already satisfied: diffusers==0.6.0 in /usr/local/lib/python3.7/dist-packages (0.6.0)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from diffusers==0.6.0) (2022.6.2)
Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from diffusers==0.6.0) (2.23.0)
Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from diffusers==0.6.0) (3.8.0)
Requirement already satisfied: Pillow<10.0 in /usr/local/lib/python3.7/dist-packages (from diffusers==0.6.0) (7.1.2)
Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from diffusers==0.6.0) (4.13.0)
Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from diffusers==0.6.0) (1.21.6)
Requirement already satisfied: huggingface-hub>=0.10.0 in /usr/local/lib/python3.7/dist-packages (from diffusers==0.6.0) (0.10.1)
Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.7/dist-packages (from huggingface-hub>=0.10.0->diffusers==0.6.0) (6.0)
Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from huggingface-hub>=0.10.0->diffusers==0.6.0) (4.64.1)
Requirement already satisfied: packaging>=20.9 in /usr/local/lib/python3.7/dist-packages (from huggingface-hub>=0.10.0->diffusers==0.6.0) (21.3)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.7/dist-packages (from huggingface-hub>=0.10.0->diffusers==0.6.0) (4.1.1)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.9->huggingface-hub>=0.10.0->diffusers==0.6.0) (3.0.9)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->diffusers==0.6.0) (3.9.0)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->diffusers==0.6.0) (2022.9.24)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->diffusers==0.6.0) (1.24.3)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->diffusers==0.6.0) (2.10)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->diffusers==0.6.0) (3.0.4)
[2]
[3]
Found 8 JAX devices of type Cloud TPU.

Then we import all the dependencies.

[4]

Model Loading

Before using the model, you need to accept the model license in order to download and use the weights.

The license is designed to mitigate the potential harmful effects of such a powerful machine learning system. We request users to read the license entirely and carefully. Here we offer a summary:

  1. You can't use the model to deliberately produce nor share illegal or harmful outputs or content,

  2. We claim no rights on the outputs you generate, you are free to use them and are accountable for their use which should not go against the provisions set in the license, and

  3. You may re-distribute the weights and use the model commercially and/or as a service. If you do, please be aware you have to include the same use restrictions as the ones in the license and share a copy of the CreativeML OpenRAIL-M to all your users.

Flax weights are available in Hugging Face Hub as part of the Stable Diffusion repo. To use them, you need to be a registered user in Hugging Face Hub and use an access token for the code to work. You have two options to provide your access token:

  • Use the huggingface-cli login command-line tool in your terminal and paste your token when prompted. It will be saved in a file in your computer.
  • Or use notebook_login() in a notebook, which does the same thing.

The following cell will present a login interface unless you've already authenticated before in this computer. You'll need to paste your access token.

[5]
Login successful
Your token has been saved to /root/.huggingface/token

TPU devices support bfloat16, an efficient half-float type. We'll use it for our tests, but you can also use float32 to use full precision instead.

[6]

Flax is a functional framework, so models are stateless and parameters are stored outside them. Loading the pre-trained Flax pipeline will return both the pipeline itself and the model weights (or parameters). We are using a bf16 version of the weights, which leads to type warnings that you can safely ignore.

[7]
Downloading:   0%|          | 0.00/563 [00:00<?, ?B/s]
Fetching 16 files:   0%|          | 0/16 [00:00<?, ?it/s]
Downloading:   0%|          | 0.00/342 [00:00<?, ?B/s]
Downloading:   0%|          | 0.00/4.78k [00:00<?, ?B/s]
Downloading:   0%|          | 0.00/608M [00:00<?, ?B/s]
Downloading:   0%|          | 0.00/209 [00:00<?, ?B/s]
Downloading:   0%|          | 0.00/230 [00:00<?, ?B/s]
Downloading:   0%|          | 0.00/587 [00:00<?, ?B/s]
Downloading:   0%|          | 0.00/246M [00:00<?, ?B/s]
Downloading:   0%|          | 0.00/525k [00:00<?, ?B/s]
Downloading:   0%|          | 0.00/472 [00:00<?, ?B/s]
Downloading:   0%|          | 0.00/806 [00:00<?, ?B/s]
Downloading:   0%|          | 0.00/1.06M [00:00<?, ?B/s]
Downloading:   0%|          | 0.00/587 [00:00<?, ?B/s]
Downloading:   0%|          | 0.00/1.72G [00:00<?, ?B/s]
Downloading:   0%|          | 0.00/556 [00:00<?, ?B/s]
Downloading:   0%|          | 0.00/167M [00:00<?, ?B/s]
Some of the weights of FlaxStableDiffusionSafetyChecker were initialized in bfloat16 precision from the model checkpoint at /root/.cache/huggingface/diffusers/models--CompVis--stable-diffusion-v1-4/snapshots/295cccdedbd5f87458186972858dc85c7e70c10a/safety_checker:
[('concept_embeds',), ('concept_embeds_weights',), ('special_care_embeds',), ('special_care_embeds_weights',), ('vision_model', 'vision_model', 'embeddings', 'class_embedding'), ('vision_model', 'vision_model', 'embeddings', 'patch_embedding', 'kernel'), ('vision_model', 'vision_model', 'embeddings', 'position_embedding', 'embedding'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '0', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '1', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '10', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '11', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '12', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '13', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '14', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '15', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '16', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '17', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '18', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '19', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '2', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '20', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '21', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '22', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '23', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '3', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '4', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '5', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '6', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '7', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '8', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'layer_norm1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'layer_norm1', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'layer_norm2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'layer_norm2', 'scale'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'mlp', 'fc1', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'mlp', 'fc1', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'mlp', 'fc2', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'mlp', 'fc2', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'vision_model', 'encoder', 'layers', '9', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'vision_model', 'post_layernorm', 'bias'), ('vision_model', 'vision_model', 'post_layernorm', 'scale'), ('vision_model', 'vision_model', 'pre_layrnorm', 'bias'), ('vision_model', 'vision_model', 'pre_layrnorm', 'scale'), ('visual_projection', 'kernel')]
You should probably UPCAST the model weights to float32 if this was not intended. See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this.
Some of the weights of FlaxCLIPTextModel were initialized in bfloat16 precision from the model checkpoint at /root/.cache/huggingface/diffusers/models--CompVis--stable-diffusion-v1-4/snapshots/295cccdedbd5f87458186972858dc85c7e70c10a/text_encoder:
[('text_model', 'embeddings', 'position_embedding', 'embedding'), ('text_model', 'embeddings', 'token_embedding', 'embedding'), ('text_model', 'encoder', 'layers', '0', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '0', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '0', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '0', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '0', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '0', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '0', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '0', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '0', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '0', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '0', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '0', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '0', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '0', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '1', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '1', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '1', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '1', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '1', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '1', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '1', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '1', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '1', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '1', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '1', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '1', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '1', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '1', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '1', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '1', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '10', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '10', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '10', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '10', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '10', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '10', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '10', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '10', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '10', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '10', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '10', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '10', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '10', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '10', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '10', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '10', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '11', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '11', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '11', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '11', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '11', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '11', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '11', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '11', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '11', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '11', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '11', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '11', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '11', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '11', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '11', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '11', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '2', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '2', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '2', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '2', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '2', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '2', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '2', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '2', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '2', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '2', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '2', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '2', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '2', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '2', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '2', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '2', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '3', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '3', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '3', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '3', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '3', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '3', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '3', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '3', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '3', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '3', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '3', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '3', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '3', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '3', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '3', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '3', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '4', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '4', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '4', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '4', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '4', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '4', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '4', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '4', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '4', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '4', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '4', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '4', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '4', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '4', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '4', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '4', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '5', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '5', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '5', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '5', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '5', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '5', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '5', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '5', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '5', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '5', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '5', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '5', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '5', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '5', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '5', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '5', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '6', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '6', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '6', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '6', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '6', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '6', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '6', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '6', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '6', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '6', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '6', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '6', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '6', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '6', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '6', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '6', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '7', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '7', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '7', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '7', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '7', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '7', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '7', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '7', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '7', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '7', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '7', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '7', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '7', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '7', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '7', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '7', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '8', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '8', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '8', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '8', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '8', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '8', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '8', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '8', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '8', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '8', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '8', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '8', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '8', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '8', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '8', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '8', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'encoder', 'layers', '9', 'layer_norm1', 'bias'), ('text_model', 'encoder', 'layers', '9', 'layer_norm1', 'scale'), ('text_model', 'encoder', 'layers', '9', 'layer_norm2', 'bias'), ('text_model', 'encoder', 'layers', '9', 'layer_norm2', 'scale'), ('text_model', 'encoder', 'layers', '9', 'mlp', 'fc1', 'bias'), ('text_model', 'encoder', 'layers', '9', 'mlp', 'fc1', 'kernel'), ('text_model', 'encoder', 'layers', '9', 'mlp', 'fc2', 'bias'), ('text_model', 'encoder', 'layers', '9', 'mlp', 'fc2', 'kernel'), ('text_model', 'encoder', 'layers', '9', 'self_attn', 'k_proj', 'bias'), ('text_model', 'encoder', 'layers', '9', 'self_attn', 'k_proj', 'kernel'), ('text_model', 'encoder', 'layers', '9', 'self_attn', 'out_proj', 'bias'), ('text_model', 'encoder', 'layers', '9', 'self_attn', 'out_proj', 'kernel'), ('text_model', 'encoder', 'layers', '9', 'self_attn', 'q_proj', 'bias'), ('text_model', 'encoder', 'layers', '9', 'self_attn', 'q_proj', 'kernel'), ('text_model', 'encoder', 'layers', '9', 'self_attn', 'v_proj', 'bias'), ('text_model', 'encoder', 'layers', '9', 'self_attn', 'v_proj', 'kernel'), ('text_model', 'final_layer_norm', 'bias'), ('text_model', 'final_layer_norm', 'scale')]
You should probably UPCAST the model weights to float32 if this was not intended. See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this.
Some of the weights of FlaxAutoencoderKL were initialized in bfloat16 precision from the model checkpoint at /root/.cache/huggingface/diffusers/models--CompVis--stable-diffusion-v1-4/snapshots/295cccdedbd5f87458186972858dc85c7e70c10a/vae:
[('decoder', 'conv_in', 'bias'), ('decoder', 'conv_in', 'kernel'), ('decoder', 'conv_norm_out', 'bias'), ('decoder', 'conv_norm_out', 'scale'), ('decoder', 'conv_out', 'bias'), ('decoder', 'conv_out', 'kernel'), ('decoder', 'mid_block', 'attentions_0', 'group_norm', 'bias'), ('decoder', 'mid_block', 'attentions_0', 'group_norm', 'scale'), ('decoder', 'mid_block', 'attentions_0', 'key', 'bias'), ('decoder', 'mid_block', 'attentions_0', 'key', 'kernel'), ('decoder', 'mid_block', 'attentions_0', 'proj_attn', 'bias'), ('decoder', 'mid_block', 'attentions_0', 'proj_attn', 'kernel'), ('decoder', 'mid_block', 'attentions_0', 'query', 'bias'), ('decoder', 'mid_block', 'attentions_0', 'query', 'kernel'), ('decoder', 'mid_block', 'attentions_0', 'value', 'bias'), ('decoder', 'mid_block', 'attentions_0', 'value', 'kernel'), ('decoder', 'mid_block', 'resnets_0', 'conv1', 'bias'), ('decoder', 'mid_block', 'resnets_0', 'conv1', 'kernel'), ('decoder', 'mid_block', 'resnets_0', 'conv2', 'bias'), ('decoder', 'mid_block', 'resnets_0', 'conv2', 'kernel'), ('decoder', 'mid_block', 'resnets_0', 'norm1', 'bias'), ('decoder', 'mid_block', 'resnets_0', 'norm1', 'scale'), ('decoder', 'mid_block', 'resnets_0', 'norm2', 'bias'), ('decoder', 'mid_block', 'resnets_0', 'norm2', 'scale'), ('decoder', 'mid_block', 'resnets_1', 'conv1', 'bias'), ('decoder', 'mid_block', 'resnets_1', 'conv1', 'kernel'), ('decoder', 'mid_block', 'resnets_1', 'conv2', 'bias'), ('decoder', 'mid_block', 'resnets_1', 'conv2', 'kernel'), ('decoder', 'mid_block', 'resnets_1', 'norm1', 'bias'), ('decoder', 'mid_block', 'resnets_1', 'norm1', 'scale'), ('decoder', 'mid_block', 'resnets_1', 'norm2', 'bias'), ('decoder', 'mid_block', 'resnets_1', 'norm2', 'scale'), ('decoder', 'up_blocks_0', 'resnets_0', 'conv1', 'bias'), ('decoder', 'up_blocks_0', 'resnets_0', 'conv1', 'kernel'), ('decoder', 'up_blocks_0', 'resnets_0', 'conv2', 'bias'), ('decoder', 'up_blocks_0', 'resnets_0', 'conv2', 'kernel'), ('decoder', 'up_blocks_0', 'resnets_0', 'norm1', 'bias'), ('decoder', 'up_blocks_0', 'resnets_0', 'norm1', 'scale'), ('decoder', 'up_blocks_0', 'resnets_0', 'norm2', 'bias'), ('decoder', 'up_blocks_0', 'resnets_0', 'norm2', 'scale'), ('decoder', 'up_blocks_0', 'resnets_1', 'conv1', 'bias'), ('decoder', 'up_blocks_0', 'resnets_1', 'conv1', 'kernel'), ('decoder', 'up_blocks_0', 'resnets_1', 'conv2', 'bias'), ('decoder', 'up_blocks_0', 'resnets_1', 'conv2', 'kernel'), ('decoder', 'up_blocks_0', 'resnets_1', 'norm1', 'bias'), ('decoder', 'up_blocks_0', 'resnets_1', 'norm1', 'scale'), ('decoder', 'up_blocks_0', 'resnets_1', 'norm2', 'bias'), ('decoder', 'up_blocks_0', 'resnets_1', 'norm2', 'scale'), ('decoder', 'up_blocks_0', 'resnets_2', 'conv1', 'bias'), ('decoder', 'up_blocks_0', 'resnets_2', 'conv1', 'kernel'), ('decoder', 'up_blocks_0', 'resnets_2', 'conv2', 'bias'), ('decoder', 'up_blocks_0', 'resnets_2', 'conv2', 'kernel'), ('decoder', 'up_blocks_0', 'resnets_2', 'norm1', 'bias'), ('decoder', 'up_blocks_0', 'resnets_2', 'norm1', 'scale'), ('decoder', 'up_blocks_0', 'resnets_2', 'norm2', 'bias'), ('decoder', 'up_blocks_0', 'resnets_2', 'norm2', 'scale'), ('decoder', 'up_blocks_0', 'upsamplers_0', 'conv', 'bias'), ('decoder', 'up_blocks_0', 'upsamplers_0', 'conv', 'kernel'), ('decoder', 'up_blocks_1', 'resnets_0', 'conv1', 'bias'), ('decoder', 'up_blocks_1', 'resnets_0', 'conv1', 'kernel'), ('decoder', 'up_blocks_1', 'resnets_0', 'conv2', 'bias'), ('decoder', 'up_blocks_1', 'resnets_0', 'conv2', 'kernel'), ('decoder', 'up_blocks_1', 'resnets_0', 'norm1', 'bias'), ('decoder', 'up_blocks_1', 'resnets_0', 'norm1', 'scale'), ('decoder', 'up_blocks_1', 'resnets_0', 'norm2', 'bias'), ('decoder', 'up_blocks_1', 'resnets_0', 'norm2', 'scale'), ('decoder', 'up_blocks_1', 'resnets_1', 'conv1', 'bias'), ('decoder', 'up_blocks_1', 'resnets_1', 'conv1', 'kernel'), ('decoder', 'up_blocks_1', 'resnets_1', 'conv2', 'bias'), ('decoder', 'up_blocks_1', 'resnets_1', 'conv2', 'kernel'), ('decoder', 'up_blocks_1', 'resnets_1', 'norm1', 'bias'), ('decoder', 'up_blocks_1', 'resnets_1', 'norm1', 'scale'), ('decoder', 'up_blocks_1', 'resnets_1', 'norm2', 'bias'), ('decoder', 'up_blocks_1', 'resnets_1', 'norm2', 'scale'), ('decoder', 'up_blocks_1', 'resnets_2', 'conv1', 'bias'), ('decoder', 'up_blocks_1', 'resnets_2', 'conv1', 'kernel'), ('decoder', 'up_blocks_1', 'resnets_2', 'conv2', 'bias'), ('decoder', 'up_blocks_1', 'resnets_2', 'conv2', 'kernel'), ('decoder', 'up_blocks_1', 'resnets_2', 'norm1', 'bias'), ('decoder', 'up_blocks_1', 'resnets_2', 'norm1', 'scale'), ('decoder', 'up_blocks_1', 'resnets_2', 'norm2', 'bias'), ('decoder', 'up_blocks_1', 'resnets_2', 'norm2', 'scale'), ('decoder', 'up_blocks_1', 'upsamplers_0', 'conv', 'bias'), ('decoder', 'up_blocks_1', 'upsamplers_0', 'conv', 'kernel'), ('decoder', 'up_blocks_2', 'resnets_0', 'conv1', 'bias'), ('decoder', 'up_blocks_2', 'resnets_0', 'conv1', 'kernel'), ('decoder', 'up_blocks_2', 'resnets_0', 'conv2', 'bias'), ('decoder', 'up_blocks_2', 'resnets_0', 'conv2', 'kernel'), ('decoder', 'up_blocks_2', 'resnets_0', 'conv_shortcut', 'bias'), ('decoder', 'up_blocks_2', 'resnets_0', 'conv_shortcut', 'kernel'), ('decoder', 'up_blocks_2', 'resnets_0', 'norm1', 'bias'), ('decoder', 'up_blocks_2', 'resnets_0', 'norm1', 'scale'), ('decoder', 'up_blocks_2', 'resnets_0', 'norm2', 'bias'), ('decoder', 'up_blocks_2', 'resnets_0', 'norm2', 'scale'), ('decoder', 'up_blocks_2', 'resnets_1', 'conv1', 'bias'), ('decoder', 'up_blocks_2', 'resnets_1', 'conv1', 'kernel'), ('decoder', 'up_blocks_2', 'resnets_1', 'conv2', 'bias'), ('decoder', 'up_blocks_2', 'resnets_1', 'conv2', 'kernel'), ('decoder', 'up_blocks_2', 'resnets_1', 'norm1', 'bias'), ('decoder', 'up_blocks_2', 'resnets_1', 'norm1', 'scale'), ('decoder', 'up_blocks_2', 'resnets_1', 'norm2', 'bias'), ('decoder', 'up_blocks_2', 'resnets_1', 'norm2', 'scale'), ('decoder', 'up_blocks_2', 'resnets_2', 'conv1', 'bias'), ('decoder', 'up_blocks_2', 'resnets_2', 'conv1', 'kernel'), ('decoder', 'up_blocks_2', 'resnets_2', 'conv2', 'bias'), ('decoder', 'up_blocks_2', 'resnets_2', 'conv2', 'kernel'), ('decoder', 'up_blocks_2', 'resnets_2', 'norm1', 'bias'), ('decoder', 'up_blocks_2', 'resnets_2', 'norm1', 'scale'), ('decoder', 'up_blocks_2', 'resnets_2', 'norm2', 'bias'), ('decoder', 'up_blocks_2', 'resnets_2', 'norm2', 'scale'), ('decoder', 'up_blocks_2', 'upsamplers_0', 'conv', 'bias'), ('decoder', 'up_blocks_2', 'upsamplers_0', 'conv', 'kernel'), ('decoder', 'up_blocks_3', 'resnets_0', 'conv1', 'bias'), ('decoder', 'up_blocks_3', 'resnets_0', 'conv1', 'kernel'), ('decoder', 'up_blocks_3', 'resnets_0', 'conv2', 'bias'), ('decoder', 'up_blocks_3', 'resnets_0', 'conv2', 'kernel'), ('decoder', 'up_blocks_3', 'resnets_0', 'conv_shortcut', 'bias'), ('decoder', 'up_blocks_3', 'resnets_0', 'conv_shortcut', 'kernel'), ('decoder', 'up_blocks_3', 'resnets_0', 'norm1', 'bias'), ('decoder', 'up_blocks_3', 'resnets_0', 'norm1', 'scale'), ('decoder', 'up_blocks_3', 'resnets_0', 'norm2', 'bias'), ('decoder', 'up_blocks_3', 'resnets_0', 'norm2', 'scale'), ('decoder', 'up_blocks_3', 'resnets_1', 'conv1', 'bias'), ('decoder', 'up_blocks_3', 'resnets_1', 'conv1', 'kernel'), ('decoder', 'up_blocks_3', 'resnets_1', 'conv2', 'bias'), ('decoder', 'up_blocks_3', 'resnets_1', 'conv2', 'kernel'), ('decoder', 'up_blocks_3', 'resnets_1', 'norm1', 'bias'), ('decoder', 'up_blocks_3', 'resnets_1', 'norm1', 'scale'), ('decoder', 'up_blocks_3', 'resnets_1', 'norm2', 'bias'), ('decoder', 'up_blocks_3', 'resnets_1', 'norm2', 'scale'), ('decoder', 'up_blocks_3', 'resnets_2', 'conv1', 'bias'), ('decoder', 'up_blocks_3', 'resnets_2', 'conv1', 'kernel'), ('decoder', 'up_blocks_3', 'resnets_2', 'conv2', 'bias'), ('decoder', 'up_blocks_3', 'resnets_2', 'conv2', 'kernel'), ('decoder', 'up_blocks_3', 'resnets_2', 'norm1', 'bias'), ('decoder', 'up_blocks_3', 'resnets_2', 'norm1', 'scale'), ('decoder', 'up_blocks_3', 'resnets_2', 'norm2', 'bias'), ('decoder', 'up_blocks_3', 'resnets_2', 'norm2', 'scale'), ('encoder', 'conv_in', 'bias'), ('encoder', 'conv_in', 'kernel'), ('encoder', 'conv_norm_out', 'bias'), ('encoder', 'conv_norm_out', 'scale'), ('encoder', 'conv_out', 'bias'), ('encoder', 'conv_out', 'kernel'), ('encoder', 'down_blocks_0', 'downsamplers_0', 'conv', 'bias'), ('encoder', 'down_blocks_0', 'downsamplers_0', 'conv', 'kernel'), ('encoder', 'down_blocks_0', 'resnets_0', 'conv1', 'bias'), ('encoder', 'down_blocks_0', 'resnets_0', 'conv1', 'kernel'), ('encoder', 'down_blocks_0', 'resnets_0', 'conv2', 'bias'), ('encoder', 'down_blocks_0', 'resnets_0', 'conv2', 'kernel'), ('encoder', 'down_blocks_0', 'resnets_0', 'norm1', 'bias'), ('encoder', 'down_blocks_0', 'resnets_0', 'norm1', 'scale'), ('encoder', 'down_blocks_0', 'resnets_0', 'norm2', 'bias'), ('encoder', 'down_blocks_0', 'resnets_0', 'norm2', 'scale'), ('encoder', 'down_blocks_0', 'resnets_1', 'conv1', 'bias'), ('encoder', 'down_blocks_0', 'resnets_1', 'conv1', 'kernel'), ('encoder', 'down_blocks_0', 'resnets_1', 'conv2', 'bias'), ('encoder', 'down_blocks_0', 'resnets_1', 'conv2', 'kernel'), ('encoder', 'down_blocks_0', 'resnets_1', 'norm1', 'bias'), ('encoder', 'down_blocks_0', 'resnets_1', 'norm1', 'scale'), ('encoder', 'down_blocks_0', 'resnets_1', 'norm2', 'bias'), ('encoder', 'down_blocks_0', 'resnets_1', 'norm2', 'scale'), ('encoder', 'down_blocks_1', 'downsamplers_0', 'conv', 'bias'), ('encoder', 'down_blocks_1', 'downsamplers_0', 'conv', 'kernel'), ('encoder', 'down_blocks_1', 'resnets_0', 'conv1', 'bias'), ('encoder', 'down_blocks_1', 'resnets_0', 'conv1', 'kernel'), ('encoder', 'down_blocks_1', 'resnets_0', 'conv2', 'bias'), ('encoder', 'down_blocks_1', 'resnets_0', 'conv2', 'kernel'), ('encoder', 'down_blocks_1', 'resnets_0', 'conv_shortcut', 'bias'), ('encoder', 'down_blocks_1', 'resnets_0', 'conv_shortcut', 'kernel'), ('encoder', 'down_blocks_1', 'resnets_0', 'norm1', 'bias'), ('encoder', 'down_blocks_1', 'resnets_0', 'norm1', 'scale'), ('encoder', 'down_blocks_1', 'resnets_0', 'norm2', 'bias'), ('encoder', 'down_blocks_1', 'resnets_0', 'norm2', 'scale'), ('encoder', 'down_blocks_1', 'resnets_1', 'conv1', 'bias'), ('encoder', 'down_blocks_1', 'resnets_1', 'conv1', 'kernel'), ('encoder', 'down_blocks_1', 'resnets_1', 'conv2', 'bias'), ('encoder', 'down_blocks_1', 'resnets_1', 'conv2', 'kernel'), ('encoder', 'down_blocks_1', 'resnets_1', 'norm1', 'bias'), ('encoder', 'down_blocks_1', 'resnets_1', 'norm1', 'scale'), ('encoder', 'down_blocks_1', 'resnets_1', 'norm2', 'bias'), ('encoder', 'down_blocks_1', 'resnets_1', 'norm2', 'scale'), ('encoder', 'down_blocks_2', 'downsamplers_0', 'conv', 'bias'), ('encoder', 'down_blocks_2', 'downsamplers_0', 'conv', 'kernel'), ('encoder', 'down_blocks_2', 'resnets_0', 'conv1', 'bias'), ('encoder', 'down_blocks_2', 'resnets_0', 'conv1', 'kernel'), ('encoder', 'down_blocks_2', 'resnets_0', 'conv2', 'bias'), ('encoder', 'down_blocks_2', 'resnets_0', 'conv2', 'kernel'), ('encoder', 'down_blocks_2', 'resnets_0', 'conv_shortcut', 'bias'), ('encoder', 'down_blocks_2', 'resnets_0', 'conv_shortcut', 'kernel'), ('encoder', 'down_blocks_2', 'resnets_0', 'norm1', 'bias'), ('encoder', 'down_blocks_2', 'resnets_0', 'norm1', 'scale'), ('encoder', 'down_blocks_2', 'resnets_0', 'norm2', 'bias'), ('encoder', 'down_blocks_2', 'resnets_0', 'norm2', 'scale'), ('encoder', 'down_blocks_2', 'resnets_1', 'conv1', 'bias'), ('encoder', 'down_blocks_2', 'resnets_1', 'conv1', 'kernel'), ('encoder', 'down_blocks_2', 'resnets_1', 'conv2', 'bias'), ('encoder', 'down_blocks_2', 'resnets_1', 'conv2', 'kernel'), ('encoder', 'down_blocks_2', 'resnets_1', 'norm1', 'bias'), ('encoder', 'down_blocks_2', 'resnets_1', 'norm1', 'scale'), ('encoder', 'down_blocks_2', 'resnets_1', 'norm2', 'bias'), ('encoder', 'down_blocks_2', 'resnets_1', 'norm2', 'scale'), ('encoder', 'down_blocks_3', 'resnets_0', 'conv1', 'bias'), ('encoder', 'down_blocks_3', 'resnets_0', 'conv1', 'kernel'), ('encoder', 'down_blocks_3', 'resnets_0', 'conv2', 'bias'), ('encoder', 'down_blocks_3', 'resnets_0', 'conv2', 'kernel'), ('encoder', 'down_blocks_3', 'resnets_0', 'norm1', 'bias'), ('encoder', 'down_blocks_3', 'resnets_0', 'norm1', 'scale'), ('encoder', 'down_blocks_3', 'resnets_0', 'norm2', 'bias'), ('encoder', 'down_blocks_3', 'resnets_0', 'norm2', 'scale'), ('encoder', 'down_blocks_3', 'resnets_1', 'conv1', 'bias'), ('encoder', 'down_blocks_3', 'resnets_1', 'conv1', 'kernel'), ('encoder', 'down_blocks_3', 'resnets_1', 'conv2', 'bias'), ('encoder', 'down_blocks_3', 'resnets_1', 'conv2', 'kernel'), ('encoder', 'down_blocks_3', 'resnets_1', 'norm1', 'bias'), ('encoder', 'down_blocks_3', 'resnets_1', 'norm1', 'scale'), ('encoder', 'down_blocks_3', 'resnets_1', 'norm2', 'bias'), ('encoder', 'down_blocks_3', 'resnets_1', 'norm2', 'scale'), ('encoder', 'mid_block', 'attentions_0', 'group_norm', 'bias'), ('encoder', 'mid_block', 'attentions_0', 'group_norm', 'scale'), ('encoder', 'mid_block', 'attentions_0', 'key', 'bias'), ('encoder', 'mid_block', 'attentions_0', 'key', 'kernel'), ('encoder', 'mid_block', 'attentions_0', 'proj_attn', 'bias'), ('encoder', 'mid_block', 'attentions_0', 'proj_attn', 'kernel'), ('encoder', 'mid_block', 'attentions_0', 'query', 'bias'), ('encoder', 'mid_block', 'attentions_0', 'query', 'kernel'), ('encoder', 'mid_block', 'attentions_0', 'value', 'bias'), ('encoder', 'mid_block', 'attentions_0', 'value', 'kernel'), ('encoder', 'mid_block', 'resnets_0', 'conv1', 'bias'), ('encoder', 'mid_block', 'resnets_0', 'conv1', 'kernel'), ('encoder', 'mid_block', 'resnets_0', 'conv2', 'bias'), ('encoder', 'mid_block', 'resnets_0', 'conv2', 'kernel'), ('encoder', 'mid_block', 'resnets_0', 'norm1', 'bias'), ('encoder', 'mid_block', 'resnets_0', 'norm1', 'scale'), ('encoder', 'mid_block', 'resnets_0', 'norm2', 'bias'), ('encoder', 'mid_block', 'resnets_0', 'norm2', 'scale'), ('encoder', 'mid_block', 'resnets_1', 'conv1', 'bias'), ('encoder', 'mid_block', 'resnets_1', 'conv1', 'kernel'), ('encoder', 'mid_block', 'resnets_1', 'conv2', 'bias'), ('encoder', 'mid_block', 'resnets_1', 'conv2', 'kernel'), ('encoder', 'mid_block', 'resnets_1', 'norm1', 'bias'), ('encoder', 'mid_block', 'resnets_1', 'norm1', 'scale'), ('encoder', 'mid_block', 'resnets_1', 'norm2', 'bias'), ('encoder', 'mid_block', 'resnets_1', 'norm2', 'scale'), ('post_quant_conv', 'bias'), ('post_quant_conv', 'kernel'), ('quant_conv', 'bias'), ('quant_conv', 'kernel')]
You should probably UPCAST the model weights to float32 if this was not intended. See [`~ModelMixin.to_fp32`] for further information on how to do this.
Some of the weights of FlaxUNet2DConditionModel were initialized in bfloat16 precision from the model checkpoint at /root/.cache/huggingface/diffusers/models--CompVis--stable-diffusion-v1-4/snapshots/295cccdedbd5f87458186972858dc85c7e70c10a/unet:
[('conv_in', 'bias'), ('conv_in', 'kernel'), ('conv_norm_out', 'bias'), ('conv_norm_out', 'scale'), ('conv_out', 'bias'), ('conv_out', 'kernel'), ('down_blocks_0', 'attentions_0', 'norm', 'bias'), ('down_blocks_0', 'attentions_0', 'norm', 'scale'), ('down_blocks_0', 'attentions_0', 'proj_in', 'bias'), ('down_blocks_0', 'attentions_0', 'proj_in', 'kernel'), ('down_blocks_0', 'attentions_0', 'proj_out', 'bias'), ('down_blocks_0', 'attentions_0', 'proj_out', 'kernel'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'bias'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'kernel'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'norm1', 'bias'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'norm1', 'scale'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'norm2', 'bias'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'norm2', 'scale'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'norm3', 'bias'), ('down_blocks_0', 'attentions_0', 'transformer_blocks_0', 'norm3', 'scale'), ('down_blocks_0', 'attentions_1', 'norm', 'bias'), ('down_blocks_0', 'attentions_1', 'norm', 'scale'), ('down_blocks_0', 'attentions_1', 'proj_in', 'bias'), ('down_blocks_0', 'attentions_1', 'proj_in', 'kernel'), ('down_blocks_0', 'attentions_1', 'proj_out', 'bias'), ('down_blocks_0', 'attentions_1', 'proj_out', 'kernel'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_2', 'bias'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_2', 'kernel'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'norm1', 'bias'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'norm1', 'scale'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'norm2', 'bias'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'norm2', 'scale'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'norm3', 'bias'), ('down_blocks_0', 'attentions_1', 'transformer_blocks_0', 'norm3', 'scale'), ('down_blocks_0', 'downsamplers_0', 'conv', 'bias'), ('down_blocks_0', 'downsamplers_0', 'conv', 'kernel'), ('down_blocks_0', 'resnets_0', 'conv1', 'bias'), ('down_blocks_0', 'resnets_0', 'conv1', 'kernel'), ('down_blocks_0', 'resnets_0', 'conv2', 'bias'), ('down_blocks_0', 'resnets_0', 'conv2', 'kernel'), ('down_blocks_0', 'resnets_0', 'norm1', 'bias'), ('down_blocks_0', 'resnets_0', 'norm1', 'scale'), ('down_blocks_0', 'resnets_0', 'norm2', 'bias'), ('down_blocks_0', 'resnets_0', 'norm2', 'scale'), ('down_blocks_0', 'resnets_0', 'time_emb_proj', 'bias'), ('down_blocks_0', 'resnets_0', 'time_emb_proj', 'kernel'), ('down_blocks_0', 'resnets_1', 'conv1', 'bias'), ('down_blocks_0', 'resnets_1', 'conv1', 'kernel'), ('down_blocks_0', 'resnets_1', 'conv2', 'bias'), ('down_blocks_0', 'resnets_1', 'conv2', 'kernel'), ('down_blocks_0', 'resnets_1', 'norm1', 'bias'), ('down_blocks_0', 'resnets_1', 'norm1', 'scale'), ('down_blocks_0', 'resnets_1', 'norm2', 'bias'), ('down_blocks_0', 'resnets_1', 'norm2', 'scale'), ('down_blocks_0', 'resnets_1', 'time_emb_proj', 'bias'), ('down_blocks_0', 'resnets_1', 'time_emb_proj', 'kernel'), ('down_blocks_1', 'attentions_0', 'norm', 'bias'), ('down_blocks_1', 'attentions_0', 'norm', 'scale'), ('down_blocks_1', 'attentions_0', 'proj_in', 'bias'), ('down_blocks_1', 'attentions_0', 'proj_in', 'kernel'), ('down_blocks_1', 'attentions_0', 'proj_out', 'bias'), ('down_blocks_1', 'attentions_0', 'proj_out', 'kernel'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'bias'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'kernel'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'norm1', 'bias'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'norm1', 'scale'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'norm2', 'bias'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'norm2', 'scale'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'norm3', 'bias'), ('down_blocks_1', 'attentions_0', 'transformer_blocks_0', 'norm3', 'scale'), ('down_blocks_1', 'attentions_1', 'norm', 'bias'), ('down_blocks_1', 'attentions_1', 'norm', 'scale'), ('down_blocks_1', 'attentions_1', 'proj_in', 'bias'), ('down_blocks_1', 'attentions_1', 'proj_in', 'kernel'), ('down_blocks_1', 'attentions_1', 'proj_out', 'bias'), ('down_blocks_1', 'attentions_1', 'proj_out', 'kernel'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_2', 'bias'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_2', 'kernel'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'norm1', 'bias'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'norm1', 'scale'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'norm2', 'bias'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'norm2', 'scale'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'norm3', 'bias'), ('down_blocks_1', 'attentions_1', 'transformer_blocks_0', 'norm3', 'scale'), ('down_blocks_1', 'downsamplers_0', 'conv', 'bias'), ('down_blocks_1', 'downsamplers_0', 'conv', 'kernel'), ('down_blocks_1', 'resnets_0', 'conv1', 'bias'), ('down_blocks_1', 'resnets_0', 'conv1', 'kernel'), ('down_blocks_1', 'resnets_0', 'conv2', 'bias'), ('down_blocks_1', 'resnets_0', 'conv2', 'kernel'), ('down_blocks_1', 'resnets_0', 'conv_shortcut', 'bias'), ('down_blocks_1', 'resnets_0', 'conv_shortcut', 'kernel'), ('down_blocks_1', 'resnets_0', 'norm1', 'bias'), ('down_blocks_1', 'resnets_0', 'norm1', 'scale'), ('down_blocks_1', 'resnets_0', 'norm2', 'bias'), ('down_blocks_1', 'resnets_0', 'norm2', 'scale'), ('down_blocks_1', 'resnets_0', 'time_emb_proj', 'bias'), ('down_blocks_1', 'resnets_0', 'time_emb_proj', 'kernel'), ('down_blocks_1', 'resnets_1', 'conv1', 'bias'), ('down_blocks_1', 'resnets_1', 'conv1', 'kernel'), ('down_blocks_1', 'resnets_1', 'conv2', 'bias'), ('down_blocks_1', 'resnets_1', 'conv2', 'kernel'), ('down_blocks_1', 'resnets_1', 'norm1', 'bias'), ('down_blocks_1', 'resnets_1', 'norm1', 'scale'), ('down_blocks_1', 'resnets_1', 'norm2', 'bias'), ('down_blocks_1', 'resnets_1', 'norm2', 'scale'), ('down_blocks_1', 'resnets_1', 'time_emb_proj', 'bias'), ('down_blocks_1', 'resnets_1', 'time_emb_proj', 'kernel'), ('down_blocks_2', 'attentions_0', 'norm', 'bias'), ('down_blocks_2', 'attentions_0', 'norm', 'scale'), ('down_blocks_2', 'attentions_0', 'proj_in', 'bias'), ('down_blocks_2', 'attentions_0', 'proj_in', 'kernel'), ('down_blocks_2', 'attentions_0', 'proj_out', 'bias'), ('down_blocks_2', 'attentions_0', 'proj_out', 'kernel'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'bias'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'kernel'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'norm1', 'bias'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'norm1', 'scale'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'norm2', 'bias'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'norm2', 'scale'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'norm3', 'bias'), ('down_blocks_2', 'attentions_0', 'transformer_blocks_0', 'norm3', 'scale'), ('down_blocks_2', 'attentions_1', 'norm', 'bias'), ('down_blocks_2', 'attentions_1', 'norm', 'scale'), ('down_blocks_2', 'attentions_1', 'proj_in', 'bias'), ('down_blocks_2', 'attentions_1', 'proj_in', 'kernel'), ('down_blocks_2', 'attentions_1', 'proj_out', 'bias'), ('down_blocks_2', 'attentions_1', 'proj_out', 'kernel'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_2', 'bias'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_2', 'kernel'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'norm1', 'bias'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'norm1', 'scale'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'norm2', 'bias'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'norm2', 'scale'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'norm3', 'bias'), ('down_blocks_2', 'attentions_1', 'transformer_blocks_0', 'norm3', 'scale'), ('down_blocks_2', 'downsamplers_0', 'conv', 'bias'), ('down_blocks_2', 'downsamplers_0', 'conv', 'kernel'), ('down_blocks_2', 'resnets_0', 'conv1', 'bias'), ('down_blocks_2', 'resnets_0', 'conv1', 'kernel'), ('down_blocks_2', 'resnets_0', 'conv2', 'bias'), ('down_blocks_2', 'resnets_0', 'conv2', 'kernel'), ('down_blocks_2', 'resnets_0', 'conv_shortcut', 'bias'), ('down_blocks_2', 'resnets_0', 'conv_shortcut', 'kernel'), ('down_blocks_2', 'resnets_0', 'norm1', 'bias'), ('down_blocks_2', 'resnets_0', 'norm1', 'scale'), ('down_blocks_2', 'resnets_0', 'norm2', 'bias'), ('down_blocks_2', 'resnets_0', 'norm2', 'scale'), ('down_blocks_2', 'resnets_0', 'time_emb_proj', 'bias'), ('down_blocks_2', 'resnets_0', 'time_emb_proj', 'kernel'), ('down_blocks_2', 'resnets_1', 'conv1', 'bias'), ('down_blocks_2', 'resnets_1', 'conv1', 'kernel'), ('down_blocks_2', 'resnets_1', 'conv2', 'bias'), ('down_blocks_2', 'resnets_1', 'conv2', 'kernel'), ('down_blocks_2', 'resnets_1', 'norm1', 'bias'), ('down_blocks_2', 'resnets_1', 'norm1', 'scale'), ('down_blocks_2', 'resnets_1', 'norm2', 'bias'), ('down_blocks_2', 'resnets_1', 'norm2', 'scale'), ('down_blocks_2', 'resnets_1', 'time_emb_proj', 'bias'), ('down_blocks_2', 'resnets_1', 'time_emb_proj', 'kernel'), ('down_blocks_3', 'resnets_0', 'conv1', 'bias'), ('down_blocks_3', 'resnets_0', 'conv1', 'kernel'), ('down_blocks_3', 'resnets_0', 'conv2', 'bias'), ('down_blocks_3', 'resnets_0', 'conv2', 'kernel'), ('down_blocks_3', 'resnets_0', 'norm1', 'bias'), ('down_blocks_3', 'resnets_0', 'norm1', 'scale'), ('down_blocks_3', 'resnets_0', 'norm2', 'bias'), ('down_blocks_3', 'resnets_0', 'norm2', 'scale'), ('down_blocks_3', 'resnets_0', 'time_emb_proj', 'bias'), ('down_blocks_3', 'resnets_0', 'time_emb_proj', 'kernel'), ('down_blocks_3', 'resnets_1', 'conv1', 'bias'), ('down_blocks_3', 'resnets_1', 'conv1', 'kernel'), ('down_blocks_3', 'resnets_1', 'conv2', 'bias'), ('down_blocks_3', 'resnets_1', 'conv2', 'kernel'), ('down_blocks_3', 'resnets_1', 'norm1', 'bias'), ('down_blocks_3', 'resnets_1', 'norm1', 'scale'), ('down_blocks_3', 'resnets_1', 'norm2', 'bias'), ('down_blocks_3', 'resnets_1', 'norm2', 'scale'), ('down_blocks_3', 'resnets_1', 'time_emb_proj', 'bias'), ('down_blocks_3', 'resnets_1', 'time_emb_proj', 'kernel'), ('mid_block', 'attentions_0', 'norm', 'bias'), ('mid_block', 'attentions_0', 'norm', 'scale'), ('mid_block', 'attentions_0', 'proj_in', 'bias'), ('mid_block', 'attentions_0', 'proj_in', 'kernel'), ('mid_block', 'attentions_0', 'proj_out', 'bias'), ('mid_block', 'attentions_0', 'proj_out', 'kernel'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'bias'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'kernel'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'norm1', 'bias'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'norm1', 'scale'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'norm2', 'bias'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'norm2', 'scale'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'norm3', 'bias'), ('mid_block', 'attentions_0', 'transformer_blocks_0', 'norm3', 'scale'), ('mid_block', 'resnets_0', 'conv1', 'bias'), ('mid_block', 'resnets_0', 'conv1', 'kernel'), ('mid_block', 'resnets_0', 'conv2', 'bias'), ('mid_block', 'resnets_0', 'conv2', 'kernel'), ('mid_block', 'resnets_0', 'norm1', 'bias'), ('mid_block', 'resnets_0', 'norm1', 'scale'), ('mid_block', 'resnets_0', 'norm2', 'bias'), ('mid_block', 'resnets_0', 'norm2', 'scale'), ('mid_block', 'resnets_0', 'time_emb_proj', 'bias'), ('mid_block', 'resnets_0', 'time_emb_proj', 'kernel'), ('mid_block', 'resnets_1', 'conv1', 'bias'), ('mid_block', 'resnets_1', 'conv1', 'kernel'), ('mid_block', 'resnets_1', 'conv2', 'bias'), ('mid_block', 'resnets_1', 'conv2', 'kernel'), ('mid_block', 'resnets_1', 'norm1', 'bias'), ('mid_block', 'resnets_1', 'norm1', 'scale'), ('mid_block', 'resnets_1', 'norm2', 'bias'), ('mid_block', 'resnets_1', 'norm2', 'scale'), ('mid_block', 'resnets_1', 'time_emb_proj', 'bias'), ('mid_block', 'resnets_1', 'time_emb_proj', 'kernel'), ('time_embedding', 'linear_1', 'bias'), ('time_embedding', 'linear_1', 'kernel'), ('time_embedding', 'linear_2', 'bias'), ('time_embedding', 'linear_2', 'kernel'), ('up_blocks_0', 'resnets_0', 'conv1', 'bias'), ('up_blocks_0', 'resnets_0', 'conv1', 'kernel'), ('up_blocks_0', 'resnets_0', 'conv2', 'bias'), ('up_blocks_0', 'resnets_0', 'conv2', 'kernel'), ('up_blocks_0', 'resnets_0', 'conv_shortcut', 'bias'), ('up_blocks_0', 'resnets_0', 'conv_shortcut', 'kernel'), ('up_blocks_0', 'resnets_0', 'norm1', 'bias'), ('up_blocks_0', 'resnets_0', 'norm1', 'scale'), ('up_blocks_0', 'resnets_0', 'norm2', 'bias'), ('up_blocks_0', 'resnets_0', 'norm2', 'scale'), ('up_blocks_0', 'resnets_0', 'time_emb_proj', 'bias'), ('up_blocks_0', 'resnets_0', 'time_emb_proj', 'kernel'), ('up_blocks_0', 'resnets_1', 'conv1', 'bias'), ('up_blocks_0', 'resnets_1', 'conv1', 'kernel'), ('up_blocks_0', 'resnets_1', 'conv2', 'bias'), ('up_blocks_0', 'resnets_1', 'conv2', 'kernel'), ('up_blocks_0', 'resnets_1', 'conv_shortcut', 'bias'), ('up_blocks_0', 'resnets_1', 'conv_shortcut', 'kernel'), ('up_blocks_0', 'resnets_1', 'norm1', 'bias'), ('up_blocks_0', 'resnets_1', 'norm1', 'scale'), ('up_blocks_0', 'resnets_1', 'norm2', 'bias'), ('up_blocks_0', 'resnets_1', 'norm2', 'scale'), ('up_blocks_0', 'resnets_1', 'time_emb_proj', 'bias'), ('up_blocks_0', 'resnets_1', 'time_emb_proj', 'kernel'), ('up_blocks_0', 'resnets_2', 'conv1', 'bias'), ('up_blocks_0', 'resnets_2', 'conv1', 'kernel'), ('up_blocks_0', 'resnets_2', 'conv2', 'bias'), ('up_blocks_0', 'resnets_2', 'conv2', 'kernel'), ('up_blocks_0', 'resnets_2', 'conv_shortcut', 'bias'), ('up_blocks_0', 'resnets_2', 'conv_shortcut', 'kernel'), ('up_blocks_0', 'resnets_2', 'norm1', 'bias'), ('up_blocks_0', 'resnets_2', 'norm1', 'scale'), ('up_blocks_0', 'resnets_2', 'norm2', 'bias'), ('up_blocks_0', 'resnets_2', 'norm2', 'scale'), ('up_blocks_0', 'resnets_2', 'time_emb_proj', 'bias'), ('up_blocks_0', 'resnets_2', 'time_emb_proj', 'kernel'), ('up_blocks_0', 'upsamplers_0', 'conv', 'bias'), ('up_blocks_0', 'upsamplers_0', 'conv', 'kernel'), ('up_blocks_1', 'attentions_0', 'norm', 'bias'), ('up_blocks_1', 'attentions_0', 'norm', 'scale'), ('up_blocks_1', 'attentions_0', 'proj_in', 'bias'), ('up_blocks_1', 'attentions_0', 'proj_in', 'kernel'), ('up_blocks_1', 'attentions_0', 'proj_out', 'bias'), ('up_blocks_1', 'attentions_0', 'proj_out', 'kernel'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'bias'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'kernel'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'norm1', 'bias'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'norm1', 'scale'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'norm2', 'bias'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'norm2', 'scale'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'norm3', 'bias'), ('up_blocks_1', 'attentions_0', 'transformer_blocks_0', 'norm3', 'scale'), ('up_blocks_1', 'attentions_1', 'norm', 'bias'), ('up_blocks_1', 'attentions_1', 'norm', 'scale'), ('up_blocks_1', 'attentions_1', 'proj_in', 'bias'), ('up_blocks_1', 'attentions_1', 'proj_in', 'kernel'), ('up_blocks_1', 'attentions_1', 'proj_out', 'bias'), ('up_blocks_1', 'attentions_1', 'proj_out', 'kernel'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_2', 'bias'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_2', 'kernel'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'norm1', 'bias'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'norm1', 'scale'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'norm2', 'bias'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'norm2', 'scale'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'norm3', 'bias'), ('up_blocks_1', 'attentions_1', 'transformer_blocks_0', 'norm3', 'scale'), ('up_blocks_1', 'attentions_2', 'norm', 'bias'), ('up_blocks_1', 'attentions_2', 'norm', 'scale'), ('up_blocks_1', 'attentions_2', 'proj_in', 'bias'), ('up_blocks_1', 'attentions_2', 'proj_in', 'kernel'), ('up_blocks_1', 'attentions_2', 'proj_out', 'bias'), ('up_blocks_1', 'attentions_2', 'proj_out', 'kernel'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'ff', 'net_2', 'bias'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'ff', 'net_2', 'kernel'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'norm1', 'bias'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'norm1', 'scale'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'norm2', 'bias'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'norm2', 'scale'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'norm3', 'bias'), ('up_blocks_1', 'attentions_2', 'transformer_blocks_0', 'norm3', 'scale'), ('up_blocks_1', 'resnets_0', 'conv1', 'bias'), ('up_blocks_1', 'resnets_0', 'conv1', 'kernel'), ('up_blocks_1', 'resnets_0', 'conv2', 'bias'), ('up_blocks_1', 'resnets_0', 'conv2', 'kernel'), ('up_blocks_1', 'resnets_0', 'conv_shortcut', 'bias'), ('up_blocks_1', 'resnets_0', 'conv_shortcut', 'kernel'), ('up_blocks_1', 'resnets_0', 'norm1', 'bias'), ('up_blocks_1', 'resnets_0', 'norm1', 'scale'), ('up_blocks_1', 'resnets_0', 'norm2', 'bias'), ('up_blocks_1', 'resnets_0', 'norm2', 'scale'), ('up_blocks_1', 'resnets_0', 'time_emb_proj', 'bias'), ('up_blocks_1', 'resnets_0', 'time_emb_proj', 'kernel'), ('up_blocks_1', 'resnets_1', 'conv1', 'bias'), ('up_blocks_1', 'resnets_1', 'conv1', 'kernel'), ('up_blocks_1', 'resnets_1', 'conv2', 'bias'), ('up_blocks_1', 'resnets_1', 'conv2', 'kernel'), ('up_blocks_1', 'resnets_1', 'conv_shortcut', 'bias'), ('up_blocks_1', 'resnets_1', 'conv_shortcut', 'kernel'), ('up_blocks_1', 'resnets_1', 'norm1', 'bias'), ('up_blocks_1', 'resnets_1', 'norm1', 'scale'), ('up_blocks_1', 'resnets_1', 'norm2', 'bias'), ('up_blocks_1', 'resnets_1', 'norm2', 'scale'), ('up_blocks_1', 'resnets_1', 'time_emb_proj', 'bias'), ('up_blocks_1', 'resnets_1', 'time_emb_proj', 'kernel'), ('up_blocks_1', 'resnets_2', 'conv1', 'bias'), ('up_blocks_1', 'resnets_2', 'conv1', 'kernel'), ('up_blocks_1', 'resnets_2', 'conv2', 'bias'), ('up_blocks_1', 'resnets_2', 'conv2', 'kernel'), ('up_blocks_1', 'resnets_2', 'conv_shortcut', 'bias'), ('up_blocks_1', 'resnets_2', 'conv_shortcut', 'kernel'), ('up_blocks_1', 'resnets_2', 'norm1', 'bias'), ('up_blocks_1', 'resnets_2', 'norm1', 'scale'), ('up_blocks_1', 'resnets_2', 'norm2', 'bias'), ('up_blocks_1', 'resnets_2', 'norm2', 'scale'), ('up_blocks_1', 'resnets_2', 'time_emb_proj', 'bias'), ('up_blocks_1', 'resnets_2', 'time_emb_proj', 'kernel'), ('up_blocks_1', 'upsamplers_0', 'conv', 'bias'), ('up_blocks_1', 'upsamplers_0', 'conv', 'kernel'), ('up_blocks_2', 'attentions_0', 'norm', 'bias'), ('up_blocks_2', 'attentions_0', 'norm', 'scale'), ('up_blocks_2', 'attentions_0', 'proj_in', 'bias'), ('up_blocks_2', 'attentions_0', 'proj_in', 'kernel'), ('up_blocks_2', 'attentions_0', 'proj_out', 'bias'), ('up_blocks_2', 'attentions_0', 'proj_out', 'kernel'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'bias'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'kernel'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'norm1', 'bias'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'norm1', 'scale'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'norm2', 'bias'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'norm2', 'scale'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'norm3', 'bias'), ('up_blocks_2', 'attentions_0', 'transformer_blocks_0', 'norm3', 'scale'), ('up_blocks_2', 'attentions_1', 'norm', 'bias'), ('up_blocks_2', 'attentions_1', 'norm', 'scale'), ('up_blocks_2', 'attentions_1', 'proj_in', 'bias'), ('up_blocks_2', 'attentions_1', 'proj_in', 'kernel'), ('up_blocks_2', 'attentions_1', 'proj_out', 'bias'), ('up_blocks_2', 'attentions_1', 'proj_out', 'kernel'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_2', 'bias'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_2', 'kernel'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'norm1', 'bias'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'norm1', 'scale'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'norm2', 'bias'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'norm2', 'scale'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'norm3', 'bias'), ('up_blocks_2', 'attentions_1', 'transformer_blocks_0', 'norm3', 'scale'), ('up_blocks_2', 'attentions_2', 'norm', 'bias'), ('up_blocks_2', 'attentions_2', 'norm', 'scale'), ('up_blocks_2', 'attentions_2', 'proj_in', 'bias'), ('up_blocks_2', 'attentions_2', 'proj_in', 'kernel'), ('up_blocks_2', 'attentions_2', 'proj_out', 'bias'), ('up_blocks_2', 'attentions_2', 'proj_out', 'kernel'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'ff', 'net_2', 'bias'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'ff', 'net_2', 'kernel'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'norm1', 'bias'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'norm1', 'scale'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'norm2', 'bias'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'norm2', 'scale'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'norm3', 'bias'), ('up_blocks_2', 'attentions_2', 'transformer_blocks_0', 'norm3', 'scale'), ('up_blocks_2', 'resnets_0', 'conv1', 'bias'), ('up_blocks_2', 'resnets_0', 'conv1', 'kernel'), ('up_blocks_2', 'resnets_0', 'conv2', 'bias'), ('up_blocks_2', 'resnets_0', 'conv2', 'kernel'), ('up_blocks_2', 'resnets_0', 'conv_shortcut', 'bias'), ('up_blocks_2', 'resnets_0', 'conv_shortcut', 'kernel'), ('up_blocks_2', 'resnets_0', 'norm1', 'bias'), ('up_blocks_2', 'resnets_0', 'norm1', 'scale'), ('up_blocks_2', 'resnets_0', 'norm2', 'bias'), ('up_blocks_2', 'resnets_0', 'norm2', 'scale'), ('up_blocks_2', 'resnets_0', 'time_emb_proj', 'bias'), ('up_blocks_2', 'resnets_0', 'time_emb_proj', 'kernel'), ('up_blocks_2', 'resnets_1', 'conv1', 'bias'), ('up_blocks_2', 'resnets_1', 'conv1', 'kernel'), ('up_blocks_2', 'resnets_1', 'conv2', 'bias'), ('up_blocks_2', 'resnets_1', 'conv2', 'kernel'), ('up_blocks_2', 'resnets_1', 'conv_shortcut', 'bias'), ('up_blocks_2', 'resnets_1', 'conv_shortcut', 'kernel'), ('up_blocks_2', 'resnets_1', 'norm1', 'bias'), ('up_blocks_2', 'resnets_1', 'norm1', 'scale'), ('up_blocks_2', 'resnets_1', 'norm2', 'bias'), ('up_blocks_2', 'resnets_1', 'norm2', 'scale'), ('up_blocks_2', 'resnets_1', 'time_emb_proj', 'bias'), ('up_blocks_2', 'resnets_1', 'time_emb_proj', 'kernel'), ('up_blocks_2', 'resnets_2', 'conv1', 'bias'), ('up_blocks_2', 'resnets_2', 'conv1', 'kernel'), ('up_blocks_2', 'resnets_2', 'conv2', 'bias'), ('up_blocks_2', 'resnets_2', 'conv2', 'kernel'), ('up_blocks_2', 'resnets_2', 'conv_shortcut', 'bias'), ('up_blocks_2', 'resnets_2', 'conv_shortcut', 'kernel'), ('up_blocks_2', 'resnets_2', 'norm1', 'bias'), ('up_blocks_2', 'resnets_2', 'norm1', 'scale'), ('up_blocks_2', 'resnets_2', 'norm2', 'bias'), ('up_blocks_2', 'resnets_2', 'norm2', 'scale'), ('up_blocks_2', 'resnets_2', 'time_emb_proj', 'bias'), ('up_blocks_2', 'resnets_2', 'time_emb_proj', 'kernel'), ('up_blocks_2', 'upsamplers_0', 'conv', 'bias'), ('up_blocks_2', 'upsamplers_0', 'conv', 'kernel'), ('up_blocks_3', 'attentions_0', 'norm', 'bias'), ('up_blocks_3', 'attentions_0', 'norm', 'scale'), ('up_blocks_3', 'attentions_0', 'proj_in', 'bias'), ('up_blocks_3', 'attentions_0', 'proj_in', 'kernel'), ('up_blocks_3', 'attentions_0', 'proj_out', 'bias'), ('up_blocks_3', 'attentions_0', 'proj_out', 'kernel'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'bias'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'ff', 'net_2', 'kernel'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'norm1', 'bias'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'norm1', 'scale'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'norm2', 'bias'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'norm2', 'scale'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'norm3', 'bias'), ('up_blocks_3', 'attentions_0', 'transformer_blocks_0', 'norm3', 'scale'), ('up_blocks_3', 'attentions_1', 'norm', 'bias'), ('up_blocks_3', 'attentions_1', 'norm', 'scale'), ('up_blocks_3', 'attentions_1', 'proj_in', 'bias'), ('up_blocks_3', 'attentions_1', 'proj_in', 'kernel'), ('up_blocks_3', 'attentions_1', 'proj_out', 'bias'), ('up_blocks_3', 'attentions_1', 'proj_out', 'kernel'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_2', 'bias'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'ff', 'net_2', 'kernel'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'norm1', 'bias'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'norm1', 'scale'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'norm2', 'bias'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'norm2', 'scale'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'norm3', 'bias'), ('up_blocks_3', 'attentions_1', 'transformer_blocks_0', 'norm3', 'scale'), ('up_blocks_3', 'attentions_2', 'norm', 'bias'), ('up_blocks_3', 'attentions_2', 'norm', 'scale'), ('up_blocks_3', 'attentions_2', 'proj_in', 'bias'), ('up_blocks_3', 'attentions_2', 'proj_in', 'kernel'), ('up_blocks_3', 'attentions_2', 'proj_out', 'bias'), ('up_blocks_3', 'attentions_2', 'proj_out', 'kernel'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_k', 'kernel'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_out_0', 'bias'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_out_0', 'kernel'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_q', 'kernel'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'attn1', 'to_v', 'kernel'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_k', 'kernel'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_out_0', 'bias'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_out_0', 'kernel'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_q', 'kernel'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'attn2', 'to_v', 'kernel'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'bias'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'ff', 'net_0', 'proj', 'kernel'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'ff', 'net_2', 'bias'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'ff', 'net_2', 'kernel'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'norm1', 'bias'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'norm1', 'scale'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'norm2', 'bias'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'norm2', 'scale'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'norm3', 'bias'), ('up_blocks_3', 'attentions_2', 'transformer_blocks_0', 'norm3', 'scale'), ('up_blocks_3', 'resnets_0', 'conv1', 'bias'), ('up_blocks_3', 'resnets_0', 'conv1', 'kernel'), ('up_blocks_3', 'resnets_0', 'conv2', 'bias'), ('up_blocks_3', 'resnets_0', 'conv2', 'kernel'), ('up_blocks_3', 'resnets_0', 'conv_shortcut', 'bias'), ('up_blocks_3', 'resnets_0', 'conv_shortcut', 'kernel'), ('up_blocks_3', 'resnets_0', 'norm1', 'bias'), ('up_blocks_3', 'resnets_0', 'norm1', 'scale'), ('up_blocks_3', 'resnets_0', 'norm2', 'bias'), ('up_blocks_3', 'resnets_0', 'norm2', 'scale'), ('up_blocks_3', 'resnets_0', 'time_emb_proj', 'bias'), ('up_blocks_3', 'resnets_0', 'time_emb_proj', 'kernel'), ('up_blocks_3', 'resnets_1', 'conv1', 'bias'), ('up_blocks_3', 'resnets_1', 'conv1', 'kernel'), ('up_blocks_3', 'resnets_1', 'conv2', 'bias'), ('up_blocks_3', 'resnets_1', 'conv2', 'kernel'), ('up_blocks_3', 'resnets_1', 'conv_shortcut', 'bias'), ('up_blocks_3', 'resnets_1', 'conv_shortcut', 'kernel'), ('up_blocks_3', 'resnets_1', 'norm1', 'bias'), ('up_blocks_3', 'resnets_1', 'norm1', 'scale'), ('up_blocks_3', 'resnets_1', 'norm2', 'bias'), ('up_blocks_3', 'resnets_1', 'norm2', 'scale'), ('up_blocks_3', 'resnets_1', 'time_emb_proj', 'bias'), ('up_blocks_3', 'resnets_1', 'time_emb_proj', 'kernel'), ('up_blocks_3', 'resnets_2', 'conv1', 'bias'), ('up_blocks_3', 'resnets_2', 'conv1', 'kernel'), ('up_blocks_3', 'resnets_2', 'conv2', 'bias'), ('up_blocks_3', 'resnets_2', 'conv2', 'kernel'), ('up_blocks_3', 'resnets_2', 'conv_shortcut', 'bias'), ('up_blocks_3', 'resnets_2', 'conv_shortcut', 'kernel'), ('up_blocks_3', 'resnets_2', 'norm1', 'bias'), ('up_blocks_3', 'resnets_2', 'norm1', 'scale'), ('up_blocks_3', 'resnets_2', 'norm2', 'bias'), ('up_blocks_3', 'resnets_2', 'norm2', 'scale'), ('up_blocks_3', 'resnets_2', 'time_emb_proj', 'bias'), ('up_blocks_3', 'resnets_2', 'time_emb_proj', 'kernel')]
You should probably UPCAST the model weights to float32 if this was not intended. See [`~ModelMixin.to_fp32`] for further information on how to do this.

Inference

Since TPUs usually have 8 devices working in parallel, we'll replicate our prompt as many times as devices we have. Then we'll perform inference on the 8 devices at once, each responsible for generating one image. Thus, we'll get 8 images in the same amount of time it takes for one chip to generate a single one.

After replicating the prompt, we obtain the tokenized text ids by invoking the prepare_inputs function of the pipeline. The length of the tokenized text is set to 77 tokens, as required by the configuration of the underlying CLIP Text model.

[8]
(8, 77)

Replication and parallelization

Model parameters and inputs have to be replicated across the 8 parallel devices we have. The parameters dictionary is replicated using flax.jax_utils.replicate, which traverses the dictionary and changes the shape of the weights so they are repeated 8 times. Arrays are replicated using shard.

[9]
[10]
(8, 1, 77)

That shape means that each one of the 8 devices will receive as an input a jnp array with shape (1, 77). 1 is therefore the batch size per device. In TPUs with sufficient memory, it could be larger than 1 if we wanted to generate multiple images (per chip) at once.

We are almost ready to generate images! We just need to create a random number generator to pass to the generation function. This is the standard procedure in Flax, which is very serious and opinionated about random numbers – all functions that deal with random numbers are expected to receive a generator. This ensures reproducibility, even when we are training across multiple distributed devices.

The helper function below uses a seed to initialize a random number generator. As long as we use the same seed, we'll get the exact same results. Feel free to use different seeds when exploring results later in the notebook.

[11]

We obtain a rng and then "split" it 8 times so each device receives a different generator. Therefore, each device will create a different image, and the full process is reproducible.

[12]

JAX code can be compiled to an efficient representation that runs very fast. However, we need to ensure that all inputs have the same shape in subsequent calls; otherwise, JAX will have to recompile the code, and we wouldn't be able to take advantage of the optimized speed.

The Flax pipeline can compile the code for us if we pass jit = True as an argument. It will also ensure that the model runs in parallel in the 8 available devices.

The first time we run the following cell it will take a long time to compile, but subequent calls (even with different inputs) will be much faster. For example, it took more than a minute to compile in a TPU v2-8 when I tested, but then it takes about 7s for future inference runs.

[13]
CPU times: user 56.2 s, sys: 42.5 s, total: 1min 38s
Wall time: 1min 29s

The returned array has shape (8, 1, 512, 512, 3). We reshape it to get rid of the second dimension and obtain 8 images of 512 × 512 × 3 and then convert them to PIL.

[14]

Visualization

Let's create a helper function to display images in a grid.

[15]
[ ]