【Day13】 AutoVC 实作 Pytorch 篇 - 2

衔接昨日

Part 5 - AutoVC

这部分我们暂时先参考官网 model_vc.py 即可

Part6 - 制作 Solver

把官网 solver_encoder.py 的第 90,91 行改成

 # 原本会造成 shape miss-match 导致收敛过快无法学习 
 g_loss_id = F.mse_loss(x_real, x_identic.squeeze())   
 g_loss_id_psnt = F.mse_loss(x_real, x_identic_psnt.squeeze())   

Part6 - Train

import torch
from solver_encoder import Solver
from data_loader import get_loader

class Config:
    def __init__(self):
        self.data_dir = './spmel'
        self.batch_size = 2
        self.len_crop = 176
        self.lambda_cd = 1
        self.dim_neck = 44
        self.dim_emb = 256
        self.dim_pre = 512
        self.freq = 22
        self.num_iters = 1000000
        self.log_step = 10
config = Config()
vcc_loader = get_loader(config.data_dir, config.batch_size, config.len_crop)
solver = Solver(vcc_loader, config)
solver.train()
torch.save(solver.G.state_dict(), "autovc")

Part7 - Inference

import IPython.display as ipd
import pickle
import torch
import numpy as np
from model_vc import Generator

device = 'cuda:0'
G = Generator(32,256,512,22).eval().to(device)
G.load_state_dict(torch.load('autovc'))
metadata = pickle.load(open('spmel/train.pkl', "rb"))

source = 0
target = 3
# 因为我的 Source 是 p226
uttr = np.load(f"spmel/p226/p226_014_mic1.npy")[50:226]
# (1,256)
emb_org = torch.from_numpy(np.expand_dims(metadata[source][1],axis=0)).to(device)
# (1,256)
emb_trg = torch.from_numpy( np.expand_dims(metadata[target][1],axis=0)).to(device)
# (1,178,80)
uttr = torch.from_numpy( np.expand_dims(uttr,axis=0)).to(device)
uttr_trg = None
with torch.no_grad():
_, x_identic_psnt, _ = G(uttr, emb_org, emb_trg)
# (176,80)
uttr_trg = x_identic_psnt[0, 0, :, :].cpu().numpy()

# To Waveform
from interface import *
vocoder = MelVocoder()
audio = np.squeeze(vocoder.inverse(torch.from_numpy(np.expand_dims(uttr_trg.T,axis=0))).cpu().numpy())
ipd.Audio(audio,rate = 22050)

小整理

你的根目录下大概会有以下内容:

root - 
    - /spmel 
    - /wavs
    - train.ipynb
    - make_spec.ipynb (生 spmel 的)
    - make_d_vector.ipynb (生 train.pkl 的)
    - dataloader.py
    - model_vc.py 
    - solver_encoder.py
    - MelGan 的 model
    - interface.py (For MelGan)
    - modules.py (For MelGan)
    - D_VECTOR 的 model
  • 如果你不想 train 的话下载官方的 pre_train model 会是 16khz 的,资料前处理方式也不一样,不可以用 MelGan 来转。

  • 你可以在这里下载我训练的 22khz 版本 ,它可以用 MelGan 转回 Waveform

  • Inference 出来的效果确实跟他们官网发表(听看看)的是一样的!

小结

到此我们快速 Run 了一次 Pytorch 的版本,更详细的内容我们留到明天开始用 TF 做会更清楚,但到这里我们已经可以体验到声音转换的魅力了!

因为 TF 用习惯了总觉得比较好解释,虽然身边的朋友们都说 TF 没救了,现在是 Pytorch 的时代,但是我们信仰要坚定RRR!

/images/emoticon/emoticon09.gif/images/emoticon/emoticon13.gif/images/emoticon/emoticon14.gif/images/emoticon/emoticon22.gif/images/emoticon/emoticon28.gif


<<:  Day 13 Self-attention(七) Positional Encoding、self-attention和其他model的比较

>>:  Day 1 初探Flutter

成熟度模型( A maturity model)

-CMM 和 CMMI 成熟度水平比较 成熟度模型“可以”(而不是应该或必须)定义五个成熟度级别,...

Day4 带着烤肉香的JavaScript

JavaScript约诞生於1995年,用於操作网页的DOM、BOM,负责处理所有网页与使用者的互动...

Ruby--Find the Difference

Find the Difference 题目连结:https://leetcode.com/pro...

成员 25 人:如何安稳地抽走猫躺的软垫

「好企业没有舒适圈;  如果你觉得今天很舒适,也撑不过明天。」 新创公司,总能看清人生百态。 你知道...

JavaScript Day01 - 说明

前言 这次主要是更新我之前的笔记,那时候刚学习 JavaScript,对於一些结果可能不是很懂,刚好...