Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion comfy/latent_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ class HunyuanVideo(LatentFormat):
]

latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761]
taesd_decoder_name = "taehv"

class Cosmos1CV8x8x8(LatentFormat):
latent_channels = 16
Expand Down Expand Up @@ -445,7 +446,7 @@ def __init__(self):
]).view(1, self.latent_channels, 1, 1, 1)


self.taesd_decoder_name = None #TODO
self.taesd_decoder_name = "lighttaew2_1"

def process_in(self, latent):
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
Expand Down Expand Up @@ -516,6 +517,7 @@ class Wan22(Wan21):

def __init__(self):
self.scale_factor = 1.0
self.taesd_decoder_name = "lighttaew2_2"
self.latents_mean = torch.tensor([
-0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557,
-0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825,
Expand Down Expand Up @@ -670,6 +672,7 @@ class HunyuanVideo15(LatentFormat):
latent_channels = 32
latent_dimensions = 3
scale_factor = 1.03682
taesd_decoder_name = "lighttaehy1_5"

class Hunyuan3Dv2(LatentFormat):
latent_channels = 64
Expand Down
34 changes: 33 additions & 1 deletion comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@
import comfy.hooks
import comfy.t2i_adapter.adapter
import comfy.taesd.taesd
import comfy.taesd.taehv
import comfy.latent_formats

import comfy.ldm.flux.redux

Expand Down Expand Up @@ -496,13 +498,14 @@ def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None)
self.memory_used_encode = lambda shape, dtype: 3300 * shape[3] * shape[4] * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: 8000 * shape[3] * shape[4] * (16 * 16) * model_management.dtype_size(dtype)
else: # Wan 2.1 VAE
dim = sd["decoder.head.0.gamma"].shape[0]
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
self.upscale_index_formula = (4, 8, 8)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
self.downscale_index_formula = (4, 8, 8)
self.latent_dim = 3
self.latent_channels = 16
ddconfig = {"dim": 96, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype)
Expand Down Expand Up @@ -572,6 +575,35 @@ def estimate_memory(shape, dtype, num_layers = 16, kv_cache_multiplier = 2):
self.process_input = lambda audio: audio
self.working_dtypes = [torch.float32]
self.crop_input = False
elif "decoder.22.bias" in sd: # taehv, taew and lighttae
self.latent_channels = sd["decoder.1.weight"].shape[1]
self.latent_dim = 3
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
self.upscale_index_formula = (4, 16, 16)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16)
self.downscale_index_formula = (4, 16, 16)
if self.latent_channels == 48: # Wan 2.2
self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=None) # taehv doesn't need scaling
self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently"))
self.process_output = lambda image: image
self.memory_used_decode = lambda shape, dtype: (1800 * (max(1, (shape[-3] ** 0.7 * 0.1)) * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype))
elif self.latent_channels == 32 and sd["decoder.22.bias"].shape[0] == 12: # lighttae_hv15
self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=comfy.latent_formats.HunyuanVideo15)
self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently"))
self.memory_used_decode = lambda shape, dtype: (1200 * (max(1, (shape[-3] ** 0.7 * 0.05)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype))
else:
if sd["decoder.1.weight"].dtype == torch.float16: # taehv currently only available in float16, so assume it's not lighttaew2_1 as otherwise state dicts are identical
latent_format=comfy.latent_formats.HunyuanVideo
else:
latent_format=None # lighttaew2_1 doesn't need scaling
self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=latent_format)
self.process_input = self.process_output = lambda image: image
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
self.upscale_index_formula = (4, 8, 8)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
self.downscale_index_formula = (4, 8, 8)
self.memory_used_encode = lambda shape, dtype: (700 * (max(1, (shape[-3] ** 0.66 * 0.11)) * shape[-2] * shape[-1]) * model_management.dtype_size(dtype))
self.memory_used_decode = lambda shape, dtype: (50 * (max(1, (shape[-3] ** 0.65 * 0.26)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype))
else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None
Expand Down
171 changes: 171 additions & 0 deletions comfy/taesd/taehv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
# Tiny AutoEncoder for HunyuanVideo and WanVideo https://github.com/madebyollin/taehv

import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm
from collections import namedtuple, deque

import comfy.ops
operations=comfy.ops.disable_weight_init

DecoderResult = namedtuple("DecoderResult", ("frame", "memory"))
TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index"))

def conv(n_in, n_out, **kwargs):
return operations.Conv2d(n_in, n_out, 3, padding=1, **kwargs)

class Clamp(nn.Module):
def forward(self, x):
return torch.tanh(x / 3) * 3

class MemBlock(nn.Module):
def __init__(self, n_in, n_out, act_func):
super().__init__()
self.conv = nn.Sequential(conv(n_in * 2, n_out), act_func, conv(n_out, n_out), act_func, conv(n_out, n_out))
self.skip = operations.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
self.act = act_func
def forward(self, x, past):
return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x))

