huge refactoring of loading fields due to mysql bug
authorMax Lapshin <max@maxidoors.ru>
Mon, 8 Dec 2008 20:53:51 +0000 (23:53 +0300)
committerMax Lapshin <max@maxidoors.ru>
Mon, 8 Dec 2008 20:53:51 +0000 (23:53 +0300)
mysql2psql

index 60c1e1b..d1e7227 100755 (executable)
@@ -12,8 +12,8 @@ class MysqlReader
   class Table
     attr_reader :name
     
-    def initialize(mysql, name)
-      @mysql = mysql
+    def initialize(reader, name)
+      @reader = reader
       @name = name
     end
     
@@ -28,33 +28,48 @@ class MysqlReader
       @columns ||= load_columns
     end
     
+    def convert_type(type)
+      case type
+      when "tinyint(1)"
+        "boolean"
+      when /tinyint/
+        "tinyint"
+      when /int/
+        "integer"
+      when /varchar/
+        "varchar"
+      else
+        type
+      end 
+    end
+    
     def load_columns
-      result = @mysql.list_fields(name)
-      fields = result.fetch_fields.map do |field|
-        desc = {
-          :name => field.name,
-          :table_name => field.table,
-          :default => field.def,
-          :type => @@types[field.type] || field.type,
-          :length => field.length,
-          :max_length => field.max_length,
-          :flags => field.flags,
-          :decimals => field.decimals ? field.decimals > 12 ? 12 : field.decimals : nil,
-          :null => !field.is_not_null?,
-          :numeric => field.is_num?,
-          :primary_key => field.is_pri_key?
-          }
-        if field.is_pri_key?
-          @mysql.query("SELECT max(`#{field.name}`) + 1 FROM `#{name}`") do |res|
-            desc[:maxval] = res.fetch_row[0].to_i
-          end
+      @reader.reconnect
+      result = @reader.mysql.list_fields(name)
+      mysql_flags = Mysql::Field.constants.select {|c| c =~ /FLAG/}
+      fields = []
+      @reader.mysql.query("EXPLAIN `#{name}`") do |res|
+        while field = res.fetch_row do
+          length = field[1][/\((\d+)\)/, 1]
+          desc = {
+            :name => field[0],
+            :table_name => name,
+            :default => field[4],
+            :type => convert_type(field[1]),
+            :length => length && length.to_i,
+            :decimals => 12,
+            :null => field[2] == "YES",
+            :primary_key => field[3] == "PRI"
+            }
+          fields << desc
         end
-        if desc[:type] == "blob" && field.flags & Mysql::Field::BINARY_FLAG == 0
-          desc[:type] = "text"
+      end
+
+      fields.select {|field| field[:primary_key]}.each do |field|
+        @reader.mysql.query("SELECT max(`#{field[:name]}`) + 1 FROM `#{name}`") do |res|
+          field[:maxval] = res.fetch_row[0].to_i
         end
-        desc
       end
-      result.free
       fields
     end
     
@@ -73,7 +88,7 @@ class MysqlReader
       @indexes = []
       @foreign_keys = []
       
-      @mysql.query("SHOW CREATE TABLE `#{name}`") do |result|
+      @reader.mysql.query("SHOW CREATE TABLE `#{name}`") do |result|
         explain = result.fetch_row[1]
         explain.split(/\n/).each do |line|
           next unless line =~ / KEY /
@@ -104,7 +119,7 @@ class MysqlReader
     
     def count_for_pager
       query = has_id? ? 'MAX(id)' : 'COUNT(*)'
-      @mysql.query("SELECT #{query} FROM #{name}") do |res|
+      @reader.mysql.query("SELECT #{query} FROM #{name}") do |res|
         return res.fetch_row[0].to_i
       end
     end
@@ -115,14 +130,26 @@ class MysqlReader
     end
   end
   
-  def initialize(host = nil, user = nil, passwd = nil, db = nil, sock = nil, flag = nil)
-    @mysql = Mysql.connect(host, user, passwd, db, sock, flag)
+  def connect
+    @mysql = Mysql.connect(@host, @user, @passwd, @db, @sock, @flag)
     @mysql.query("SET NAMES utf8")
     @mysql.query("SET SESSION query_cache_type = OFF")
   end
   
+  def reconnect
+    @mysql.close rescue false
+    connect
+  end
+  
+  def initialize(host = nil, user = nil, passwd = nil, db = nil, sock = nil, flag = nil)
+    @host, @user, @passwd, @db, @sock, @flag = host, user, passwd, db, sock, flag
+    connect
+  end
+  
+  attr_reader :mysql
+  
   def tables
-    @tables ||= @mysql.list_tables.map {|table| Table.new(@mysql, table)}
+    @tables ||= @mysql.list_tables.map {|table| Table.new(self, table)}
   end
   
   def paginated_read(table, page_size)
@@ -163,44 +190,35 @@ class PostgresWriter < Writer
     null = column[:null] ? "" : " NOT NULL"
     type = 
     case column[:type]
-    when "var_string"
+    when "varchar"
       default = default + "::character varying" if default
+#      puts "VARCHAR: #{column.inspect}"
       "character varying(#{column[:length]})"
-    when "long"
+    when "integer"
       default = " DEFAULT #{column[:default].nil? ? 'NULL' : column[:default]}" if default
-      "bigint"
-    when "longlong"
+      "integer"
+    when "tinyint"
       default = " DEFAULT #{column[:default].nil? ? 'NULL' : column[:default]}" if default
-      "bigint"
+      "smallint"
     when "datetime"
       default = nil
       "timestamp without time zone"
     when "date"
       default = nil
       "date"
-    when "char"
-      if column[:length] == 1
-        if default
-          default = " DEFAULT #{column[:default].to_i == 1 ? 'true' : 'false'}"
-        end
-        "boolean"
-      else
-        default = " DEFAULT #{column[:default].nil? ? 'NULL' : column[:default]}" if default
-        "int"
-      end
+    when "boolean"
+      default = " DEFAULT #{column[:default].to_i == 1 ? 'true' : 'false'}" if default
+      "boolean"
     when "blob"
       "bytea"
     when "text"
       "text"
     when "float"
       default = " DEFAULT #{column[:default].nil? ? 'NULL' : column[:default]}" if default
-      "numeric(#{column[:length] + column[:decimals]}, #{column[:decimals]})"
+      "real"
     when "decimal"
       default = " DEFAULT #{column[:default].nil? ? 'NULL' : column[:default]}" if default
       "numeric(#{column[:length] + column[:decimals]}, #{column[:decimals]})"
-    when "int24"
-      default = " DEFAULT #{column[:default].nil? ? 'NULL' : column[:default]}" if default
-      "int"
     else
       puts "Unknown #{column.inspect}"
       column[:type].inspect