#!/usr/bin/env python

from azure import *
from azure.servicemanagement import *
import argparse
import urllib2
import time
import base64
import os
import subprocess

parser = argparse.ArgumentParser(description='Create a CoreOS cluster on Microsoft Azure.')
parser.add_argument('--version', action='version', version='azure-coreos-cluster 0.1')
parser.add_argument('cloud_service_name',
                   help='cloud service name')
parser.add_argument('--ssh-cert',
                   help='certificate file with public key for ssh, in .cer format')
parser.add_argument('--ssh-thumb',
                   help='thumbprint of ssh cert')
parser.add_argument('--subscription', required=True,
                   help='required Azure subscription id')
parser.add_argument('--azure-cert', required=True,
                   help='required path to Azure cert pem file')
parser.add_argument('--blob-container-url', required=True,
                   help='required url to blob container where vm disk images will be created, including /, ex: https://patcoreos.blob.core.windows.net/vhds/')
parser.add_argument('--vm-size', default='Small',
                   help='optional, VM size [Small]')
parser.add_argument('--vm-name-prefix', default='coreos',
                   help='optional, VM name prefix [coreos]')
parser.add_argument('--availability-set', default='coreos-as',
                   help='optional, name of availability set for cluster [coreos-as]')
parser.add_argument('--location', default='West US',
                   help='optional - overriden by affinity-group, [West US]')
parser.add_argument('--affinity-group', default='',
                   help='optional, overrides location if specified')
parser.add_argument('--ssh', default=22001, type=int,
                   help='optional, starts with 22001 and +1 for each machine in cluster')
parser.add_argument('--coreos-image', default='2b171e93f07c4903bcad35bda10acf22__CoreOS-Stable-681.2.0',
                   help='optional, [2b171e93f07c4903bcad35bda10acf22__CoreOS-Stable-681.2.0]')
parser.add_argument('--num-nodes', default=3, type=int,
                   help='optional, number of nodes to create (or add), defaults to 3')
parser.add_argument('--virtual-network-name',
                   help='optional, name of an existing virtual network to which we will add the VMs')
parser.add_argument('--subnet-names',
                   help='optional, subnet name to which the VMs will belong')
parser.add_argument('--custom-data',
                   help='optional, path to your own cloud-init file')
parser.add_argument('--discovery-service-url',
                   help='optional, url for an existing cluster discovery service. Else we will generate one.')
parser.add_argument('--pip', action='store_true',
                   help='optional, assigns public instance ip addresses to each VM')
parser.add_argument('--deis', action='store_true',
                   help='optional, automatically opens http and controller endpoints')
parser.add_argument('--data-disk', action='store_true',
                   help='optional, attaches a data disk to each VM')
parser.add_argument('--nohttps', action='store_true',
                   help='optional, disables the creation of the https load balanced endpoint')
parser.add_argument('--no-discovery-url', action='store_true',
                   help='optional, disables the creation of a new coreos discovery url')

cloud_init_template = """#cloud-config

coreos:
  etcd:
    # generate a new token for each unique cluster from https://discovery.etcd.io/new
    discovery: {0}
    # deployments across multiple cloud services will need to use $public_ipv4
    addr: $private_ipv4:4001
    peer-addr: $private_ipv4:7001
  units:
    - name: etcd.service
      command: start
    - name: fleet.service
      command: start
"""

args = parser.parse_args()

# Create SSH cert if it's not given
if not args.ssh_cert and not args.ssh_thumb:
  print 'SSH arguments not given, generating certificate'
  with open(os.devnull, 'w') as shutup:
      subprocess.call('openssl req -x509 -nodes -days 365 -newkey rsa:2048 -config cert.conf -keyout ssh-cert.key -out ssh-cert.pem', shell=True, stdout=shutup, stderr=shutup)
      subprocess.call('chmod 600 ssh-cert.key', shell=True, stdout=shutup, stderr=shutup)
      subprocess.call('openssl  x509 -outform der -in ssh-cert.pem -out ssh-cert.cer', shell=True, stdout=shutup, stderr=shutup)
      thumbprint = subprocess.check_output('openssl x509 -in ssh-cert.pem -sha1 -noout -fingerprint | sed s/://g', shell=True)
      args.ssh_thumb = thumbprint.split('=')[1].replace('\n', '')
      args.ssh_cert = './ssh-cert.cer'
  print 'Generated SSH certificate with thumbprint ' + args.ssh_thumb

