#!/usr/bin/env python

from azure import *
from azure.servicemanagement import *
import argparse
import urllib2
import time
import base64
import os
import subprocess

parser = argparse.ArgumentParser(description='Create a CoreOS cluster on Microsoft Azure.')
parser.add_argument('--version', action='version', version='azure-coreos-cluster 0.1')
parser.add_argument('cloud_service_name',
                   help='cloud service name')
parser.add_argument('--ssh-cert',
                   help='certificate file with public key for ssh, in .cer format')
parser.add_argument('--ssh-thumb',
                   help='thumbprint of ssh cert')
parser.add_argument('--subscription', required=True,
                   help='required Azure subscription id')
parser.add_argument('--azure-cert', required=True,
                   help='required path to Azure cert pem file')
parser.add_argument('--blob-container-url', required=True,
                   help='required url to blob container where vm disk images will be created, including /, ex: https://patcoreos.blob.core.windows.net/vhds/')
parser.add_argument('--vm-size', default='Small',
                   help='optional, VM size [Small]')
parser.add_argument('--vm-name-prefix', default='coreos',
                   help='optional, VM name prefix [coreos]')
parser.add_argument('--availability-set', default='coreos-as',
                   help='optional, name of availability set for cluster [coreos-as]')
parser.add_argument('--location', default='West US',
                   help='optional, [West US]')
parser.add_argument('--ssh', default=22001, type=int,
                   help='optional, starts with 22001 and +1 for each machine in cluster')
parser.add_argument('--coreos-image', default='2b171e93f07c4903bcad35bda10acf22__CoreOS-Stable-607.0.0',
                   help='optional, [2b171e93f07c4903bcad35bda10acf22__CoreOS-Stable-607.0.0]')
parser.add_argument('--num-nodes', default=3, type=int,
                   help='optional, number of nodes to create (or add), defaults to 3')
parser.add_argument('--virtual-network-name',
                   help='optional, name of an existing virtual network to which we will add the VMs')
parser.add_argument('--subnet-names',
                   help='optional, subnet name to which the VMs will belong')
parser.add_argument('--custom-data',
                   help='optional, path to your own cloud-init file')
parser.add_argument('--discovery-service-url',
                   help='optional, url for an existing cluster discovery service. Else we will generate one.')
parser.add_argument('--pip', action='store_true',
                   help='optional, assigns public instance ip addresses to each VM')
parser.add_argument('--deis', action='store_true',
                   help='optional, automatically opens http and controller endpoints')
parser.add_argument('--data-disk', action='store_true',
                   help='optional, attaches a data disk to each VM')

cloud_init_template = """#cloud-config

coreos:
  etcd:
    # generate a new token for each unique cluster from https://discovery.etcd.io/new
    discovery: {0}
    # deployments across multiple cloud services will need to use $public_ipv4
    addr: $private_ipv4:4001
    peer-addr: $private_ipv4:7001
  units:
    - name: etcd.service
      command: start
    - name: fleet.service
      command: start
"""

args = parser.parse_args()

# Create SSH cert if it's not given
if not args.ssh_cert and not args.ssh_thumb:
  print 'SSH arguments not given, generating certificate'
  with open(os.devnull, 'w') as shutup:
      subprocess.call('openssl req -x509 -nodes -days 365 -newkey rsa:2048 -config cert.conf -keyout ssh-cert.key -out ssh-cert.pem', shell=True, stdout=shutup, stderr=shutup)
      subprocess.call('chmod 600 ssh-cert.key', shell=True, stdout=shutup, stderr=shutup)
      subprocess.call('openssl  x509 -outform der -in ssh-cert.pem -out ssh-cert.cer', shell=True, stdout=shutup, stderr=shutup)
      thumbprint = subprocess.check_output('openssl x509 -in ssh-cert.pem -sha1 -noout -fingerprint | sed s/://g', shell=True)
      args.ssh_thumb = thumbprint.split('=')[1].replace('\n', '')
      args.ssh_cert = './ssh-cert.cer'
  print 'Generated SSH certificate with thumbprint ' + args.ssh_thumb

