Generative Adversarial Networks - GANs

deep-learning
gan

#1

Generative Discriminative Networks - GANs

GANs là gì?

GANs là một thuật toán học không giám sát (Unsupersived Learning) được Ian Goodfellow giới thiệu vào năm 2014 tại hội nghị NIPS, trong đó bao gồm hai thành phần chính là GeneratorDiscriminator:

  • Generator (ký hiệu G) nhận nhiệm vụ học ra cách áp xạ từ một không gian tìm ẩn Z (a latent space) vào một không gian với phân phối từ dữ liệu cho trước.

  • Discriminator (ký hiệu D) nhận nhiệm vụ phân biệt dữ liệu được tạo ra từ G và dữ liệu cho trước.

Một cách toán học: Giả sử ta có z \in Zz \sim p_Z(z), dữ liệu cho trước xx \sim p_{data}(x) (x gọi là real data). Ta có G sẽ ánh xạ z không gian dữ liệu cho trước \hat{x}=G(z) (\hat{x} gọi là fake data). D(x) là xác suất mà xreal data hay fake data. Mục tiêu của GANs là làm sao cho G cố gắng tạo ra được \hat{x} sao cho D không còn thể phân biệt được là fake data. Tối ưu GD giống như trò chơi minimax với hàm mục tiêu V(D, G), trong đó G cố gắng làm tăng xác suất mà \hat{x} được tạo ra là real dataD thì cố gắng làm điều ngược lại.

\min_{G} \max_{D} V(D, G) = \mathbb{E}_{x \sim p_{data}(x)}[log D(x)] + \mathbb{E}_{z \sim p_Z(z)}[log(1 - D(G(z)))]

Tối ưu GANs

Qúa tình tối ưu GANs cũng khá đơn giản:

  1. Lấy ngẫu nhiên m mẫu z \in Zm mẫu x từ dữ liệu cho trước.
  2. Tối ưu D dựa trên zx.
  3. Lấy ngẫu nhiên m mẫu z \in Z (có thể dùng lại z ở bước 1).
  4. Tối ưu G dựa trên z.
  5. Quay lại bước 1.

Minh hoạ bằng Python với Tensorflow

Sau đây mình sẽ minh hoạ GAN với tập dữ liệu MNIST. Toàn bộ mã nguồn có thể tìm được tại đây

Trước tiên mình cần cài đặt thư viện Tensorflow cho Python

[sudo] pip install tensorflow #Hoặc tensorflow-gpu đối với các bạn sử dụng Tensorflow với GPU

Trước tiên ta cần khai bái các thư viện cần thiết:

import tensorflow as tf
import tensorflow.contrib.slim as slim #slim cho phép khai báo nhanh các lớp thông dụng
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data #Tensorfow cung cấp sẵn API để lấy dữ liệu từ tập MNIST

Ta cần định nghĩa GD, ở đây mình sử dụng một mạng truyền thẳng (feed-forward) đơn giản với một lớp ẩn:

def generator(inputs):
	with tf.variable_scope("generator"):
		net = slim.fully_connected(inputs, 256, scope = "fc1")
		net = slim.fully_connected(net, 784, scope = "fake_images", activation_fn = tf.nn.sigmoid)
	return net

def discriminator(inputs):
	with tf.variable_scope("discriminator"):
		net = slim.fully_connected(inputs, 256, scope = "fc1")
		net = slim.fully_connected(net, 1, scope = "predictions", activation_fn = tf.nn.sigmoid)
	return net

Ta đinh nghĩa các hyperparameters cần thiết:

mnist_loader = input_data.read_data_sets('MNIST_data')
batch_size = 32
z_dim = 100 #Số chiều của Z
learning_rate = 0.0002
num_iters = 100000

Sau đó ta khởi tạo mạng:

random_z = tf.placeholder(shape = [batch_size, z_dim], dtype = tf.float32, name = "random_vector") #vector z
real_images = tf.placeholder(shape = [batch_size, 784], dtype = tf.float32, name = "real_images") #real data
fake_images = generator(random_z) #fake data