# generate coreos discovery url
if not args.no_discovery_url:
  print 'Generating new CoreOS discovery URL for the cluster'
  subprocess.call(['bash', '-c', 'python ./create-azure-user-data $(curl -s https://discovery.etcd.io/new)'])

# Setup custom data
if args.custom_data:
    with open(args.custom_data, 'r') as f:
      if not os.path.exists(args.custom_data):
        print "Couldn't find the user-data file. Did you remember to run `create-azure-user-data`?"
        sys.exit(1)
      cloud_init = f.read()
    f.closed
else:
    if args.discovery_service_url:
        cloud_init = cloud_init_template.format(args.discovery_service_url)
    else:
        response = urllib2.urlopen('https://discovery.etcd.io/new')
        discovery_url = response.read()
        cloud_init = cloud_init_template.format(discovery_url)

SERVICE_CERT_FORMAT = 'pfx'

with open(args.ssh_cert) as f:
    service_cert_file_data = base64.b64encode(f.read())
f.closed

def wait_for_async(request_id, timeout):
    count = 0
    result = sms.get_operation_status(request_id)
    while result.status == 'InProgress':
        count = count + 1
        if count > timeout:
            print('Timed out waiting for async operation to complete.')
            return
        time.sleep(5)
        print('.'),
        sys.stdout.flush()
        result = sms.get_operation_status(request_id)
        if result.error:
            print(result.error.code)
            print(vars(result.error))
    print result.status + ' in ' + str(count*5) + 's'

def linux_config(hostname, args):
    pk = PublicKey(args.ssh_thumb,
                   u'/home/core/.ssh/authorized_keys')
    system = LinuxConfigurationSet(hostname, 'core', None, True,
              custom_data=cloud_init)
    system.ssh.public_keys.public_keys.append(pk)
    system.disable_ssh_password_authentication = True
    return system

def lb_endpoint_config(name, port, probe=False, idle_timeout_minutes=4):
    endpoint = ConfigurationSetInputEndpoint(name, 'tcp', port, port, name, False, idle_timeout_minutes)
    if probe:
      endpoint.load_balancer_probe = probe
    return endpoint

def load_balancer_probe(path, port, protocol):
    load_balancer_probe = LoadBalancerProbe()
    load_balancer_probe.path = path
    load_balancer_probe.port = port
    load_balancer_probe.protocol = protocol
    return load_balancer_probe

def network_config(subnet_name=None, port='59913', public_ip_name=None):
    network = ConfigurationSet()
    network.configuration_set_type = 'NetworkConfiguration'
    network.input_endpoints.input_endpoints.append(
        ConfigurationSetInputEndpoint('ssh', 'tcp', port, '22'))
    if subnet_name:
        network.subnet_names.append(subnet_name)
    if public_ip_name:
        ip = PublicIP(name=public_ip_name)
        ip.idle_timeout_in_minutes = 20
        network.public_ips.public_ips.append(ip)
    if args.deis:
        # create web endpoint with probe checking /health-check
        network.input_endpoints.input_endpoints.append(lb_endpoint_config('web', '80', load_balancer_probe('/health-check', '80', 'http')))
        if not args.nohttps:
          network.input_endpoints.input_endpoints.append(lb_endpoint_config('https', '443', load_balancer_probe(None, '443', 'tcp')))
        # create builder endpoint TCP probe check and extended timeout
        network.input_endpoints.input_endpoints.append(lb_endpoint_config('builder', '2222', load_balancer_probe(None, '2222', 'tcp',), 20))
    return network

def data_hd(target_container_url, target_blob_name, target_lun, target_disk_size_in_gb):
    media_link = target_container_url + target_blob_name
    data_hd = DataVirtualHardDisk()
    data_hd.disk_label = target_blob_name
    data_hd.logical_disk_size_in_gb = target_disk_size_in_gb
    data_hd.lun = target_lun
    data_hd.media_link = media_link
    return data_hd

sms = ServiceManagementService(args.subscription, args.azure_cert)

#Create the cloud service
try:
  print 'Creating the hosted service...',
  sys.stdout.flush()
  if args.affinity_group:
    sms.create_hosted_service(
        args.cloud_service_name, label=args.cloud_service_name, affinity_group=args.affinity_group)
  else:
    sms.create_hosted_service(
        args.cloud_service_name, label=args.cloud_service_name, location=args.location)
  print('Successfully created hosted service ' + args.cloud_service_name)
  sys.stdout.flush()
  time.sleep(2)
except WindowsAzureConflictError:
  print "Hosted service {} already exists. Delete it or try again with a different name.".format(args.cloud_service_name)
  sys.exit(1)

