#!/usr/bin/env python3
"""Pymodbus Synchronous Client (TCP)"""

# ================================================== #
# //                 Module imports               // #
# ================================================== #

# Install the pymodbus, pyyaml and tabulate modules

import logging
import argparse
import sys
import random 
from pymodbus.client import ModbusTcpClient
from pymodbus.exceptions import ModbusIOException
from tabulate import tabulate 

# ================================================== #
# //         Global variable declarations         // #
# ================================================== #

__author__ = "Diarmuid Ó Briain"
__copyright__ = "Copyright 2025, SETU"
__licence__ = "European Union Public Licence v1.2"
__version__ = "2.0"

# ================================================== #
# //                   Constants                  // #
# ================================================== #

# Defaults used if none specified with -i -p or -u switches

MB_SVR_HOST = "127.0.0.1" 
MB_SVR_PORT = 5002
UNIT = 1
COUNT = 8
START_ADDR = 0 

# Pool of possible Coil and Holding Register values to select from for writing.
CO_WRITE_POOL = [True, False]
RR_WR_POOL = [11, 22, 33, 44, 55, 66, 77, 88, 99, 110, 121] 

# ================================================== #
# //                   Logging                    // #
# ================================================== #

logging.basicConfig(format='%(message)s') 
log = logging.getLogger()
log.setLevel(logging.INFO)

# ================================================== #
# //        Terminal arguements (switches)        // #
# ================================================== #

def get_commandline_parser():
    """
    Defines and returns the command line parser object. 
    NOTE: -c and -r are mutually exclusive to prevent ambiguous requests.
    """
    parser = argparse.ArgumentParser(description="Modbus Client")
    parser.add_argument("-i", "--ip", default=MB_SVR_HOST, 
                        help=f"IP address of the server (default: {MB_SVR_HOST})")
    parser.add_argument("-p", "--port", type=int, default=MB_SVR_PORT, 
                        help=f"Port of the server (default: {MB_SVR_PORT})")
    parser.add_argument("-u", "--unit", type=int, default=UNIT, 
                        help=f"Modbus Unit ID (Slave ID) (default: {UNIT})")
    
    # ACTION ARGUMENTS: Mutually exclusive read or write operation
    action_group = parser.add_mutually_exclusive_group()
    action_group.add_argument("-d", "--read", action="store_true", 
                        help="Perform a READ-ONLY request of the current state (Default).")
    action_group.add_argument("-w", "--write", action="store_true", 
                        help="Perform a WRITE TEST: Read initial state, WRITE to Coils/HRs, Read final state, and display both tables.")
    
    # SELECTIVE READ/WRITE FILTERS
    filter_group = parser.add_mutually_exclusive_group()
    filter_group.add_argument("-c", "--coil-only", action="store_true", 
                        help="Limit the operation to Coils (0x01) and Discrete Inputs (0x02) only.")
    filter_group.add_argument("-r", "--register-only", action="store_true", 
                        help="Limit the operation to Holding Registers (0x03) and Input Registers (0x04) only.")
    
    # Optional supplementary action (Now acts as an exclusive action)
    parser.add_argument("-s", "--server", action="store_true", 
                        help="Read and display server device identification information only.")
    
    return parser # Return the parser object

# ================================================== #
# //        Read and Display Device Information   // #
# ================================================== #

def read_and_display_device_id(client):
    """Reads device identification fields and displays them using tabulate."""
    
    ID_FIELDS = {
        0x00: "Vendor Name", 0x01: "Product Code", 0x02: "Major/Minor Revision", 
        0x03: "Product Name", 0x04: "Model Name", 0x05: "Vendor URL",
    }
    
    table_data = list()
    
    for object_id, field_name in ID_FIELDS.items():
        try:
            # Using read_code=0x01 (Basic Device Identification)
            response = client.read_device_information(read_code=0x01, object_id=object_id) 

            if response.isError():
                value = f"Failed to read (Error: {response})"
            else:
                value_bytes = response.information.get(object_id, b"N/A") 
                # Decode bytes object to string and strip whitespace
                value = value_bytes.decode('utf-8').strip()
            
            table_data.append([field_name, value])
            
        except Exception as e:
            table_data.append([field_name, f"Exception: {e}"])
            
    log.info("\n--- Modbus Server Device Identification ---")
    log.info(tabulate(table_data, headers=["Field", "Value"], tablefmt="fancy_grid"))

