Variational Autoencoder (VAE) cơ bản

generative-models

#1

Thời gian gần đây các mô hình sinh ngày càng trở nên phổ biến. Bên cạnh GANs, chúng ta còn có một mô hình nổi tiếng khác là Variational Autoencoder. Thời gian vừa rồi mình có cơ hội được làm việc với nó nên mình muốn viết 1 bài chia sẻ tổng quan về những gì mình hiểu. Vừa để chia sẻ những gì mình đã đọc và cũng mong được mọi người chỉ ra giúp nếu có chỗ nào mình hiểu chưa đúng :smiley:.

I. Autoencoder là gì?

Đây vẫn là một dạng neural network, có hơi khác các kiến trúc mạng khác ở chỗ: Mạng dùng chính input làm ground truth. Một cách dễ hình dung thì mạng sẽ có dạng thế này

Nguồn ảnh: https://towardsdatascience.com/applied-deep-learning-part-3-autoencoders-1c083af4d798 Mạng gồm 3 phần chính là encoder, latent representation và decoder. Khối encoder sẽ nhận đầu vào là vector biểu diễn dữ liệu n chiều, output là vector trong không gian ẩn m chiều (latent representation). Thông thường thì m < n, do đó đây cũng là một phương pháp để giảm số chiều của dữ liệu. Decoder nhận đầu vào là vector m chiều bên trên (output của encoder) và cố gắng khôi phục lại dữ liệu ban đầu (vector n chiều đầu vào của encoder). Có thể hình dung việc này giống như chúng ta nén một tín hiệu và khôi phục lại tín hiệu đó. Để khôi phục lại được tín hiệu gốc, biểu diễn ẩn phải học được cách lưu giữ các thông tin quan trọng của tín hiệu đầu vào. Tính toán trên không gian dữ liệu ít chiều sẽ giúp việc tính toán hiệu quả hơn, đồng thời tránh đc curse of dimensionality. Để tìm hiểu kỹ hơn về autoencoder, các bạn có thể tham khảo bài viết này

https://towardsdatascience.com/deep-inside-autoencoders-7e41f319999f

Do bài viết của mình tập trung vào VAE nên không tiện viết quá nhiều về Autoencoder. Tuy nhiên VAE là một topic nâng cao của Autoencoder nên mình khuyến khích việc mọi người tìm hiểu kỹ về Autoencoder trước khi chuyển qua phần 2 :blush:.

II. Variational Autoencoder

Quay trở lại một ví dụ với bộ dữ liệu MNIST. Giả sử với ảnh chứa số 4 sau khi qua bộ encoder giảm xuống chỉ còn 2 chiều, mình thu được biểu diễn của nó là vector có tọa độ (2, 2). Khi mình đưa vector (2, 2) này vào làm đầu vào cho decoder, tất nhiên mình sẽ thu lại được bức ảnh mang số 4. Model biết điều đó do đây là dữ liệu training. Nhưng giả sử mình đưa vào input của decoder vector (2.1, 2.1) khá tương đồng với vector (2, 2), không có gì đảm bảo được rằng mình sẽ thu lại được 1 bức ảnh chứa số 4 hay một số gần giống số 4. Điều đó không tốt chút nào, vì chúng ta đang cần mô hình có thể sinh ra được dữ liệu mới mà mô hình chưa từng nhìn thấy. Kết quả sau khi đưa bộ dữ liệu MNIST vào một Autoencoder:

Dễ dàng nhận thấy dù kết quả có vẻ tốt hơn khá nhiều so với phương pháp cố điển như PCA, tuy nhiên vẫn còn nhiều số bị chồng lên nhau. Ngoài ra rất khó để dự đoán xem những vùng màu trắng không có điểm dữ liệu nào sẽ ra sao nếu đưa những điểm đó vào decoder. VAE sinh ra để giải quyết bài toán này. Tư tưởng chính của VAE là thay vì tìm một điểm biểu diễn trong không gian ẩn cho một điểm dữ liệu trong không gian gốc, chúng ta sẽ đi tìm một phân phối xác suất cho điểm dữ liệu đó.

Nguồn ảnh: Sách Generative Deep Learning - David Foster

