import numpy as np
import cv2
from PIL import Image

def image_interpolation(image1, image2, alpha):
    # Ensure both images have the same size
    assert image1.shape == image2.shape, "Input images must have the same shape"

    # Normalize the alpha value
    alpha = max(0, min(1, alpha))

    # Perform interpolation
    interpolated_image = np.uint8((1 - alpha) * image1 + alpha * image2)

    return interpolated_image

def pure_pil_alpha_to_color_v2(image, color=(0, 0, 0)):
    image.load()  # needed for split()
    background = Image.new('RGB', image.size, color)
    background.paste(image, mask=image.split()[3])  # 3 is the alpha channel
    return background

image = Image.open("ciscoTheProot/faces/prootface1.bmp")
image.thumbnail((128, 32), Image.ANTIALIAS)
RGBImage = pure_pil_alpha_to_color_v2(image)
image1 = np.array(RGBImage, dtype=np.uint8)

image = Image.open("ciscoTheProot/faces/prootface2.bmp")
image.thumbnail((128, 32), Image.ANTIALIAS)
RGBImage = pure_pil_alpha_to_color_v2(image)
image2 = np.array(RGBImage, dtype=np.uint8)

image = Image.open("ciscoTheProot/faces/prootface3.bmp")
image.thumbnail((128, 32), Image.ANTIALIAS)
RGBImage = pure_pil_alpha_to_color_v2(image)
image3 = np.array(RGBImage, dtype=np.uint8)


# Assuming you have image1 and image2 as input images

# Interpolate with alpha = 0.5
alpha = 0
interpolated_image = image_interpolation(image2, image3, alpha)

# Display the interpolated image
im_bgr = cv2.cvtColor(interpolated_image, cv2.COLOR_RGB2BGR)
cv2.imshow("Interpolated Image", im_bgr)
cv2.waitKey(0)


alpha = 0.2
interpolated_image = image_interpolation(image2, image3, alpha)
im_bgr = cv2.cvtColor(interpolated_image, cv2.COLOR_RGB2BGR)
cv2.imshow("Interpolated Image", im_bgr)
cv2.waitKey(0)


alpha = 0.4
interpolated_image = image_interpolation(image2, image3, alpha)
im_bgr = cv2.cvtColor(interpolated_image, cv2.COLOR_RGB2BGR)
cv2.imshow("Interpolated Image", im_bgr)
cv2.waitKey(0)

alpha = 0.6
interpolated_image = image_interpolation(image2, image3, alpha)
im_bgr = cv2.cvtColor(interpolated_image, cv2.COLOR_RGB2BGR)
cv2.imshow("Interpolated Image", im_bgr)
cv2.waitKey(0)

alpha = 0.8
interpolated_image = image_interpolation(image2, image3, alpha)
im_bgr = cv2.cvtColor(interpolated_image, cv2.COLOR_RGB2BGR)
cv2.imshow("Interpolated Image", im_bgr)
cv2.waitKey(0)


alpha = 1
interpolated_image = image_interpolation(image2, image3, alpha)
im_bgr = cv2.cvtColor(interpolated_image, cv2.COLOR_RGB2BGR)
cv2.imshow("Interpolated Image", im_bgr)
cv2.waitKey(0)

cv2.destroyAllWindows()