通过SRCNN对图片进行超分辨率重建

好久之前看到一篇关于深度学习的相关应用,里面提到了对图片进行超分辨率重建(Super Resolution),很感兴趣,因为手上好多低分辨率的图片,想upgrade.

论文点这里

看了论文,然后网上也刚好有个专门针对动漫图片进行超分辨率重建的项目waifu2x,也把代码看了一下

训练数据可以通过高分辨率图片下采样到低分辨率,然后插值重建,二者即可构成高低分辨率数据对。为了提高训练速度,可以通过在高分辨率数据对集中随机截取一小块进行训练,降低训练数据规模。

论文里面只用了三层的卷积层,而waifu2x使用了七层。

整个网络结构相对简单,区别于其他超分辨率重建的方法,如字典映射和稀疏编码方法。

用稀疏编码方法去看待CNN,是有一定启发性的。但是不同于前者,后者是一个端到端的完整实现。而前者训练sparse coding和mapping都是分开的,然后组成pipeline.

参考了https://github.com/lon9/waifu2x-keras的keras实现

用TensorFlow加载waifu2x的模型试了下,waifu2x使用了1600张高分辨率的图片进行训练。模型文件点这里,model目录里面。

..稍稍大点的图片就OOM了,模型还是太大了,传播到第四层,就Out of Memory了,算了一下造成溢出的输出tensor大小,内存一下就能占用达到了2G多.. 所以caffe版本是怎么处理这个问题的….

暂时想到的办法是对图片进行切块重建,最后重新合成大图,this absolutely would work, i think。

(5.25 更新,尝试了切块重建,没有OOM,但是他喵的,合成后的图,每一块图片都带有黑边,Orz,网络生成的图片边缘,确实有细节损失)

原图

原图

近邻插值

近邻插值(边缘锯齿明显)

CNN重建

SRCNN重建(边缘平滑,锐度稍降)

有点大的图

会导致OOM的大图

切割重建后,各切块之间出现黑边

切块重组就变成这样了…..

调试了下,似乎切块之后的图块,经过SRCNN,边缘就会出现黑边,然而通过最邻近插值的图块就没黑边。 so 似乎 还需要 折腾一下。

import tensorflow as tf
from PIL import Image
from scipy import misc
import numpy as np
import json

flags = tf.app.flags

#parameters for cli
flags.DEFINE_string('m','model/scale2.0x_model.json','Directory of Model')
flags.DEFINE_string('i','','Path for image to super resolution')

FLAGS= flags.FLAGS

#model parameter
parameters = []

#read model parameter from json
def get_model_params(path):
    with open(path,'rb') as model_file:
         return json.load(model_file)

def conv2d(x,w):
    return tf.nn.conv2d(x,w,strides=[1,1,1,1],padding="SAME")

def LeakyRelu(z, alpha=0.1, name="LeakyRelu"):
         return tf.maximum(z, tf.scalar_mul(alpha, z), name=name)
    
#load image and convert to YCbCr
def load_img2ycbcr(image_path,is_noise):
    image = Image.open(image_path).convert('YCbCr')
    if is_noise:
        image =  misc.fromimage(image).astype('float32')
    else : 
        image = misc.fromimage(image.resize((2*image.size[0], 2*image.size[1]), resample=Image.NEAREST)).astype('float32')
    
    misc.toimage(image, mode='YCbCr').convert('RGB').save(r'./NEAREST_2.0x.jpg')
    
    #normalization
    x = np.reshape(np.array(image[:,:,0]), (1, image.shape[0], image.shape[1],1)) / 255.0

    return image,x

#load image and convert to RGB
def load_img2rgb(image_path,is_noise):
    image = Image.open(image_path).convert('RGB')
    if is_noise:
        image = misc.fromimage(image).astype('float32')
    else:
        image = misc.fromimage(image.resize((2*image.size[0],2*image.size[1]),resample=Image.NEAREST)).astype('float32')
    
    misc.toimage(image, mode='RGB').save(r'./NEAREST_2.0x.jpg')
    
    #normalization to accelerate the caculation
    x = np.reshape(np.array(image),(1,image.shape[0],image.shape[1],image.shape[2]))/255.0

    return image,x

# for output name
def output_rename(path):
   file_ext=path.split('\\')[-1].split('.')
   rename = file_ext[0]+'_2.0x'
   rename = rename + '.' + file_ext[1]
   return rename

def main(_):

  if not tf.gfile.Exists(FLAGS.i):
       print("image didn't exist!")
       exit()
  if not tf.gfile.Exists(FLAGS.m):
       print("model didn't exist!")
       exit()


 #load the whole parameters from json    
  parameters.append(get_model_params(FLAGS.m)) 
  input_plane = parameters[0][0]['nInputPlane'] 
  if input_plane == 3:
     image,x = load_img2rgb(FLAGS.i,False)
  elif input_plane == 1:
     image,x = load_img2ycbcr(FLAGS.i,False)
  else:
      print('model\'s input channel was not equal 3 or 1')
      exit()
  
#convert model to tensorflow
  for i in range(len(parameters[0])):
    params = parameters[0][i]
    #notice here,transpose the np array to fit the right parameter order
    w_array = np.float32(params['weight']).transpose(2,3,1,0)
    weights = tf.Variable(w_array,name='w')
    b_array = np.float32(params['bias'])
    bias = tf.Variable(b_array,name='b')
    if i == 0 :
        activation = LeakyRelu(conv2d(x,weights)+bias)
    else:
        activation = LeakyRelu(conv2d(activation,weights)+bias)

  #initialize all variables  
  init_op = tf.global_variables_initializer()

  #using GPU to accelerate
  config = tf.ConfigProto(
        device_count = {'GPU': 1}
    )

  sess = tf.InteractiveSession(config=config)
  sess.run(init_op)
  y = sess.run(activation)

  if input_plane ==3:
       image = np.clip(y, 0, 1)[0]*255
       misc.toimage(image, mode='RGB').save(r'./'+output_rename(FLAGS.i))
  else: 
       image[:,:,0] = np.clip(y,0,1)*255
       misc.toimage(image, mode='YCbCr').convert('RGB').save(r'./'+output_rename(FLAGS.i))
  sess.close()

if __name__ == "__main__":
    tf.app.run()

Leave a Reply

Your email address will not be published. Required fields are marked *