predictions = discriminator(tf.concat([real_images, fake_images], axis = 0)) #fake và real data được đưa qua Discriminator
real_preds = tf.slice(predictions, [0, 0], [batch_size, -1])
fake_preds = tf.slice(predictions, [batch_size, 0], [batch_size, -1])

Sau đó ta định nghĩa hàm mất cho GeneratorDiscriminator. Ở đây mình sử dụng Adam Optimizer để tối ưu GD.

gen_loss = -tf.reduce_mean(tf.log(fake_preds))
dis_loss = -tf.reduce_mean(tf.log(real_preds) + tf.log(1. - fake_preds))

gen_vars = slim.get_variables(scope = "generator")
dis_vars = slim.get_variables(scope = "discriminator")

optimizer = tf.train.AdamOptimizer(learning_rate)
gen_train_op = optimizer.minimize(gen_loss, var_list = gen_vars)
dis_train_op = optimizer.minimize(dis_loss, var_list = dis_vars)

Sau đó ta tiến hành tối ưu GD

sess = tf.Session()
sess.run(tf.global_variables_initializer())
for iter in xrange(1, num_iters + 1):
	## Vector ngẫu nhiên z được lấy từ phân phối đều trên [-1, 1]
    feed_dict = {
            random_z: np.random.uniform(-1., 1., size=[batch_size, z_dim]),
            real_images: mnist_loader.train.next_batch(batch_size=batch_size)[0]
            }
    _, _, _gen_loss, _dis_loss = sess.run(
		    [gen_train_op, dis_train_op, gen_loss, dis_loss],
			feed_dict = feed_dict
			)
			
    if (iter % 50) == 0:
        print("Iteration [{:06d}/{:06d}]".format(iter, num_iters))
        print("\t>> Generator Loss: {}".format(_gen_loss))
        print("\t>> Discriminator Loss: {}".format(_dis_loss))

Ứng dụng của GANs và Những điều lưu ý

Trong những năm gần đây, GAN đã có những ứng dụng mạnh mẽ trong nhiều bài toán như Image Super Resolution, Image Translation, Domain Adaptaion. Tuy nhiên để tối ưu GANs là điều không phải dễ, điều này đòi hỏi về phần cứng cũng như sự phân tích bài toán:

  • GAN đòi hỏi chi phí phần cứng cao với các bài toán xử lý ảnh có kích thước lớn.

  • Việc tìm sự cân bằng giữa GeneratorDiscriminaotor là một bài toán thực sự khó.

  • Việc điều chỉnh Learning Rate cũng không phải là điểu dễ dàng, phải cân bằng làm sao cho DiscriminatorGenerator cân bằng lẫn nhau, nếu không dễ dẫn đến một bên lấn áp phần còn lại, dẫn đến Generator không cho ra kết quả tốt.

Bài viết này chỉ cung cấp một khái niệm và ví dụ cơ bản nhất về GANs, nếu có gì sai sót mong các bạn có thể đóng góp kiến.

Tham khảo

  1. Generative Adversarial Networks Wiki
  2. Generative Adversarial Nets. Ian Goodfellow. NIPS 2014
  3. Mã nguồn tham khảo

#3

Anh bỏ một dấu $ để hiện thị rồi nhé.


#4

Bài viết khá công phu. Anh gợi ý em một số điểm để bài viết tốt hơn:

  1. Cho một số hình minh họa về GAN. Ví dụ quá trình training của GAN em có thể cho hình vẽ sau để người đọc dễ hình dung

generative-adversarial-networks-34-638

  1. Train GAN thực chất là quá trình training luân phiên: train G rồi lại train D… Trong code tensorflow em nên tách nó ra để người đọc dễ hình dung ( mapping từ thuật toán đến implementation). Em nên có hình vẽ về loss G và loss D vsf kết quả first and last epoch của generated image.

  2. Thêm một số trick để cải thiện quá trình train cho GAN. Em google từ khóa ganhacks sẽ có

Good job!


#5

Cảm ơn anh đã gợi ý ạ. Em đang viết một bài dài và chi tiết hơn, và viết luôn sang một số ứng dụng của GAN. Nhưng thiện tại em đang thực tập nên cũng khá bận, khi nào về Việt Nam em sẽ hoàn thiện bài viết mới đầy đủ và gửi đến đọc giả sớm nhất.


