import cv2
import numpy as np
import serial
import time

# =====================
# CONFIGURATION
# =====================
CAMERA_INDEX = 0   # Camera index (0 = default webcam)
SERIAL_PORT = "COM3"  # Change this to your Arduino COM port (e.g., COM3, COM5, /dev/ttyUSB0)
BAUD_RATE = 9600

# Try to connect to Arduino
try:
    arduino = serial.Serial(SERIAL_PORT, BAUD_RATE, timeout=1)
    time.sleep(2)  # wait for Arduino reset
    print("Connected to Arduino on", SERIAL_PORT)
except Exception as e:
    arduino = None
    print("Could not connect to Arduino:", e)

# =====================
# FUNCTIONS
# =====================

def nothing(x):
    pass

def get_mask(frame, hsv_ranges):
    hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)
    lower = np.array([hsv_ranges["H Min"], hsv_ranges["S Min"], hsv_ranges["V Min"]])
    upper = np.array([hsv_ranges["H Max"], hsv_ranges["S Max"], hsv_ranges["V Max"]])
    mask = cv2.inRange(hsv, lower, upper)
    return mask

def get_contour_info(mask, frame):
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    if not contours:
        return frame, None

    c = max(contours, key=cv2.contourArea)
    if cv2.contourArea(c) < 500:
        return frame, None

    cv2.drawContours(frame, [c], -1, (0, 0, 255), 3)

    x_vals = c[:, 0, 0]
    y_vals = c[:, 0, 1]
    w_px = max(x_vals) - min(x_vals)
    h_px = max(y_vals) - min(y_vals)

    if w_px == 0 or h_px == 0:
        return frame, None

    oar = round(float(w_px) / h_px, 2)
    return frame, oar

def send_command(cmd):
    """Send a command to Arduino if connected."""
    if arduino:
        arduino.write((cmd + "\n").encode("utf-8"))
        print("Sent:", cmd)

# =====================
# MAIN PROGRAM
# =====================

cap = cv2.VideoCapture(CAMERA_INDEX)

# --- Create resizable windows ---
cv2.namedWindow("Trackbars", cv2.WINDOW_NORMAL)
cv2.resizeWindow("Trackbars", 400, 100)

cv2.namedWindow("Original", cv2.WINDOW_NORMAL)
cv2.resizeWindow("Original", 640, 480)

cv2.namedWindow("Mask", cv2.WINDOW_NORMAL)
cv2.resizeWindow("Mask", 640, 480)

cv2.namedWindow("Result", cv2.WINDOW_NORMAL)
cv2.resizeWindow("Result", 640, 480)

# Create HSV trackbars
for name, maxval, initval in [
    ("H Min", 179, 0), ("S Min", 255, 0), ("V Min", 255, 0),
    ("H Max", 179, 179), ("S Max", 255, 255), ("V Max", 255, 255)
]:
    cv2.createTrackbar(name, "Trackbars", initval, maxval, nothing)

last_shape = None  # track last sent command to avoid spamming

while True:
    ret, frame = cap.read()
    if not ret:
        break

    hsv_ranges = {name: cv2.getTrackbarPos(name, "Trackbars") for name in
                  ["H Min", "S Min", "V Min", "H Max", "S Max", "V Max"]}

    mask = get_mask(frame, hsv_ranges)
    result, oar = get_contour_info(mask, frame.copy())

    # ==========================
    # IF / ELSE + SERIAL CONTROL
    # ==========================
    if oar is None:
        shape = "No object"
    elif oar >= 1:
        shape = "Square"
        if last_shape != shape:  # only send if changed
            send_command("S1 180")
            last_shape = shape
    elif oar <= 1:
        shape = "Rectangle"
        if last_shape != shape:
            send_command("S1 0")
            last_shape = shape
    # ==========================

    if oar:
        cv2.putText(result, f"OAR={oar}", (30, 50),
                    cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
    cv2.putText(result, shape, (30, 90),
                cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)

    cv2.imshow("Original", frame)
    cv2.imshow("Mask", mask)
    cv2.imshow("Result", result)

    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

cap.release()
cv2.destroyAllWindows()

if arduino:
    arduino.close()