# ================================================== #
# //         Helper for Reading Modbus State      // #
# ================================================== #

def get_MB_state(client, read_coils=True, read_contacts=True, read_holdings=True, read_inputs=True):
    """
    Reads selected Modbus register types and returns a dictionary of their states. 
    If a read flag is False, the dictionary entry will be None.
    """
    
    state = dict()
    
    # Read Coil State (0x01) - R/W
    if read_coils:
        rr_coil = client.read_coils(address=START_ADDR, count=COUNT)
        state['Coils (0x01, R/W)'] = rr_coil.bits if not rr_coil.isError() else None
    else:
        state['Coils (0x01, R/W)'] = None

    # Read Discrete Input State (0x02) - R/O
    if read_contacts:
        rr_di = client.read_discrete_inputs(address=START_ADDR, count=COUNT)
        state['Discrete Inputs (0x02, R/O)'] = rr_di.bits if not rr_di.isError() else None
    else:
        state['Discrete Inputs (0x02, R/O)'] = None
    
    # Read Holding Register State (0x03) - R/W
    if read_holdings:
        rr_hr = client.read_holding_registers(address=START_ADDR, count=COUNT)
        state['Holding Registers (0x03, R/W)'] = rr_hr.registers if not rr_hr.isError() else None
    else:
        state['Holding Registers (0x03, R/W)'] = None
    
    # Read Input Register State (0x04) - R/O
    if read_inputs:
        rr_ir = client.read_input_registers(address=START_ADDR, count=COUNT)
        state['Input Registers (0x04, R/O)'] = rr_ir.registers if not rr_ir.isError() else None
    else:
        state['Input Registers (0x04, R/O)'] = None
    
    return state

def transpose_MB_state(data_dict, show_coils_di, show_hrs_ir):
    """
    Converts register type-keyed dictionary into an address-indexed list of lists,
    only including columns corresponding to the requested display groups.
    """
    transposed_data = list()
    
    # Get values, handling None for read errors or intentionally skipped reads
    coils = data_dict.get('Coils (0x01, R/W)')
    di = data_dict.get('Discrete Inputs (0x02, R/O)')
    hr = data_dict.get('Holding Registers (0x03, R/W)')
    ir = data_dict.get('Input Registers (0x04, R/O)')

    # Helper function to safely retrieve value or placeholder
    def get_val(data, index, default_err='ERROR'):
        # If data is None here, it means the read either failed or was intentionally skipped in the read function.
        if data is None:
            return 'SKIP' # 'SKIP' is used when the data was not requested/read.
        if data == 'ERROR':
            return 'ERROR' # 'ERROR' is used when the read failed.
        if index < len(data):
            return str(data[index])
        return 'ERROR' # 'ERROR' for length mismatch

    # Iterate over the count/addresses
    for i in range(COUNT):
        address = START_ADDR + i
        row = [address] # Start row with address
        
        # Conditionally add Coil/DI data
        if show_coils_di:
            row.append(get_val(coils, i)) 
            row.append(get_val(di, i))
        
        # Conditionally add HR/IR data
        if show_hrs_ir:
            row.append(get_val(hr, i))
            row.append(get_val(ir, i))

        transposed_data.append(row)
    return transposed_data

def display_MB_table(data_dict, title, show_coils_di=True, show_hrs_ir=True):
    """Formats and displays Modbus register data, dynamically choosing columns."""
    
    # Transpose data first, passing the display flags
    table_data = transpose_MB_state(data_dict, show_coils_di, show_hrs_ir)
    
    # Define headers dynamically based on display flags
    headers = ["Addr"]
    
    if show_coils_di:
        headers.append("Coil\n(0x01, R/W)")
        headers.append("DI\n(0x02, R/O)")
    
    if show_hrs_ir:
        headers.append("HR\n(0x03, R/W)")
        headers.append("IR\n(0x04, R/O)")
    
    log.info(f"\n--- {title} ---")
    # Set both string and numeric alignment to centre
    log.info(tabulate(table_data, headers=headers, tablefmt="fancy_grid", numalign="center", stralign="center"))
    log.info('-' * 65)