Cụ thể ở đây chúng ta sẽ giả thiết rằng dữ liệu được biểu diễn thông qua phân phối chuẩn. Mỗi điểm dữ liệu sẽ được xác định bởi 2 tham số là giá trị kì vọng (m) và phương sai (sigma^2). Tuy nhiên do ở đây mỗi điểm dữ liệu được biểu diễn bởi một vector nhiều chiều nên giá trị kì vọng cũng sẽ là một vector, còn phương sai được biểu diễn bởi một ma trận là ma trận hiệp phương sai. Các phần tử trên đường chéo chính là phương sai của từng chiều dữ liệu tương ứng, các phân từ không thuộc đường chéo chính tương đương với hiệp phương sai của 2 chiều dữ liệu.

image

Tuy nhiên khá may mắn là để đơn giản trong tính toán, tác giả đã giả sử các chiều dữ liệu là các biến ngẫu nhiên độc lập với nhau. Do đó ngoài các phần tử nằm trên đường chéo chính thì các phần tử còn lại đều bằng 0 (dễ thấy theo hình trên, khi các biến ngẫu nhiên độc lập với nhau, correlation giữa chúng cũng sẽ bằng 0). Bây giờ thay vì cần dùng một ma trận để biểu diễn phương sai thì chúng ta chỉ cần dùng một vector là đủ. Ngoài ra để đảm bảo độ ổn định trong quá trình huấn luyện mô hình, chúng ta sẽ dùng logarithm của phương sai thay vì dùng trực tiếp phương sai. Từ đó, 2 tham số cần xác định cho mỗi điểm dữ liệu bây giờ là:

  1. m: điểm kì vọng của phân phối
  2. log_var: logarithm phương sai của các chiều dữ liệu trong không gian ẩn.

Để tìm một điểm cụ thể z cho dữ liệu trong không gian ẩn, chúng ta sẽ tính toán như sau:

z = m + epsilon*sigma

trong đó: sigma = exp(log_var /2)

image

epsilon được lấy mẫu ngẫu nhiên theo phân phối chuẩn với mean = 0 và variance = 1

Giải thích một chút về việc tại sao không sampling một điểm dữ liệu z theo phân phối q(z|x) mà lại dùng phương pháp trên (reparameterization trick).

Nguồn ảnh:https://stats.stackexchange.com/questions/199605/how-does-the-reparameterization-trick-for-vaes-work-and-why-is-it-important

Nếu dùng sampling ngay tại node z, chúng ta sẽ có một random node trong mạng. Điều này khiến cho chúng ta khi tính back-propagation qua một biến ngẫu nhiên, điều này là không thể. Do đó tác giả đã dùng một trick để đưa phần ngẫu nhiên (epsilon) ra ngoài và sẽ ko cần đạo hàm phần này khi tính back-propagation.

Trong không gian ẩn, chúng ta sẽ lấy mẫu ở đây và đưa vào decoder để sinh ra dữ liệu mới. Để giải quyết vấn đề lấy điểm (2.1, 2.1) vẫn sinh ra được ảnh mang số 4 nêu trên, cần có một quy luật nhất định cho việc lấy mẫu này. Cách đơn giản mà khá hiệu quả được đề xuất là sử dụng phân phối chuẩn N(0, 1) (mean bằng 0 và variance bằng 1). Nghĩa là ta thêm một ràng buộc mỗi điểm dữ liệu sẽ được biểu diễn bởi một phân phối chuẩn xấp xỉ phân phối N(0, 1). Để xấp xỉ 2 phân phối, phương pháp thường được dùng là KL-Divergence. KL loss của 2 phân phối Gauss với hàm đơn biến có thể được chứng minh như hình dưới:

image

Bên trên là chứng minh ở dạng đơn biến. Tuy nhiên trong trường hợp này KL loss đối với đa biến thực chất là tổng của các đơn biến (tổng KL loss cho từng chiều). Chứng minh phần này các bạn có thể tham khảo tại đây:

Loss function của VAE gồm 2 phần:

  1. Reconstruction loss (giống với Autoencoder thông thường)
  2. KL loss mình đã trình bày phía trên

Khi kết hợp 2 loss này lại các bạn có thể tùy biến tỷ lệ cho mỗi bên tùy mục đích sử dụng. Đây cũng là các hyperparameter có thể chỉnh được.