# Setup custom data
if args.custom_data:
    with open(args.custom_data, 'r') as f:
      if not os.path.exists(args.custom_data):
        print "Couldn't find the user-data file. Did you remember to run `create-azure-user-data`?"
        sys.exit(1)
      cloud_init = f.read()
    f.closed
else:
    if args.discovery_service_url:
        cloud_init = cloud_init_template.format(args.discovery_service_url)
    else:
        response = urllib2.urlopen('https://discovery.etcd.io/new')
        discovery_url = response.read()
        cloud_init = cloud_init_template.format(discovery_url)

SERVICE_CERT_FORMAT = 'pfx'

with open(args.ssh_cert) as f:
    service_cert_file_data = base64.b64encode(f.read())
f.closed

def wait_for_async(request_id, timeout):
    count = 0
    result = sms.get_operation_status(request_id)
    while result.status == 'InProgress':
        count = count + 1
        if count > timeout:
            print('Timed out waiting for async operation to complete.')
            return
        time.sleep(5)
        print('.'),
        sys.stdout.flush()
        result = sms.get_operation_status(request_id)
        if result.error:
            print(result.error.code)
            print(vars(result.error))
    print result.status + ' in ' + str(count*5) + 's'

def linux_config(hostname, args):
    pk = PublicKey(args.ssh_thumb,
                   u'/home/core/.ssh/authorized_keys')
    system = LinuxConfigurationSet(hostname, 'core', None, True,
              custom_data=cloud_init)
    system.ssh.public_keys.public_keys.append(pk)
    system.disable_ssh_password_authentication = True
    return system

def endpoint_config(name, port, probe=False):
    endpoint = ConfigurationSetInputEndpoint(name, 'tcp', port, port, name)
    if probe:
      endpoint.load_balancer_probe = probe
    return endpoint

def load_balancer_probe(path, port, protocol):
    load_balancer_probe = LoadBalancerProbe()
    load_balancer_probe.path = path
    load_balancer_probe.port = port
    load_balancer_probe.protocol = protocol
    return load_balancer_probe

def network_config(subnet_name=None, port='59913', public_ip_name=None):
    network = ConfigurationSet()
    network.configuration_set_type = 'NetworkConfiguration'
    network.input_endpoints.input_endpoints.append(
        ConfigurationSetInputEndpoint('ssh', 'tcp', port, '22'))
    if subnet_name:
        network.subnet_names.append(subnet_name)
    if public_ip_name:
        ip = PublicIP(name=public_ip_name)
        ip.idle_timeout_in_minutes = 20
        network.public_ips.public_ips.append(ip)
    if args.deis:
        # create web endpoint with probe checking /health-check
        network.input_endpoints.input_endpoints.append(endpoint_config('web', '80', load_balancer_probe('/health-check', '80', 'http')))
        # create builder endpoint TCP probe check
        network.input_endpoints.input_endpoints.append(endpoint_config('builder', '2222', load_balancer_probe(None, '2222', 'tcp')))
    return network

def data_hd(target_container_url, target_blob_name, target_lun, target_disk_size_in_gb):
    media_link = target_container_url + target_blob_name
    data_hd = DataVirtualHardDisk()
    data_hd.disk_label = target_blob_name
    data_hd.logical_disk_size_in_gb = target_disk_size_in_gb
    data_hd.lun = target_lun
    data_hd.media_link = media_link
    return data_hd

sms = ServiceManagementService(args.subscription, args.azure_cert)

#Create the cloud service
try:
  print 'Creating the hosted service...',
  sys.stdout.flush()
  sms.create_hosted_service(
      args.cloud_service_name, label=args.cloud_service_name, location=args.location)
  print('Successfully created hosted service ' + args.cloud_service_name)
  sys.stdout.flush()
  time.sleep(2)
except WindowsAzureConflictError:
  print "Hosted service {} already exists. Delete it or try again with a different name.".format(args.cloud_service_name)
  sys.exit(1)

