o
    9j                      @  sx  d Z ddlmZ ddlZddlZddlZddlmZ ddlm	Z	 ddl
ZddlZddlm  mZ ddlmZ ddlmZ ddlmZmZ dd	lmZmZmZ d
dlmZ d
dlmZ d
dl m!Z! zddl"Z"dZ#W n e$yt   dZ#Y nw ee%& j'd Z(e(d d Z)e(d d Z*e(d d d Z+d3ddZ,d4d!d"Z-d3d#d$Z.d5d&d'Z/	(d6d7d.d/Z0d0d1 Z1e2d2kre1  dS dS )8uB  V3 face training — distill teacher's .npz targets into a causal TCN.

Usage:
    PYTHONPATH=. python3 -m models.v3_face.train               # full train
    PYTHONPATH=. python3 -m models.v3_face.train --smoke       # 10 scenarios × 5 epochs
    PYTHONPATH=. python3 -m models.v3_face.train --device cuda:1 --epochs 80
    )annotationsN)Path)Dict)AdamW)LambdaLR)
DataLoaderWeightedRandomSampler)LIPSYNC_ONLYEXPRESSION_ONLYSHARED_CHANNELS   )V3FaceConfig)BlendshapeDataset)V3FaceModelTF   datav3_trainingemotionmodelsv3_facecheckpointscfgr   returntorch.Tensorc                 C  sb   t j| jf| jt jd}tD ]}| j||< qtD ]}d| j| j  ||< q| j|d< | j|d< |S )zPer-channel L1 weight vector, shape (52,).

    Lipsync channels get cfg.lipsync_weight (audio-sync matters most).
    Eye-blink (ch 8, 9) get cfg.eye_blink_weight (sparse, hard to learn).
    Everything else gets cfg.expression_weight.
    dtype      ?   	   )	torchfull
output_dimexpression_weightfloat32r	   lipsync_weightr   eye_blink_weightr   wch r)   L/dataset/kemix-engine/package/face/animasync-face-v3/models/v3_face/train.pymake_channel_weights+   s   

r+   predtargetvalid_length
ch_weightsc                 C  s   | j \}}}tj|| jdd}||dk  }| |  |ddddf  }	|	jdd}	|	| }	| j	dd}
|	 |
 S )	u   L1 loss masked to valid frames, weighted per channel.

    pred, target:    (B, T, C)
    valid_length:    (B,) int  — number of valid frames per sample
    ch_weights:      (C,)
    devicer   r   Ndim      ?min)
shaper   aranger1   	unsqueezefloatabsmeansumclamp)r,   r-   r.   r/   BTC	frame_idxmaskdiffdenomr)   r)   r*   	masked_l1?   s   rG   c                 C  s   t j| jf| jt jd}tD ]}| j||< qtD ]}| j||< q| j	|d< | j	|d< | j
dur9| j
|d< | j
|d< | jdurHdD ]}| j||< q@|S )a   Per-channel velocity penalty weights, shape (52,).

    Lipsync channels get heavy smoothing; brows + cheek get light smoothing
    (preserve V2 prosody motion); eye-blink channels get near-zero so the
    sharp 5-frame blink kernel survives training.
    r   r   r   N      )r   r   r         )r   r    r!   velocity_expression_weightr#   r	   velocity_lipsync_weightr   velocity_shared_weightvelocity_eye_blink_weightvelocity_eye_squint_weightvelocity_brow_weightr&   r)   r)   r*   make_velocity_weightsR   s    





rR   vel_weightsc                 C  s   | j \}}}| ddddf | ddddf  }|ddddf |ddddf  }||  |ddddf  }	|	jdd}	tj|d | jdd}
|
|d dk  }|	| }	| j	dd}|	 | S )	zEPer-channel-weighted L1 on per-frame difference (smoothness penalty).Nr   r2   r3   r0   r   r5   r6   )
r8   r<   r=   r   r9   r1   r:   r;   r>   r?   )r,   r-   r.   rS   r@   rA   _pred_vtarget_vrE   rC   rD   rF   r)   r)   r*   masked_velocityr   s   ((rW   r5   trainboolvelocity_scaler;   Dict[str, float]c
                 C  s  |  | ddddd}
|rt nt }| |D ]}|d j|dd}|d j|dd}|d j|dd}|d	 j|dd}| ||}t||||}t|||||	 }|| }|ry|jdd
 |  tj	j