Thực chất để xây dựng hàm loss của VAE, người ta đi từ tư tưởng của Variational Inference. Chúng ta cần tìm phân phối xác suất ánh xạ z khi biết x, nghĩa là p(z|x) theo tham số theta. Tuy nhiên xác định được theta lại là công việc rất khó, một giải pháp khả thi hơn là xấp xỉ phân phối p(z|x) theo tham số theta bằng phân phối q(z|x) theo tham số phi nào đó. Việc xấp xỉ này chúng ta tiếp tục sử dụng KL Divergence. Mình không biết viết Latex nên viết tạm ra giấy tiếp ^^.

image

Lưu ý một chút là chúng ta sẽ maximize L trong hình. Tuy nhiên reconstruction loss là một cross entropy, Đặt dấu trừ ra ngoài sẽ quay trở về bài toán minimize -L, phù hợp với nội dung của VAE loss bên trên :smiley: .

III. Thực hành

Thật ra trang chủ của Keras đã đăng một bài code hướng dẫn khá chi tiết với bộ dữ liệu MNIST, các bạn có thể kéo về chạy thử luôn. Kết hợp với phần lí thuyết giải thích bên trên, phần code này viết một cách khá đầy đủ từng bước một mà mình đã đề cập :slight_smile:

https://keras.io/examples/variational_autoencoder/

Bài viết của mình nếu chỗ nào chưa rõ hoặc sai mong các bạn comment bên dưới giúp mình để có thể cải thiện chất lượng bài viết. Cảm ơn mọi người đã đọc hết bài :smiley:


#2

curse dimesionality : hiểu như thế nào ạ? nếu dịch sang tiếng việt để có khi sử dụng trong báo cáo nên chuyển ngữ ntn là hợp lý ạ?


#3

Về cơ bản thì nó liên quan đến độ phủ của dữ liệu, đặc biệt khi bạn dùng các phương pháp dựa trên Euclidean distance. Bạn có thể tìm hiểu thêm tại đây

Còn dịch ra tiếng việt thì thực sự mình cũng ko biết nên dịch nó là gì ^^


#4

e có đọc thấy có khái niệm posterior distribution. tra wiki mà chưa thông não được. bác giải thích giúp em vs


#5

Giải thích như của bạn chưa chuẩn. Curse of dimension được hiểu là "những đặc tính trên chiều không gian gốc - n chiều, có thể không còn duy trì được trên miền không gian ít chiều hơn m, (m<n).

Vậy tìm được chiều không gian nhỏ hơn mà vẫn giữ các đặc tính mong muốn là rất khó. Vì thế mới có câu kiểu là “lời nguyền của số chiều” - chưa nghĩ ra từ gì hay hơn.

Ví dụ là ảnh 3D cho ta cảm giác về khoảng cách và độ sâu, nhưng ảnh 2D không trả về chính xác về khoảng cách nữa.

L2 distance chỉ là một ví dụ thôi chứ không phải là chính như bạn nói.


#6

Thật ra đoạn này em cũng không chắc lắm. Theo em đọc và hiểu trong cuốn machine learning a probabilistic perspective của Murphy thì curse of dimensionality tưởng tượng không gian theo dạng hypercube với D-dimension và dữ liệu trong không gian này uniformly distributed. Giả sử có một cube nhỏ trong đó, nếu muốn cube này chứa một tỷ lệ f dữ liệu của không gian (f thuộc [0, 1]) thì chiều dài cạnh kì vọng của cube đó sẽ được tính bằng công thức eD(f) = f^1/D. Nghĩa là lúc này mật độ của dữ liệu sẽ càng thưa nếu số chiều tăng lên.

Để bảo toàn mật độ dữ liệu cũng như Euclidean distance đối với các “hàng xóm” thì số lượng dữ liệu nếu em nhớ không nhầm cũng phải tăng lên theo hàm N^D với N là số lượng điểm dữ liệu giả định trong không gian 1 chiều, D là số chiều.

Cách hiểu khi giảm số chiều từ không gian nhiều chiều xuống ít chiều sẽ gây mất mát các đặc tính mong muốn em thấy cũng có lý. Có thể em đang hiểu sai đoạn này