class TPool(nn.Module):
def __init__(self, n_f, stride):
super().__init__()
self.stride = stride
self.conv = operations.Conv2d(n_f*stride,n_f, 1, bias=False)
def forward(self, x):
_NT, C, H, W = x.shape
return self.conv(x.reshape(-1, self.stride * C, H, W))

class TGrow(nn.Module):
def __init__(self, n_f, stride):
super().__init__()
self.stride = stride
self.conv = operations.Conv2d(n_f, n_f*stride, 1, bias=False)
def forward(self, x):
_NT, C, H, W = x.shape
x = self.conv(x)
return x.reshape(-1, C, H, W)

def apply_model_with_memblocks(model, x, parallel, show_progress_bar):

B, T, C, H, W = x.shape
if parallel:
x = x.reshape(B*T, C, H, W)
# parallel over input timesteps, iterate over blocks
for b in tqdm(model, disable=not show_progress_bar):
if isinstance(b, MemBlock):
BT, C, H, W = x.shape
T = BT // B
_x = x.reshape(B, T, C, H, W)
mem = F.pad(_x, (0,0,0,0,0,0,1,0), value=0)[:,:T].reshape(x.shape)
x = b(x, mem)
else:
x = b(x)
BT, C, H, W = x.shape
T = BT // B
x = x.view(B, T, C, H, W)
else:
out = []
work_queue = deque([TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(B, T * C, H, W).chunk(T, dim=1))])
progress_bar = tqdm(range(T), disable=not show_progress_bar)
mem = [None] * len(model)
while work_queue:
xt, i = work_queue.popleft()
if i == 0:
progress_bar.update(1)
if i == len(model):
out.append(xt)
del xt
else:
b = model[i]
if isinstance(b, MemBlock):
if mem[i] is None:
xt_new = b(xt, xt * 0)
mem[i] = xt.detach().clone()
else:
xt_new = b(xt, mem[i])
mem[i] = xt.detach().clone()
del xt
work_queue.appendleft(TWorkItem(xt_new, i+1))
elif isinstance(b, TPool):
if mem[i] is None:
mem[i] = []
mem[i].append(xt.detach().clone())
if len(mem[i]) == b.stride:
B, C, H, W = xt.shape
xt = b(torch.cat(mem[i], 1).view(B*b.stride, C, H, W))
mem[i] = []
work_queue.appendleft(TWorkItem(xt, i+1))
elif isinstance(b, TGrow):
xt = b(xt)
NT, C, H, W = xt.shape
for xt_next in reversed(xt.view(B, b.stride*C, H, W).chunk(b.stride, 1)):
work_queue.appendleft(TWorkItem(xt_next, i+1))
del xt
else:
xt = b(xt)
work_queue.appendleft(TWorkItem(xt, i+1))
progress_bar.close()
x = torch.stack(out, 1)
return x


class TAEHV(nn.Module):
def __init__(self, latent_channels, parallel=False, decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True), latent_format=None, show_progress_bar=True):
super().__init__()
self.image_channels = 3
self.patch_size = 1
self.latent_channels = latent_channels
self.parallel = parallel
self.latent_format = latent_format
self.show_progress_bar = show_progress_bar
self.process_in = latent_format().process_in if latent_format is not None else (lambda x: x)
self.process_out = latent_format().process_out if latent_format is not None else (lambda x: x)
if self.latent_channels in [48, 32]: # Wan 2.2 and HunyuanVideo1.5
self.patch_size = 2
if self.latent_channels == 32: # HunyuanVideo1.5
act_func = nn.LeakyReLU(0.2, inplace=True)
else: # HunyuanVideo, Wan 2.1
act_func = nn.ReLU(inplace=True)

self.encoder = nn.Sequential(
conv(self.image_channels*self.patch_size**2, 64), act_func,
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
TPool(64, 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
conv(64, self.latent_channels),
)
n_f = [256, 128, 64, 64]
self.frames_to_trim = 2**sum(decoder_time_upscale) - 1
self.decoder = nn.Sequential(
Clamp(), conv(self.latent_channels, n_f[0]), act_func,
MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 1), conv(n_f[0], n_f[1], bias=False),
MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), conv(n_f[1], n_f[2], bias=False),
MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), conv(n_f[2], n_f[3], bias=False),
act_func, conv(n_f[3], self.image_channels*self.patch_size**2),
)
@property
def show_progress_bar(self):
return self._show_progress_bar

@show_progress_bar.setter
def show_progress_bar(self, value):
self._show_progress_bar = value

def encode(self, x, **kwargs):
if self.patch_size > 1: x = F.pixel_unshuffle(x, self.patch_size)
x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
if x.shape[1] % 4 != 0:
# pad at end to multiple of 4
n_pad = 4 - x.shape[1] % 4
padding = x[:, -1:].repeat_interleave(n_pad, dim=1)
x = torch.cat([x, padding], 1)
x = apply_model_with_memblocks(self.encoder, x, self.parallel, self.show_progress_bar).movedim(2, 1)
return self.process_out(x)

