#!/usr/bin/env python

# RaidGuessFS, a FUSE pseudo-filesystem to guess RAID parameters of a damaged device
# Copyright (C) 2015 Ludovic Pouzenc <ludovic@pouzenc.fr>
#
# This file is part of RaidGuessFS.
#
# RaidGuessFS is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# RaidGuessFS is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with RaidGuessFS. If not, see <http://www.gnu.org/licenses/>

import logging, numpy

class MyRaid():
    """Auxiliary class, managing RAID layer"""

    def __init__(self, *args, **kwargs):
        self.raid_start = 0
        self.raid_end = 0
        self.raid_sector_size = 512
        self.raid_chunk_size = 65536
        self.raid_disk_order = []
        self.raid_disk_count = 0
        self.raid_types = [ '0', '1', '5', '5+0' ]

    def get_raid_size(self, raid_type):
        size = self.raid_end - self.raid_start
        if size <= 0:
            return 0
        else:
            return {
                '0': size * self.raid_disk_count,
                '1': size,
                '5': size * (self.raid_disk_count - 1) if self.raid_disk_count >= 3 else 0,
                '5+0': size * (self.raid_disk_count - 2) if self.raid_disk_count >= 6 and self.raid_disk_count % 2 == 0 else 0 ,
                }[raid_type]

    def get_raid_start(self):
        return self.raid_start

    def get_raid_chunk_size(self):
        return self.raid_chunk_size

    def get_raid_disk_order(self):
        return self.raid_disk_order

    def get_raid_disk_order_str(self):
        return ' '.join(map(str,self.raid_disk_order))

    def set_raid_start(self, new_raid_start):
        """Update the start offset of raid data on underlying disks"""
        self.raid_start = new_raid_start

    def set_raid_end(self, new_raid_end):
        """Update the end offset of raid data on underlying disks"""
        self.raid_end = new_raid_end

    def set_raid_chunk_size(self, new_raid_chunk_size):
        """Update the size of chucks of data (or slice size)"""
        self.raid_chunk_size = new_raid_chunk_size

    def set_raid_disk_order(self, new_raid_disk_order):
        """Update the raid logical disk order"""
        card=len(new_raid_disk_order)
        check=[0]*card
        for item in new_raid_disk_order:
            d = int(item)
            if not 0 <= d < card:
                raise ValueError('Value out of range : %i [0,%i]'%(d,card-1))
            check[d]=check[d]+1
        
        for d in range(card):
            if check[d]!=1:
                raise ValueError('Disk %i appears %i times (must be 1)'%(d,check[d]))
        self.raid_disk_order = new_raid_disk_order
        self.raid_disk_count = len(new_raid_disk_order)


    def read_data(self,raid_type,disks,offset,size):
        """TODO"""
        disk_count = len(self.raid_disk_order)

        # This code is RAID 5 only

        slice_no = offset / self.raid_chunk_size
        slice_off = offset % self.raid_chunk_size 
        segment=slice_no/(disk_count-1)
        par_disk=(disk_count-1) - (segment % disk_count) # TODO : equivalent a : segment-1 % disk_count ?
        data_disk=( par_disk + 1 + (slice_no % (disk_count-1)) ) % disk_count
        off_disk = self.raid_start + segment * self.raid_chunk_size + slice_off

        size2 = min(size, (slice_no+1) * self.raid_chunk_size - offset)

        logging.info("raid.read_data(%s): offset=%d,slice_no=%d,slice_off=%d,segment=%d,par_disk=%d,data_disk=%d,off_disk=%d,size2=%d,slice_off+size2=%d" 
        % (raid_type,offset,slice_no,slice_off,segment,par_disk,data_disk,off_disk,size2,slice_off+size2) )

        data_fd = disks[self.raid_disk_order[data_disk]]
        data_fd.seek(off_disk)
        data = data_fd.read(size2)
        
        # This kills performance but don't make short reads before EOF
        #if size2 < size:
        #    data += self.read_data(self,raid_type,disks,offset+size2,size-size2)

        return data

    def xor_blocks(self,fd_list, offset, size):
        """TODO"""
        logging.info("Enter xor_blocks(fd_list,%d,%d)"%(offset, size))
        
        assert(size % 8 == 0), "Size must be multiple of 8"
        dt = numpy.dtype('<Q8')

        fd_list[0].seek(offset)
        str_b1=fd_list[0].read(size)
        numpy_b1 = numpy.fromstring(str_b1, dtype=dt)
        all_zero = (numpy.count_nonzero(numpy_b1) == 0 )
        any_zero = all_zero

        for fd in fd_list[1:]:
            fd.seek(offset)
            str_b2=fd.read(size)
            numpy_b2 = numpy.fromstring(str_b2, dtype=dt)
            b2_zero = (numpy.count_nonzero(numpy_b2) == 0 )
            if all_zero == True:
                all_zero = b2_zero
            if any_zero == False:
                any_zero = b2_zero

            numpy.bitwise_xor(numpy_b1,numpy_b2,numpy_b1)

        if all_zero == True:
            result = 'z'
        elif numpy.count_nonzero(numpy_b1) == 0:
            if any_zero:
                result = 'g'
            else:
                result = 'G'
        else:
            result = 'b'

        logging.info("Exit. xor_blocks(fd_list,%d,%d)"%(offset, size))
        return result

    def check_data(self,raid_type,disks,offset,size):
        """TODO"""
        logging.warn("Enter check_data(%s,disks,%d,%d)"%(raid_type,offset,size))
        #import binascii
        #logging.warn(binascii.hexlify(numpy_b1))

        #result = ''.join([ self.xor_blocks(disks, (offset+i)*self.raid_sector_size, self.raid_sector_size) for i in range(size)])
        result = ''.join([ '0x%011x %c\n'%( (offset/16+i)*self.raid_sector_size, self.xor_blocks(disks, (offset/16+i)*self.raid_sector_size, self.raid_sector_size)) for i in range(size/16) ])
        # TODO donner des offests RAID et pas disques 
        
        logging.warn("Exit. check_data(%s,disks,%d,%d)"%(raid_type,offset,size))

        return result