#7

" Trong không gian ẩn, chúng ta sẽ lấy mẫu ở đây và đưa vào decoder để sinh ra dữ liệu mới. Để giải quyết vấn đề lấy điểm (2.1, 2.1) vẫn sinh ra được ảnh mang số 4 nêu trên, cần có một quy luật nhất định cho việc lấy mẫu này. Cách đơn giản mà khá hiệu quả được đề xuất là sử dụng phân phối chuẩn N(0, 1) (mean bằng 0 và variance bằng 1). Nghĩa là ta thêm một ràng buộc mỗi điểm dữ liệu sẽ được biểu diễn bởi một phân phối chuẩn xấp xỉ phân phối N(0, 1). "

Cái đoạn này e ko hiểu lắm.

  • Khi 1 điểm dữ liệu đưa vào latent space sẽ xác định được mean vs log_var tương ứng của nó. Và từ đó sẽ lấy mẫu theo 2 cái tham số đó và đưa vào decoder ạ?
  • Em thắc mắc nữa là tại sao khi việc lấy mẫu tuân theo phân phối chuẩn tắc u=0, sigma = 1 thì sẽ làm cho như trong ví dụ (2.1, 2.1) vs (2.0, 2.0) có cùng tạo ra được số 4 ạ?
  • Giả sử như đã áp dụng KL-loss đó rồi thì khi print ra output thì (2.1, 2.1) vs (2, 2) có cùng tạo ra 1 đầu ra có giá trị giống nhau ko ạ

Mong đc mn giải đáp ạ


#8
  1. Đúng

  2. Do khi lấy mẫu theo phân phối này trong quá trình training sẽ giúp decoder học được quy luật decode general hơn là chỉ biết cách decode từ duy nhất 1 điểm dữ liệu trong latent space.

  3. Không chắc được. Đối với ví dụ như hình dưới, VAE sẽ giúp có được tính chất nội suy trong latent space. Còn có chắc 2 điểm với khoảng cách cho trước có chắc chắn cho output giống nhau hay không thì nó phụ thuộc vào nhiều yếu tố. Phía trên ví dụ (2, 2) và (2.1, 2.1) chỉ lấy để làm ví dụ thôi bạn.

Nguồn ảnh: https://towardsdatascience.com/intuitively-understanding-variational-autoencoders-1bfe67eb5daf


#9

em hỏi ngu tí. Trong quá trình huấn luyện thì mô hình sẽ học cho các số 0->9) một phân phối chuẩn nào đó ạ? Ví dụ số " 0 " có 1 phân phối chuẩn xác định dựa trên mean/ variance. Thì khi KL-divergency nó sẽ làm cho mean -> 0, variance -> 1. Thì có khi nào nó vô tình làm cho các phân phối chuẩn của các số đó giống nhau ko ạ?


#10

Mình nghĩ là nó không giống nhau được vì các tham số khi reparameterize khác nhau mà. Việc dùng KL-divergency theo mình hiểu là để model không cheat (tránh bị overfit).

Posterior distribution trên machinelearningcoban viết cũng dễ hiểu mà.


#11

thanks kiu ạ. Cho e hỏi thêm là “cheat” ở đây hiểu ntn ạ


#12

Mình dùng từ trong sách thôi. Ý là thay vì tìm phân phối cho mỗi loại như mục tiêu mong muốn thì với mỗi data point trong train set nó sẽ gán với 1 phân phối => overfit.


#13

Giả sử ban đầu bạn có 1 tập các observation X = {x_1, x_2, …, x_n}. Bạn muốn biết các observation này tác động thế nào đến tập y = {y_1, y_2, …, y_n}. Theo bayes rule thì:

P(y|X) = P(X|y)P(y) / P(X)

Ở công thức trên, X là observed data, hay evidence, là dữ liệu bạn quan sát được.

P(y|X) là posterior, là cái bạn muốn tìm.

P(X|y) là likelihood

P(y) là prior, là ban thông tin đầu về y trước khi biết được evidence X.

Khả năng dịch của mình hạn chế, không biết viết vậy có ổn không. Bạn có thể đọc phần Introduction của cuốn Pattern Recognition and Machine Learning để biết thêm chi tiết.