def decode(self, x, **kwargs):
x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar)
if self.patch_size > 1: x = F.pixel_shuffle(x, self.patch_size)
return x[:, self.frames_to_trim:].movedim(2, 1)
28 changes: 22 additions & 6 deletions latent_preview.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,24 @@
from PIL import Image
from comfy.cli_args import args, LatentPreviewMethod
from comfy.taesd.taesd import TAESD
from comfy.sd import VAE
import comfy.model_management
import folder_paths
import comfy.utils
import logging

MAX_PREVIEW_RESOLUTION = args.preview_size
VIDEO_TAES = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"]

def preview_to_image(latent_image):
latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1
.mul(0xFF) # to 0..255
)
def preview_to_image(latent_image, do_scale=True):
if do_scale:
latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1
.mul(0xFF) # to 0..255
)
else:
latents_ubyte = (latent_image.clamp(0, 1)
.mul(0xFF) # to 0..255
)
if comfy.model_management.directml_enabled:
latents_ubyte = latents_ubyte.to(dtype=torch.uint8)
latents_ubyte = latents_ubyte.to(device="cpu", dtype=torch.uint8, non_blocking=comfy.model_management.device_supports_non_blocking(latent_image.device))
Expand All @@ -35,6 +42,10 @@ def decode_latent_to_preview(self, x0):
x_sample = self.taesd.decode(x0[:1])[0].movedim(0, 2)
return preview_to_image(x_sample)

class TAEHVPreviewerImpl(TAESDPreviewerImpl):
def decode_latent_to_preview(self, x0):
x_sample = self.taesd.decode(x0[:1, :, :1])[0][0]
return preview_to_image(x_sample, do_scale=False)

class Latent2RGBPreviewer(LatentPreviewer):
def __init__(self, latent_rgb_factors, latent_rgb_factors_bias=None):
Expand Down Expand Up @@ -78,8 +89,13 @@ def get_previewer(device, latent_format):

if method == LatentPreviewMethod.TAESD:
if taesd_decoder_path:
taesd = TAESD(None, taesd_decoder_path, latent_channels=latent_format.latent_channels).to(device)
previewer = TAESDPreviewerImpl(taesd)
if latent_format.taesd_decoder_name in VIDEO_TAES:
taesd = VAE(comfy.utils.load_torch_file(taesd_decoder_path))
taesd.first_stage_model.show_progress_bar = False
previewer = TAEHVPreviewerImpl(taesd)
else:
taesd = TAESD(None, taesd_decoder_path, latent_channels=latent_format.latent_channels).to(device)
previewer = TAESDPreviewerImpl(taesd)
else:
logging.warning("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name))

Expand Down
18 changes: 14 additions & 4 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,8 +692,10 @@ def load_lora_model_only(self, model, lora_name, strength_model):
return (self.load_lora(model, None, lora_name, strength_model, 0)[0],)

class VAELoader:
video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"]
image_taes = ["taesd", "taesdxl", "taesd3", "taef1"]
@staticmethod
def vae_list():
def vae_list(s):
vaes = folder_paths.get_filename_list("vae")
approx_vaes = folder_paths.get_filename_list("vae_approx")
sdxl_taesd_enc = False
Expand Down Expand Up @@ -722,6 +724,11 @@ def vae_list():
f1_taesd_dec = True
elif v.startswith("taef1_decoder."):
f1_taesd_enc = True
else:
for tae in s.video_taes:
if v.startswith(tae):
vaes.append(v)

if sd1_taesd_dec and sd1_taesd_enc:
vaes.append("taesd")
if sdxl_taesd_dec and sdxl_taesd_enc:
Expand Down Expand Up @@ -765,7 +772,7 @@ def load_taesd(name):

@classmethod
def INPUT_TYPES(s):
return {"required": { "vae_name": (s.vae_list(), )}}
return {"required": { "vae_name": (s.vae_list(s), )}}
RETURN_TYPES = ("VAE",)
FUNCTION = "load_vae"

Expand All @@ -776,10 +783,13 @@ def load_vae(self, vae_name):
if vae_name == "pixel_space":
sd = {}
sd["pixel_space_vae"] = torch.tensor(1.0)
elif vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]:
elif vae_name in self.image_taes:
sd = self.load_taesd(vae_name)
else:
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
if os.path.splitext(vae_name)[0] in self.video_taes:
vae_path = folder_paths.get_full_path_or_raise("vae_approx", vae_name)
else:
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
sd = comfy.utils.load_torch_file(vae_path)
vae = comfy.sd.VAE(sd=sd)
vae.throw_exception_if_invalid()
Expand Down