# ================================================== #
# //        Dynamic Value Selection Helpers       // #
# ================================================== #

def get_different_CO_values(current_coils):
    """Generates a new list of coil values that are the opposite of the current state."""
    if not current_coils:
        # Default to all True if initial read failed
        return [True] * COUNT 
    
    # Simple inversion guarantees a different value for each coil
    return [not val for val in current_coils]

def get_different_register_values(current_registers):
    """
    Generates a new list of register values where each value is guaranteed to be different 
    from the current state by selecting from RR_WR_POOL.
    """
    if not current_registers:
        # Default to random values from the pool if initial read failed
        return random.sample(RR_WR_POOL, COUNT) 
    
    new_values = list()
    
    for current_val in current_registers:
        # Create a list of all values in the pool that are NOT the current value
        available_options = [val for val in RR_WR_POOL if val != current_val]
        
        if available_options:
            # Select one of the different values randomly
            new_values.append(random.choice(available_options))
        else:
            # Fallback if a unique value cannot be found
            log.warning(f"Could not find a unique write value for register value {current_val}. Using 1 as a fallback.")
            new_values.append(1) 
            
    return new_values

# ================================================== #
# //               Run the client                 // #
# ================================================== #

def run_client(args):
    """Initialises and runs the Modbus client actions based on terminal flags."""
    
    log.info(f"Connecting to Modbus Server at {args.ip}:{args.port} (Unit ID: {args.unit})...")
    
    # Initialise Modbus TCP Client
    client = ModbusTcpClient(args.ip, port=args.port) 
    
    if not client.connect():
        log.error("Failed to connect to Modbus Server. Exiting.")
        return

    log.info("Client connected.")

    # 1. DISPLAY SERVER INFO (-s) and EXIT if requested
    # This section makes the -s flag exclusive, overriding -d or -w.
    if args.server:
        log.info("\n[INFO] Server identification requested (-s). Overriding all other actions.")
        read_and_display_device_id(client) 
        client.close()
        log.info("\nClient disconnected.")
        return # Exit the function immediately, skipping read/write operations
        
    # 2. WRITE ACTION (-w) - Selective based on -c and -r (Only runs if args.server is False)
    if args.write:
        
        # --- Determine Write & Display Scopes ---
        # The mutual exclusivity of -c and -r is handled by argparse.
        is_full_test = not args.coil_only and not args.register_only
        
        write_coils = args.coil_only or is_full_test
        write_hrs = args.register_only or is_full_test
        
        display_coils_di = args.coil_only or is_full_test
        display_hrs_ir = args.register_only or is_full_test

        if is_full_test:
            read_title = "WRITE TEST SUITE (Read Initial, Write, Read Final)"
        elif args.coil_only:
            read_title = "WRITE TEST SUITE (Coils Only)"
        elif args.register_only:
            read_title = "WRITE TEST SUITE (Registers Only)"
            
        log.info(f"\n--- ACTION: {read_title} ---")
        
        # 2a. Read Initial State 
        # Always read all four types for complete data set, then display targeted columns.
        initial_state = get_MB_state(client, True, True, True, True)
        display_MB_table(initial_state, "INITIAL MODBUS REGISTER STATE (Read before write)", 
                             show_coils_di=display_coils_di, show_hrs_ir=display_hrs_ir)

        # 2b. Write Operations and Verification (Coils)
        if write_coils:
            initial_coils = initial_state.get('Coils (0x01, R/W)')
            coils_values = get_different_CO_values(initial_coils) 
            
            log.info(f"\n[INFO] Attempting to WRITE Coils ({START_ADDR}-{START_ADDR + COUNT - 1}) with values: {coils_values}")
            rq = client.write_coils(START_ADDR, coils_values) 
            
            if rq.isError(): 
                log.error(f"Write Coils error: {rq}")
            else: 
                # Verification Read for Coils
                rr_verify = client.read_coils(address=START_ADDR, count=COUNT)
                if rr_verify.isError():
                    log.error(f"[FAIL] Coil Verification Read Error: {rr_verify}")
                elif rr_verify.bits == coils_values:
                    log.info("[VERIFY] Coils read back MATCHES written values.")
                else:
                    log.error(f"[FAIL] Coils read back does NOT match written values. Read: {rr_verify.bits}")
        else:
            log.info("\n[SKIP] Coil writes skipped (due to -r flag).")
            
        # 2c. Write Operations and Verification (Holding Registers)
        if write_hrs:
            initial_registers = initial_state.get('Holding Registers (0x03, R/W)')
            holding_values = get_different_register_values(initial_registers)

            log.info(f"\n[INFO] Attempting to WRITE Holding Registers ({START_ADDR}-{START_ADDR + COUNT - 1}) with values: {holding_values}")
            rq = client.write_registers(START_ADDR, holding_values) 
            
            if rq.isError(): 
                log.error(f"Write Registers error: {rq}")
            else: 
                # Verification Read for Holding Registers
                rr_verify = client.read_holding_registers(address=START_ADDR, count=COUNT)
                if rr_verify.isError():
                    log.error(f"[FAIL] Holding Register Verification Read Error: {rr_verify}")
                elif rr_verify.registers == holding_values:
                    log.info("[VERIFY] Holding Registers read back MATCHES written values.")
                else:
                    log.error(f"[FAIL] Holding Registers read back does NOT match written values. Read: {rr_verify.registers}")
        else:
            log.info("\n[SKIP] Holding Register writes skipped (due to -c flag).")
            
        # 2d. Read Final State 
        # Always read all four types for complete data set, then display targeted columns.
        final_state = get_MB_state(client, True, True, True, True)
        display_MB_table(final_state, "FINAL MODBUS REGISTER STATE (Read after write)",
                             show_coils_di=display_coils_di, show_hrs_ir=display_hrs_ir)


    # 3. READ-ONLY ACTION (-d) - Selective based on -c or -r (Only runs if args.server is False)
    elif args.read:
        
        # Determine which register groups to read and which to show
        show_coils_di = args.coil_only
        show_hrs_ir = args.register_only
        
        # If neither -c nor -r is specified, read and show all four types (default -d behavior)
        if not show_coils_di and not show_hrs_ir:
            show_coils_di = True
            show_hrs_ir = True
            read_title = "CURRENT MODBUS REGISTER STATE (All Types Read)"
        elif show_coils_di: # implies not show_hrs_ir because of mutual exclusivity
            read_title = "CURRENT MODBUS REGISTER STATE (Coils/DIs Only Read)"
        elif show_hrs_ir: # implies not show_coils_di because of mutual exclusivity
            read_title = "CURRENT MODBUS REGISTER STATE (HRs/IRs Only Read)"
        
        log.info(f"\n--- ACTION: READ-ONLY OF CURRENT MODBUS STATE ---")
        
        # We only call get_MB_state for the requested groups to save network traffic
        current_state = get_MB_state(
            client, 
            read_coils=show_coils_di,
            read_contacts=show_coils_di,
            read_holdings=show_hrs_ir,
            read_inputs=show_hrs_ir
        )
        # Pass the flags down to control table display, removing skipped columns entirely
        display_MB_table(current_state, read_title, show_coils_di, show_hrs_ir)
        
        
    client.close()
    log.info("\nClient disconnected.")

# ================================================== #
# //                Main Starter                  // #
# ================================================== #

if __name__ == "__main__":
    # Get the parser object
    parser = get_commandline_parser()
    
    # Parse the arguments. Since 'required=True' is removed, manually check for the action.
    args = parser.parse_args()
    
    # Check for default behaviour: if neither -d nor -w was provided, assume -d (read).
    if not args.read and not args.write:
        args.read = True
        log.info("[INFO] No action flag (-d or -w) provided. Defaulting to READ-ONLY (-d) mode.")
        
    try:
        run_client(args)
    except KeyboardInterrupt:
        log.info("Client manually interrupted.")
    except Exception as e:
        log.critical(f"Unhandled exception in main thread: {e}")
        sys.exit(1)