#6

Hôm trước mình đọc 1 bài về face-swap, ý tưởng có vẻ khá giống với GAN. có điều là ko dùng autoencoder mà dùng multiscale CNN.

Ý tưởng chính của bài này thì mình vẫn chưa hiểu rõ lắm, comment ở đây mong mng chỉ thêm. Có phải là mình chỉ train cái multiscale CNN thôi ko? và loss function được tính từ layers của VGG?

Cám ơn mng!


#7

Bạn có thể share paper để mọi người cùng đọc để mình và các bạn khác có thể giúp bạn.


#8

https://arxiv.org/pdf/1611.09577.pdf

Nó là bài này ạ!


#9

Mọi người cho em hỏi với ạ:

Hàm tối ưu sử dụng cross-entropy tức là muốn Discriminator nhận real image là real image và fake image là fake image. Vậy em tưởng phải tối ưu hàm đạt giá trị nhỏ nhất của hàm cross-entropy ạ? Vì tính chất của hàm cross-entropy là giá trị càng nhỏ thì càng ép cho 2 phân phối xác suất giống nhau ạ? Mà trong GANs lại tối ưu maximum tức là giá trị lớn nhất. Mong mọi người giải đáp ạ! Em cảm ơn!


#10

Chào bạn, hàm cross entropy loss có dạng là \mathcal{L} = - \{ \mathbb{E}_{x \sim p_{data}(x)}[log(D(x))] + \mathbb{E}_{z \sim p_z(z)}[log(1 - D(G(z)))] \}. Khi bạn làm min hàm \mathcal{L} tức là bạn muốn làm max -\mathcal{L}. Trong hình của bạn, hàm đó đang ngược dấu với \mathcal{L}.


#11

À vâng đúng rồi ạ, em quên ko để ý dấu :joy:


#12

Mình đóng góp thêm là GAN mục tiêu của nó chính là bộ Generator với mục tiêu tạo ra sample mới có cùng tính chất với tập sample cũ.

Từ đặc tính này dẫn đến các ứng dụng của GAN mang yếu tố “realistic” - rất giống thật. Ví dụ như các ứng dụng low-level vision như image denoising, image super-resolution, inpainting, .v.v.

Về điểm yếu của GAN, như mình xem kết quả từ nhiều bài báo liên quan tới Super-resolution thì ảnh nhìn khá thực (nếu ko có ảnh gốc). Nhưng hoàn toàn có thể biến object thành một cái gì đó hoàn toàn khác điển hình như ký tự. Ngoài ra các đặc tính vật lý của bức ảnh như hình ảnh phản chiếu có thể khác biệt rất lớn. Bạn nhìn lần đầu có thể ấn tượng, khó phân biệt được đâu là ảnh GAN nhưng nếu bạn nhìn kỹ sẽ nhận ra được các lỗi thường gặp.

Nên dùng GAN cho ảnh kết hợp vói object recognition, segmentation cho ảnh y tế thì mình ko rõ ảnh hưởng của nó như thế nào.

GAN còn cần nhiều phát triển nữa, đặc biệt là hàm loss của bộ Net D


#13

Đã có ứng dụng nào của GANs trong lĩnh vực NLP chưa bạn nhỉ?


#14

Em có một số thắc mắc về GANs thế này mong mọi người giúp em

  • Theo những gì em học được thì mục tiêu chính của GANs là tạo bộ data giả sao cho giống thật nhất ( Generator đánh lừa được Discriminator). Nhưng mọi người cho em hỏi là: với một bài toán là cần sự đặc trưng của một khuôn mặt (Face recognition) thì việc Generator tạo ra các data giả mang bản chất khuôn mặt giống với ảnh gốc thì nó thực sự không mang nhiều ý nghĩa đúng không mọi người ? (do em đang muốn làm data của face recognition của em trở nên đa dạng và phong phú hơn. Do với bối cảnh khác nhau em đều phải thêm dữ liệu vào cho mô hình học lại thực sự cũng rất bất tiện) Cảm ơn mọi người rất nhiều