#upload ssh cert to cloud-service
print 'Uploading SSH certificate...',
sys.stdout.flush()
result = sms.add_service_certificate(args.cloud_service_name,
                                     service_cert_file_data, SERVICE_CERT_FORMAT, '')
wait_for_async(result.request_id, 15)

def get_vm_name(args, i):
    return args.cloud_service_name + '-' + args.vm_name_prefix + '-' + str(i)

vms =[]

#Create the VMs
for i in range(args.num_nodes):
    ssh_port = args.ssh +i
    vm_name = get_vm_name(args, i)
    if args.pip:
        pip_name = vm_name
    else:
        pip_name = None
    media_link = args.blob_container_url + vm_name
    os_hd = OSVirtualHardDisk(media_link=media_link,
                            source_image_name=args.coreos_image)
    system = linux_config(vm_name, args)
    network = network_config(subnet_name=args.subnet_names, port=ssh_port, public_ip_name=pip_name)
    #specifiy the data disk, important to start at lun = 0
    if args.data_disk:
        data_disk = data_hd(args.blob_container_url, vm_name + '-data.vhd', 0, 100)
        data_disks = DataVirtualHardDisks()
        data_disks.data_virtual_hard_disks.append(data_disk)
    else:
        data_disks = None

    try:
      if i == 0:
          result = sms.create_virtual_machine_deployment(
                      args.cloud_service_name, deployment_name=args.cloud_service_name,
                      deployment_slot='production', label=vm_name,
                      role_name=vm_name, system_config=system, os_virtual_hard_disk=os_hd,
                      role_size=args.vm_size, network_config=network, data_virtual_hard_disks=data_disks)
      else:
          result = sms.add_role(
                      args.cloud_service_name, deployment_name=args.cloud_service_name,
                      role_name=vm_name,
                      system_config=system, os_virtual_hard_disk=os_hd,
                      role_size=args.vm_size, network_config=network, data_virtual_hard_disks=data_disks)
    except WindowsAzureError as e:
      if "Forbidden" in str(e):
        print "Unable to use this CoreOS image. This usually means a newer image has been published."
        print "See https://coreos.com/docs/running-coreos/cloud-providers/azure/ for the latest stable image,"
        print "and supply it to this script with --coreos-image. If it works, please open a pull request to update this script."
        sys.exit(1)
      else:
        pass

    print 'Creating VM ' + vm_name + '...',
    sys.stdout.flush()
    wait_for_async(result.request_id, 30)
    vms.append({'name':vm_name,
                'host':args.cloud_service_name + '.cloudapp.net',
                'port':ssh_port,
                'user':'core',
                'identity':args.ssh_cert.replace('.cer','.key')})

#get the ip addresses
def get_ips(service_name, deployment_name):
    result = sms.get_deployment_by_name(service_name, deployment_name)
    for instance in result.role_instance_list:
        ips.append(instance.public_ips[0].address)
    return ips

#print dns config
if args.pip:
    ips = []
    ips = get_ips(args.cloud_service_name, args.cloud_service_name)
    print ''
    print '-------'
    print "You'll need to configure DNS records for a domain you wish to use with your Deis cluster."
    print 'For convenience, the public IP addresses are printed below, along with sane DNS timeouts.'
    print ''
    for ip in ips:
        print '@ 10800 IN A ' + ip
    print '* 10800 IN CNAME @'
    print '-------'
    print 'For more information, see: http://docs.deis.io/en/latest/managing_deis/configure-dns/'
    print ''

#print ~/.ssh/config
print ''
print '-------'
print "Instances on Azure don't use typical SSH ports. It is recommended to configure ~/.ssh/config"
print 'so the instances can easily be referenced when logging in via SSH. For convenience, the config'
print 'directives for your instances are below:'
print ''
for vm in vms:
    print 'Host ' + vm['name']
    print '    HostName ' + vm['host']
    print '    Port ' + str(vm['port'])
    print '    User ' + vm['user']
    print '    IdentityFile ' + vm['identity']
print '-------'
print ''
