#  Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
#  with the License. A copy of the License is located at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  or in the 'license' file accompanying this file. This file is distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES
#  OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions
#  and limitations under the License.
#
# This script will:
# 1. Download the RES app zip from S3
# 2. Unzip the app code recursively
# 3. Apply a patch on the source code
# 4. Rebuild and zip the updated source code
# 5. Upload the updated app zip back to S3
#
# Prerequisites:
# - AWS CLI v2 is installed
# - Make sure you've configured the AWS CLI for the account / region where RES is deployed and
# that you have S3 permissions to write to the bucket created by RES.
# - Python 3.9.16 or above and boto3 are installed

import argparse
import os
import shutil
import subprocess
import tempfile
import atexit

import boto3


def cleanup(working_dir: tempfile.TemporaryDirectory):
    working_dir.cleanup()


parser = argparse.ArgumentParser(description="Apply patch on RES app")
parser.add_argument(
    "--res-version",
    help="RES version",
    required=True,
    type=str,
    choices=["2024.04.02", "2024.06"],
)
parser.add_argument(
    "--environment-name", help="Name of the RES environment", required=True, type=str
)
parser.add_argument(
    "--module",
    help="Name of the module to update",
    required=True,
    type=str,
    choices=["cluster-manager", "virtual-desktop-controller"],
)
parser.add_argument("--region", help="AWS region of the RES environment", type=str)
parser.add_argument("--account-id", help="AWS account of the RES environment", type=str)
parser.add_argument("--patch", help="Path of the patch file to apply", required=True, type=str)

args = parser.parse_args()

if not os.path.exists(args.patch):
    print(f"Patch file {args.patch} doesn't exist")
    exit(1)

session = boto3.session.Session()
args.account_id = args.account_id or session.client("sts").get_caller_identity().get("Account")
args.region = args.region or session.region_name

temp_dir = tempfile.TemporaryDirectory()
atexit.register(cleanup, temp_dir)
temp_dir_name = temp_dir.name

cluster_bucket_name = f"{args.environment_name}-cluster-{args.region}-{args.account_id}"
app_zip_name = f"idea-{args.module}-{args.res_version}.tar.gz"
app_zip_path = os.path.join(temp_dir_name, app_zip_name)
s3_object_key = f"idea/releases/{app_zip_name}"

print(f"Downloading {app_zip_name} to {app_zip_path} ...")
session.client("s3").download_file(cluster_bucket_name, s3_object_key, app_zip_path)

print(f"Unzipping {app_zip_name} ...")
unzip_dir = os.path.join(temp_dir_name, f"idea-{args.module}-{args.res_version}")
os.makedirs(unzip_dir, exist_ok=True)
subprocess.check_call(
    ["tar", "-xvf", app_zip_path, "-C", unzip_dir],
    stdout=subprocess.DEVNULL,
    stderr=subprocess.STDOUT,
)

zipped_libs = (entry for entry in os.listdir(unzip_dir) if entry.endswith("-lib.tar.gz"))
for entry in zipped_libs:
    print(f"Unzipping lib {entry} ...")
    subprocess.check_call(
        ["tar", "-xvf", entry, "-C", unzip_dir],
        cwd=unzip_dir,
        stdout=subprocess.DEVNULL,
        stderr=subprocess.STDOUT,
    )
    os.remove(os.path.join(unzip_dir, entry))

print(f"Applying patch {args.patch} ...")
subprocess.check_call([f"patch -p0 <{args.patch}"], cwd=unzip_dir, shell=True)

lib_src_dirs = (
    entry
    for entry in os.listdir(unzip_dir)
    if os.path.exists(os.path.join(unzip_dir, entry, "setup.py"))
)
for entry in lib_src_dirs:
    absolute_path = os.path.join(unzip_dir, entry)
    print(f"Rebuilding lib {entry} ...")
    subprocess.check_call(
        ["python3", "setup.py", "sdist", "-d", unzip_dir],
        cwd=absolute_path,
        stdout=subprocess.DEVNULL,
        stderr=subprocess.STDOUT,
    )
    normalized_lib_version = subprocess.check_output(
        ["python3", "setup.py", "--version"],
        cwd=absolute_path,
        stderr=subprocess.DEVNULL,
    ).decode()

    sdist = absolute_path + ".tar.gz"
    shutil.copy(sdist, absolute_path.rstrip(normalized_lib_version) + "lib.tar.gz")

    shutil.rmtree(absolute_path)
    os.remove(sdist)

print(f"Rebuilding {app_zip_name} ...")
shutil.make_archive(unzip_dir, "gztar", unzip_dir)

print(f"Uploading {app_zip_name} ...")
session.client("s3").upload_file(app_zip_path, cluster_bucket_name, s3_object_key)

print(
    "Patch applied successfully. Please terminate the existing infra instance and "
    "wait for a new one to be launched by the corresponding autoscaling group automatically."
)
