#!/usr/bin/env python

# RaidGuessFS, a FUSE pseudo-filesystem to guess RAID parameters of a damaged device
# Copyright (C) 2015 Ludovic Pouzenc <ludovic@pouzenc.fr>
#
# This file is part of RaidGuessFS.
#
# RaidGuessFS is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# RaidGuessFS is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with RaidGuessFS. If not, see <http://www.gnu.org/licenses/>

import logging, numpy
import mydisks

class MyRaid():
    """Auxiliary class, managing RAID layer"""
    RAID_TYPES = [ '0', '1', '5', '5+0' ]
    RAID5_LAYOUTS = [ 'la', 'ra', 'ls', 'rs' ]

    @staticmethod
    def xor_blocks(fd_list, offset, size):
        """Compute bitwise XOR against a bunch of disks slice"""
        logging.info("Enter xor_blocks(fd_list(%i),0x%011x,%d)"%(len(fd_list), offset, size))
        
        if size % 8 != 0:
            raise ValueError('xor_blocks : size must be multiple of 8')
        dt = numpy.dtype('<Q8')

        fd_list[0].seek(offset)
        str_b1=fd_list[0].read(size)
        numpy_b1 = numpy.fromstring(str_b1, dtype=dt)
        all_zero = (numpy.count_nonzero(numpy_b1) == 0 )
        any_zero = all_zero

        for fd in fd_list[1:]:
            fd.seek(offset)
            str_b2=fd.read(size)
            numpy_b2 = numpy.fromstring(str_b2, dtype=dt)
            b2_zero = (numpy.count_nonzero(numpy_b2) == 0 )
            if all_zero == True:
                all_zero = b2_zero
            if any_zero == False:
                any_zero = b2_zero

            numpy.bitwise_xor(numpy_b1,numpy_b2,numpy_b1)

        if all_zero == True:
            result = 'z'
        elif numpy.count_nonzero(numpy_b1) == 0:
            if any_zero:
                result = 'g'
            else:
                result = 'G'
        else:
            result = 'b'

        logging.info("Exit. xor_blocks(fd_list,%d,%d)"%(offset, size))
        #import binascii
        #logging.debug(binascii.hexlify(numpy_b1))
        return (result,numpy_b1)


    def __init__(self, *args, **kwargs):
        self.d = None
        self.raid_start = 0
        self.raid_end = 0
        self.raid_sector_size = 512 # TODO : should be self.d.sector_size
        self.raid_chunk_size = 65536
        self.raid_disk_order = []
        self.raid_disk_count = 0
        self.raid_layout = 'ls'
        self.raid_disks = []
        self.nested_subraid = 2

    def get_raid_start(self):
        return self.raid_start

    def get_raid_end(self):
        return self.raid_end

    def get_raid_chunk_size(self):
        return self.raid_chunk_size

    def get_raid_disk_order(self):
        return self.raid_disk_order

    def get_raid_disk_order_str(self):
        return ' '.join(map(str,self.raid_disk_order))

    def get_raid_layout(self):
        return self.raid_layout

    def set_disks(self, new_mydisks):
        # FIXME : self.d don't need to be updaed (pass on __init__)
        self.d = new_mydisks
        self.set_raid_disk_order(range(self.d.disk_count))

    def set_raid_start(self, new_raid_start):
        """Update the start offset of raid data on underlying disks"""
        self.raid_start = new_raid_start

    def set_raid_end(self, new_raid_end):
        """Update the end offset of raid data on underlying disks"""
        self.raid_end = new_raid_end

    def set_raid_chunk_size(self, new_raid_chunk_size):
        """Update the size of chucks of data (or slice size)"""
        self.raid_chunk_size = new_raid_chunk_size

    def set_raid_disk_order(self, new_raid_disk_order):
        """Update the raid logical disk order"""
        check=[0] * self.d.disk_count
        for item in new_raid_disk_order:
            d = int(item)
            if not 0 <= d < self.d.disk_count:
                raise ValueError('Value out of range : %i [0,%i]'%(d,self.d.disk_count-1))
            check[d]=check[d]+1
        
        for d in range(self.d.disk_count):
            if check[d] != 1 and check[d] != 0:
                raise ValueError('Disk %i appears %i times (must be 0 or 1)'%(d,check[d]))
        self.raid_disk_count = len(new_raid_disk_order)
        self.raid_disk_order = new_raid_disk_order
        self.raid_disks = [ self.d.disks[i] for i in self.raid_disk_order ]

    def set_raid_layout(self, new_raid_layout):
        if new_raid_layout in MyRaid.RAID5_LAYOUTS:
            self.raid_layout = new_raid_layout
        else:
            raise ValueError('raid_layout has to be one of %s'%' '.join(RAID_LAYOUTS))

    def sizeof_raid_result(self, raid_type):
        size = max(0, self.raid_end - self.raid_start)
        return {
            '0'  : size * self.raid_disk_count,
            '1'  : size if self.raid_disk_count == 2 else 0,
            '5'  : size * (self.raid_disk_count - 1) if self.raid_disk_count >= 3 else 0,
            '5+0': size * (self.raid_disk_count - 2) if self.raid_disk_count >= 6 and self.raid_disk_count % 2 == 0 else 0,
            }[raid_type]

    def sizeof_disk_xor(self, raid_type):
        return max(0, self.raid_end - self.raid_start)

    def sizeof_disk_parity(self, raid_type):
        size = max(0, self.raid_end - self.raid_start) / self.raid_sector_size * 16
        return {
            '0'  : 64,
            '1'  : size if self.raid_disk_count == 2 else 64,
            '5'  : size if self.raid_disk_count >= 3 else 64,
            '5+0': size if self.raid_disk_count >= 6 and self.raid_disk_count % 2 == 0 else 64,
            }[raid_type]

    def read_disk_xor(self,raid_type,offset,size):
        """Returns raw bitwise XOR against a bunch of disks slice"""
        return MyRaid.xor_blocks(self.raid_disks,offset,size)[1].tostring()

    def read_disk_parity(self,raid_type,offset,size):
        """Returns textual information about parity status of each sector"""
        logging.warn("Enter read_disk_parity(%s,%d,%d)"%(raid_type,offset,size))
        msg = {
                '0'  : 'There no notion of parity in RAID 0 mode\n',
                '1'  : None if self.raid_disk_count == 2 else 'Wrong disk count (should be 2)\n',
                '5'  : None if self.raid_disk_count >= 3 else 'Wrong disk count (should be >=3)\n',
                '5+0': None if self.raid_disk_count >= 6 and self.raid_disk_count % 2 == 0
                            else 'Wrong disk count (should be >=6 and even)\n',
                }[raid_type]
        if msg:
            return msg[offset:offset+size]

        start = self.raid_start + offset * self.raid_sector_size / 16
        end = start + size * self.raid_sector_size / 16

        #TODO : improove for nested levels
        if raid_type in ['1','5', '5+0']:
            result = ''.join(
                    [ '0x%011x %c\n'%( addr, MyRaid.xor_blocks(self.raid_disks, addr, self.raid_sector_size)[0])
                            for addr in xrange(start, end, self.raid_sector_size)
                    ])
        else:
            result = None

        logging.warn("Exit. read_disk_parity(%s,%d,%d)"%(raid_type,offset,size))
        return result


    def read_raid_result(self,raid_type,offset,size):
        """Returns actual RAID data"""
        if raid_type == '0':
            segment_no = offset / self.raid_chunk_size
            segment_off = offset % self.raid_chunk_size
            stripe_no = segment_no / self.raid_disk_count
            subraid_no = -1
            par_disk = -1
            data_disk = segment_no % self.raid_disk_count
            off_disk = self.raid_start + stripe_no * self.raid_chunk_size + segment_off
            size2 = min(size, (segment_no+1) * self.raid_chunk_size - offset)

        elif raid_type == '1':
            segment_no = -1
            segment_off = -1
            stripe_no = -1
            subraid_no = -1
            par_disk = 1
            data_disk = 0
            off_disk = self.raid_start + offset
            size2 = size

        elif raid_type == '5':
            segment_no = offset / self.raid_chunk_size
            segment_off = offset % self.raid_chunk_size
            stripe_no = segment_no / (self.raid_disk_count-1)
            subraid_no = -1

            if self.raid_layout in ['ls','la']:
                par_disk = (self.raid_disk_count-1) - (stripe_no % self.raid_disk_count)
            else: # self.raid_layout in ['rs','ra']:
                par_disk = stripe_no % self.raid_disk_count

            if self.raid_layout in ['ls','rs']:
                data_disk = (par_disk+1 + (segment_no % (self.raid_disk_count-1)) ) % self.raid_disk_count
            else: # self.raid_layout in ['la','ra']:
                data_disk = segment_no % (self.raid_disk_count-1)
                if data_disk >= par_disk:
                    data_disk = data_disk + 1

            off_disk = self.raid_start + stripe_no * self.raid_chunk_size + segment_off
            # Note : could make error-free shorter reads than asked but convince the reader to be chunck aligned, which is great for perf
            size2 = min(size, (segment_no+1) * self.raid_chunk_size - offset)

        elif raid_type == '5+0':
            subraid_disk_count = self.raid_disk_count / self.nested_subraid
            segment_no = offset / self.raid_chunk_size
            segment_off = offset % self.raid_chunk_size
            stripe_no = segment_no / (self.raid_disk_count - self.nested_subraid)                # segment_no / 12
            subraid_no = (segment_no / (subraid_disk_count-1) ) % self.nested_subraid            # (segment_no/6) mod 2

            if self.raid_layout in ['ls','la']:
                subraid_par_disk = (subraid_disk_count-1) - (stripe_no % subraid_disk_count)
            else: # self.raid_layout in ['rs','ra']:
                subraid_par_disk = stripe_no % subraid_disk_count

            if self.raid_layout in ['ls','rs']:
                subraid_data_disk = (subraid_par_disk+1 + (segment_no % (subraid_disk_count-1)) ) % subraid_disk_count
            else: # self.raid_layout in ['la','ra']:
                subraid_data_disk = segment_no % (subraid_disk_count-1)
                if subraid_data_disk >= subraid_par_disk:
                    subraid_data_disk = subraid_data_disk + 1

            par_disk  = subraid_no * subraid_disk_count + subraid_par_disk
            data_disk = subraid_no * subraid_disk_count + subraid_data_disk

            off_disk = self.raid_start + stripe_no * self.raid_chunk_size + segment_off
            # Note : could make error-free shorter reads than asked but convince the reader to be chunck aligned, which is great for perf
            size2 = min(size, (segment_no+1) * self.raid_chunk_size - offset)

        else:
            raise Exception('Unimplemented read_raid_result() for raid_type == %s', raid_type)

        logging.debug("raid.read_result(%s): offset=%d,segment_no=%d,segment_off=%d,stripe_no=%d,subraid_no=%d,par_disk=%d(disk%02d),data_disk=%d(disk%02d),off_disk=%d,size2=%d,segment_off+size2=%d" 
        % (raid_type,offset,segment_no,segment_off,stripe_no,subraid_no,par_disk,self.raid_disk_order[par_disk],data_disk,self.raid_disk_order[data_disk],off_disk,size2,segment_off+size2) )

        data_fd = self.raid_disks[data_disk]

        if self.d.is_readable(self.raid_disk_order[data_disk],off_disk,size2):
            # No damaged sectors until the end of the chunck, so just read the data disk
            data_fd.seek(off_disk)
            data = data_fd.read(size2)
        else:
            logging.warn('Try to recovering damaged chunck (raid_offset: 0x%011x, data_disk: %i, disk_offset: 0x%011x'
                    % (offset, self.raid_disk_order[data_disk], off_disk) )
            # Damaged sectors, check / recover every sector
            other_disks = list(self.raid_disk_order)
            other_disks.remove(self.raid_disk_order[data_disk])
            other_fds = list(self.raid_disks)
            other_fds.remove(data_fd)

            data_arr = []
            for s in xrange(off_disk, off_disk+size2, self.raid_sector_size):
                if self.d.is_readable(self.raid_disk_order[data_disk],s,self.raid_sector_size):
                    # Current sector is readable from data disk, read it
                    logging.debug('-> 0x%011x : readable'%s)
                    data_fd.seek(off_disk)
                    data_arr.append(data_fd.read(self.raid_sector_size))
                else:
                    # Current sector is dead on data disk, recover it if possible
                    recoverable = reduce(lambda a,b: a and b, [ 
                        self.d.is_readable(other_disk,off_disk,self.raid_sector_size) for other_disk in other_disks
                    ])
                    if recoverable:
                        logging.info('-> 0x%011x : recoverable'%s)
                        data_arr.append( MyRaid.xor_blocks(other_fds, s,self.raid_sector_size)[1].tostring() )
                    else:
                        logging.warn('-> 0x%011x : unrecoverable'%s)
                        data_arr.append( '\0' * self.raid_sector_size)
            data = ''.join(data_arr)

        # Prevent short reads, seems mandatory for losetup'ing raid_result but kills performance
        if size2 < size:
            return ''.join( (data, self.read_raid_result(raid_type,offset+size2,size-size2) ) )

        return data