#upload ssh cert to cloud-service
print 'Uploading SSH certificate...',
sys.stdout.flush()
result = sms.add_service_certificate(args.cloud_service_name,
                                     service_cert_file_data, SERVICE_CERT_FORMAT, '')
wait_for_async(result.request_id, 15)

def get_vm_name(args, i):
    return args.cloud_service_name + '-' + args.vm_name_prefix + '-' + str(i)

vms =[]

#Create the VMs
for i in range(args.num_nodes):
    ssh_port = args.ssh +i
    vm_name = get_vm_name(args, i)
    if args.pip:
        pip_name = vm_name
    else:
        pip_name = None
    media_link = args.blob_container_url + vm_name
    os_hd = OSVirtualHardDisk(media_link=media_link,
                            source_image_name=args.coreos_image)
    system = linux_config(vm_name, args)
    network = network_config(subnet_name=args.subnet_names, port=ssh_port, public_ip_name=pip_name)
    #specifiy the data disk, important to start at lun = 0
    if args.data_disk:
        data_disk = data_hd(args.blob_container_url, vm_name + '-data.vhd', 0, 100)
        data_disks = DataVirtualHardDisks()
        data_disks.data_virtual_hard_disks.append(data_disk)
    else:
        data_disks = None

    try:
      if i == 0:
          result = sms.create_virtual_machine_deployment(
                      args.cloud_service_name, deployment_name=args.cloud_service_name,
                      deployment_slot='production', label=vm_name,
                      role_name=vm_name, system_config=system, os_virtual_hard_disk=os_hd, virtual_network_name=args.virtual_network_name,
                      role_size=args.vm_size, network_config=network, data_virtual_hard_disks=data_disks)
      else:
          result = sms.add_role(
                      args.cloud_service_name, deployment_name=args.cloud_service_name,
                      role_name=vm_name,
                      system_config=system, os_virtual_hard_disk=os_hd,
                      role_size=args.vm_size, network_config=network, data_virtual_hard_disks=data_disks)
    except WindowsAzureError as e:
      if "Forbidden" in str(e):
        print "Unable to use this CoreOS image. This usually means a newer image has been published."
        print "See https://coreos.com/docs/running-coreos/cloud-providers/azure/ for the latest stable image,"
        print "and supply it to this script with --coreos-image. If it works, please open a pull request to update this script."
        sys.exit(1)
      else:
        pass

    print 'Creating VM ' + vm_name + '...',
    sys.stdout.flush()
    wait_for_async(result.request_id, 30)
    vms.append({'name':vm_name,
                'host':args.cloud_service_name + '.cloudapp.net',
                'port':ssh_port,
                'user':'core',
                'identity':args.ssh_cert.replace('.cer','.key')})

#get the ip addresses
def get_ips(service_name, deployment_name):
  try:
    result = sms.get_deployment_by_name(service_name, deployment_name)
    for instance in result.role_instance_list:
        ips.append(instance.public_ips[0].address)
    return ips
  except WindowsAzureMissingResourceError:
    # some helpful user error info
    print 'Could not find cloud service ip address. This is likely due to the fact that the'
    print 'cloud service failed to start. Check that the storage account for the'
    print '--blob-container-url argument exists, ends with a \'/\' and that there is a container named \'vhds\' '
    print 'within it. You may need to delete the cloud service \'' + service_name + '\' if you try'
    print 'this script again'
    sys.exit(1)
#print dns config
if args.pip:
    ips = []
    ips = get_ips(args.cloud_service_name, args.cloud_service_name)
    print ''
    print '-------'
    print "You'll need to configure DNS records for a domain you wish to use with your Deis cluster."
    print 'For convenience, the public IP addresses are printed below, along with sane DNS timeouts.'
    print ''
    for ip in ips:
        print '@ 10800 IN A ' + ip
    print '* 10800 IN CNAME @'
    print '-------'
    print 'For more information, see: http://docs.deis.io/en/latest/managing_deis/configure-dns/'
    print ''

#print ~/.ssh/config
print ''
print '-------'
print "Instances on Azure don't use typical SSH ports. It is recommended to configure ~/.ssh/config"
print 'so the instances can easily be referenced when logging in via SSH. For convenience, the config'
print 'directives for your instances are below:'
print ''
for vm in vms:
    print 'Host ' + vm['name']
    print '    HostName ' + vm['host']
    print '    Port ' + str(vm['port'])
    print '    User ' + vm['user']
    print '    IdentityFile ' + vm['identity']
print '-------'
print ''