|  |j |  |  |
d  t| 7  < |
d  t| 7  < |
d  |jdkrt| nd7  < |
d  d7  < qW d    n1 sw   Y  td|
d }|
d | |
d | |
d | dS )N        r   )lossl1velnaudioT)non_blockingcondr-   r.   )set_to_noner]   r^   r_   r`   r   )r]   r^   r_   )rX   r   enable_gradno_gradtorG   rW   	zero_gradbackwardnnutilsclip_grad_norm_
parameters	grad_clipstepr;   itemvelocity_weightmax)modeldl	optimizer	schedulerr/   rS   r   r1   rX   rZ   sumsgrad_ctxbatchra   rc   r-   r.   r,   loss_l1loss_velr]   r`   r)   r)   r*   	run_epoch   s:   

&$r|   c            $        s6  t  } | jdttd | jdttd | jdttd | jdtd d | jdtd d | jdtd d | jdd	d
 | jdtdd | jdddd | jdg dddd | jdtd dd | jdddd | jdddd | jdddd  | jd!d d"d  | jd#d d$d  | jd%td&d'd | jd(td&d)d | jd*td d+d | jd,td d-d | jd.td d/d | jd0td1d2d | jd3td1d4d | jd5td1d6d | jd7td1d8d | jd9td d:d | jd;td d<d | jd=td d>d | jd?dd@d | jdAtd1dBd | jdCtdDdEd | jdFtd dGd | jdHtd dId | jdJt	d dKd | 
 }t  |jr7|j _|jr?|j _|jrG|j _|j _|j _|j _|j _|j _|j _|j _|j _|j _|j _|j _|j _|j _|j _|j _|j  _ |j! _!|j"j#dLdLdM t$%|j%} jd&ks jd&ks jd urʈ jd ur jn j}t&dN jdOdP jdOdQ|dOdR  jd urt&dS jdOdT j'dOdU  jd urt&dV jdOdT j'dOdU  jd1krt&dW jdOdX  jd1kr"t&dY jdOdZ jd[  j( d\d] j( d^  jd1kr@t&d_ jdOdZ jd[  j( d\d] j( d^  jd1kr^t&d` jdOdZ jd[  j( d\d] j( d^  jd urst&da jdOdb jdOdc  jd urt&dd jdOde jdOdf  jd urt&dg jdOdb jd ur jn jdOdc  jd urt&dh jdOdi  jd urt&dj jdOdk  jrt&dl j dOdm  j!d1krt&dn j!dodZ j!d[  j( d\d] j( d^ t)|j*|j+dp fi dq j,dr jds jdt jdu jdv jdw jdx jdy jdz jd{ jd| jd} jd~ jd j d j!}t)|j*|j+d f j, j j j j j j j j j j j j j  j!d}|j-rg }dD ]\}fdd|j.D d | }|/| qo|s|j.d d }||_.|j.d d |_.d _d _0d _t&dt1| dt1| d j d j  t&dt1| d|2   t&dt1| d|2   |3 j4}	t5|	t1|dLd}
t6| j|
|j7dLdd}t6| jd|j7dLd}t8 9|}t&d|j:d dOd|j;dod|j;d dod |j<d ur[|j<= s1t>d|j< t$j?|j<|dd}|@|d  t&d|j< d|Add d|AdtdddU |jBr|B }t&d|d dOd|jCd dOd ddlDmE} |D ]}q}|jFotG}|jFrtGst&d |r|jHpd jI dt1 jJt1 jK  d jd|j-rdnd }tFjL|jM|jN|i  jO|j:|j;|j;d t1|t1||2 |2 t	||j-d	d tFjP|ddd t&dtFjQR   dd |S D }tT| jd jUd}tVdt1| j d fdd}tW||}tX 9|}tY 9|}ddlDmE}mZ} |j[dkrR|D ]}d1||< d1||< q;t&dt1| d n+|j[dks\|jBr}|D ]}d1||< d1||< q^t&d|jBrrdnd dt1| dƝ tdǃ}t\ jD ]}t]] } j^dkrt_d&| j^ nd&}t`|||||| |dL|dɍ
}t`|||||| |d|dɍ
} t]] | }!t&d|d˛d|d dd|d dd| d dd| d dd|a d dқd|dOd|!dod՝ | d |k }"|"r| d }|b  jO|| d | d |d |d d֜}|j[dkrdnd|j[ }#|jcr+|# d|jc }#|"rGt$d||j"d|# dٝ  t&d|# d|ddU t$d||j"d|# dٝ  |rytFje||d |d |d | d | d | d ||a d |!dޜ
|dߍ qt&d|d t&d|j"  |r|tFjfd< tFg  d S d S )Nz	--npz_dir)typedefaultz--emotion_dirz	--out_dirz--epochsz--batch_sizez--lrz--devicezcuda:0)r~   z--num_workersrK   z--smoke
store_truez)Smoke test: 10 train scenarios, 5 epochs.)actionhelpz--focus)alllipsync
expressionr   zChannel focus. 'all' = train all 52 channels (default). 'lipsync' = only LIPSYNC + SHARED channels have loss (expression branch sees zero gradient). 'expression' = only EXPRESSION_ONLY channels have loss (lipsync branch sees zero gradient).)choicesr~   r   z--resumezvLoad model weights from a checkpoint .pt before training. Use with --freeze_lipsync for phase-2 expression retraining.)r}   r~   r   z--freeze_lipsynczFreeze shared backbone + lipsync branch + lipsync head. Only the expression branch + head will train. Lipsync output stays bit-for-bit identical to what was loaded via --resume.z--wandbz5Log to Weights & Biases. Requires `wandb login` once.z--wandb_projectzanimasync-v3-facezW&B project name.)r~   r   z--wandb_run_namez)W&B run name. Defaults to auto-generated.z--wandb_entityz4W&B entity (team or user). Defaults to your default.z--lipsync-target-gainr5   zMultiply PURE_LIPSYNC target channels (jaw, mouth mechanics, tongue, cheekPuff) by this factor at load time, then clamp [0,1]. Default 1.0 = no change.z--expression-target-gainaR  Multiply EMOTIONAL target channels (brows incl. innerUp, cheekSquint, eyeSquint, eyeWide, mouth Dimple/Frown/Smile, noseSneer) by this factor at load time, then clamp [0,1]. Default 1.0 = no change. If --emotional-mouth-target-gain is also set, THIS knob covers only the pure-expression subset (brows + eyeSquint + eyeWide + cheekSquint).z--emotional-mouth-target-gainu  Optional separate gain for emotional-mouth target channels (mouthDimple, mouthFrown, mouthSmile, noseSneer). When set, decouples from --expression-target-gain so brows/eyes can go higher than mouth. Used by v18b to avoid pushing shared mouth channels past the point where crisp_mouth normalization destabilizes lipsync. None → same as --expression-target-gain (backward-compat with v14/v18).z--velocity-eye-squint-weightu   Per-channel velocity penalty for eyeSquint L/R (ch 18, 19). Higher than the default expression velocity weight suppresses jitter at high gain. None → use velocity_expression_weight (backward-compat). Suggested 0.8 for v18b.z--velocity-brow-weightu   Per-channel velocity penalty for the 5 brow channels (ch 0-4). Same mechanism as the eyeSquint weight — suppresses brow jitter at high expression gain. None → use velocity_expression_weight (backward-compat). Suggested 0.7 for v18c.z--plosive-damp-targetr\   a-  Bake the runtime plosive damper into training targets. When mouthClose > 0.4 on a frame, mouthPress/Roll/Shrug (ch 35,36,39,40,41,42) get multiplied by (1 - this_value * smoothstep). 0 = off (default). 0.30 = matches the production main-viewer setting that prevents 'lips swallowed' on m/b/p plosives.z--smooth-target-sigma-browu  Gaussian σ (frames @ 30 fps) for pre-smoothing the 5 brow target channels BEFORE the gain. The proper fix for brow flicker at high gain — smooths the input the model is asked to fit so jitter never enters the training signal. 0 = off (default). Suggested 2.0 (~67ms) for v18e.z--smooth-target-sigma-eye-wideu   Gaussian σ (frames @ 30 fps) on the eyeWide target channels (20, 21) before gain. Pair with eyeSquint smoothing to fix surprise → other-emotion transitions where eyeWide↓ and eyeSquint↑ cross discontinuously. Typical: same value as eyeSquint.z --smooth-target-sigma-eye-squintu  Gaussian σ (frames @ 30 fps) for pre-smoothing the eyeSquint L/R target channels BEFORE the gain. Real orbicularis oculi is slow + sustained, so we can smooth heavier than brows without losing useful motion. 0 = off (default). Suggested 3.0 (~100ms) for v18e.z--brow-innerup-happy-gainuq  Override the gain on browInnerUp (ch 2) for happy emotions (joy / laughter / excitement / gratitude). Set to 1.0 with expression-target-gain=2.2 to keep browInnerUp un-amplified on happy frames so the avatar doesn't look concerned/apologetic when saying happy things, while leaving the 2.2× boost everywhere else. None → no override (backward compat with v14..v18f).z--brow-happy-gainu  Per-emotion gain override for ALL brow channels (ch 0–4) on happy frames (joy / laughter / excitement / gratitude). Use when brows are pushed very high via --brow-target-gain but should stay calm (~v14 level 1.4) on smiling frames. None → no override (backward compat).z--brow-target-gainu.  Override target gain for ALL brow channels (ch 0–4) independently of --expression-target-gain. Use to keep brows at a v14-style mild gain (e.g. 1.4) while pushing the rest of expression (cheekSquint / eyeSquint / eyeWide) much higher. None → brows ride the global expression gain (backward compat).z--soft-clipa  Replace the dataset's hard upper-clamp at 1.0 with a smooth knee-then-asymptote curve. Preserves emotional cross-fade tails that would otherwise be chopped off when high target gains push values well past 1.0. Linear below --soft-clip-knee, asymptotic above. Lower bound stays hard 0.z--smooth-cond-sigma-emotionu  Gaussian σ (in frames @ 30 fps) for smoothing the emotion one-hot in cond[:, :16] along time. Smooths turn-boundary step changes so the model learns gradual emotion cross-fades. Must also be applied at inference time. Typical: 5–10 (~170–330 ms). 0 = no smoothing (backward compat).z--soft-clip-kneegffffff?zKnee point for --soft-clip. Below this value the curve is exactly linear (no distortion). Above, it bends toward 1.0. Default 0.7 = first 70%% of the dynamic range untouched. Range (0, 1).z--eye-wide-maxu   Hard cap on eyeWide (ch 20, 21) AFTER gain, before the [0,1] clamp. Prevents bug-eyed saturation at high expression gain. Typical: 0.5–0.7. None → no cap (backward compat).z--brow-surprise-gainu$  Per-frame brow scale-down on surprise/fluster proxy (high eyeWide). Mirrors the viewer 'brow cap' slider: brows (ch 0–4) get multiplied by (1 - wide_ramp * value), where wide_ramp ramps in over [0.10, 0.30] of post-gain eyeWide. Typical: 0.3–0.5. None → no scale-down (backward compat).z--variant-tagu   Optional suffix appended to checkpoint filenames after --focus, so different gain runs don't clobber each other. E.g. 'v14' → best_lipsync_v14.pt / best_expression_v14.pt.T)parentsexist_oku   [target-gain] lipsync×z.2fu     expression×u     emotional-mouth×z6  (applied per-channel in dataset, then clamped [0,1])u    [velocity] eyeSquint(ch 18,19)×z  (vs default expression )u   [velocity] brows(ch 0-4)×z[plosive-damper] target damp=z,  (applied in dataset when mouthClose > 0.4)u    [target-smooth] brow(ch 0-4) σ=z
 frames (~i  z.0fz ms @ z fps)u'   [target-smooth] eyeSquint(ch 18,19) σ=u%   [target-smooth] eyeWide(ch 20,21) σ=z:[per-emotion] browInnerUp(ch 2) on happy frames uses gain z (vs z elsewhere)u   [target-gain] brows(ch 0-4)×u     (overrides expression×z for brow group only)z5[per-emotion] brows(ch 0-4) on happy frames use gain u6   [per-frame] brows(ch 0-4) scaled by (1 - wide_ramp × z?) on surprise/fluster frames (eyeWide proxy, ramp [0.10, 0.30])z"[cap] eyeWide(ch 20,21) capped at z after gainz[soft-clip] saturation knee at z! (linear below, asymptotic above)u!   [cond-smooth] emotion one-hot σ=z.1fzseed_train_final.jsonlcrop_frameslipsync_target_gainexpression_target_gainemotional_mouth_target_gainplosive_damp_targetsmooth_target_sigma_browsmooth_target_sigma_eye_squintsmooth_target_sigma_eye_widebrow_innerup_happy_gainbrow_happy_gainbrow_target_gainbrow_surprise_gaineye_wide_max	soft_clipsoft_clip_kneesmooth_cond_sigma_emotionzseed_val.jsonl)r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   ))long_rK   )solo_rJ   )zdaily_-splitrJ   c                   s   g | ]
}|d   kr|qS )r   r)   ).0e)want_catr)   r*   
<listcomp>  s    zmain.<locals>.<listcomp>
      z[smoke] train=z val=z epochs=z batch=ztrain: u    scenarios — counts: zval:   )num_samplesreplacementF)
batch_sizesamplernum_workers
pin_memory	drop_last)r   shuffler   r   zV3FaceModel (split-branch): g    .AzM params, ~z MB fp32 (~z	 MB int8)z--resume checkpoint not found: )map_locationweights_onlyrs   z[resume] loaded weights from z  (prev epoch=epoch?z	, val_l1=val_l1nanz.4fz[freeze_lipsync] froze zGM params (shared backbone + lipsync branch + lipsync head). Trainable: zM (expression branch + head).r   )LIPSYNC_BRANCH_CHANNELSuY   [wandb] requested but `wandb` not installed — skipping. Install with: pip install wandbv3face_h_b_lrz.0e_smoke )	n_paramssize_mb_fp32size_mb_int8train_scenariosval_scenariostrain_category_countsval_category_countsr1   smoke)projectentitynameconfig	gradientsd   )loglog_freqz[wandb] logging to c                 S  s   g | ]}|j r|qS r)   )requires_grad)r   pr)   r)   r*   r     s    )g?gffffff?)lrbetasweight_decayro   intr   r;   c                   sR   |  j k r| td j  S |  j  td j   }ddttjtd|   S )Nr   r   r5   )warmup_stepsrr   mathcospir7   )ro   prog)r   total_stepsr)   r*   	lr_lambda  s   
zmain.<locals>.lr_lambda)r   EXPRESSION_BRANCH_CHANNELSr   z[focus=lipsync] masked z expression channelsr   z[focus=expressionz/freeze_lipsyncz	] masked z lipsync channelsinfr   )rX   rZ   zepoch 3dz  train l1=r^   z vel=r_   z	  val l1=z  lr=z.2ez  vw=z  s)rs   r   r   r   val_veltrain_l1	train_velrT   bestz.ptu     → saved bestz	 (val l1=latestr]   )
r   ztrain/l1ztrain/velocityz
train/losszval/l1zval/velocityzval/losszval/best_l1r   epoch_seconds)ro   z
Done. best val l1: zcheckpoints: best_val_l1)ro   r   r   r;   )hargparseArgumentParseradd_argumentr   DEFAULT_NPZ_DIRDEFAULT_EMOTION_DIRDEFAULT_OUT_DIRr   r;   str
parse_argsr   epochsn_epochsr   r   learning_rater   r   r   rP   rQ   r   r   r   r   r   r   r   r   r   r   r   r   out_dirmkdirr   r1   printrL   fpsr   npz_diremotion_dirr   r   entriesextendr   lencategory_countsget_sample_weightslong_oversample_weightr   r   r   r   rg   r   size_mbresumeexists
SystemExitloadload_state_dictgetfreeze_lipsyncn_trainablers   r   wandb_WANDB_AVAILABLEwandb_run_name
hidden_dimshared_dilationsbranch_dilationsinitwandb_projectwandb_entity__dict__watchrunget_urlrm   r   r   rr   r   r+   rR   r   focusrangetimevelocity_warmup_epochsr7   r|   get_last_lr
state_dictvariant_tagsaver   summaryfinish)$apargsr1   mouth_gtrain_dsval_dssmoke_picksr`   pickssample_weightsr   train_dlval_dlrs   ckptfrozen_nr   r(   	use_wandbrun_nametrainable_paramsru   r   rv   r/   rS   r   best_valr   t0rZ   trvadtis_bestsuffixr)   )r   r   r   r*   main   sB  




	






	
















	








 








r*  __main__)r   r   r   r   )
r,   r   r-   r   r.   r   r/   r   r   r   )
r,   r   r-   r   r.   r   rS   r   r   r   )r5   )rX   rY   rZ   r;   r   r[   )3__doc__
__future__r   r   r   r  pathlibr   typingr   numpynpr   torch.nn.functionalrj   
functionalFtorch.optimr   torch.optim.lr_schedulerr   torch.utils.datar   r   scripts.compiler.constantsr	   r
   r   r   r   datasetr   rs   r   r   r   ImportError__file__resolver   PROJECT_ROOTr   r   r   r+   rG   rR   rW   r|   r*  __name__r)   r)   r)   r*   <module>   sP    



 %   [
