csharp/A7ocin/PPOL/unity-environment/Assets/ML-Agents/Scripts/ExternalCommunicator.cs

ExternalCommunicator.cs
using System.Collections;
using System.Collections.Generic;
using UnityEngine;

using Newtonsoft.Json;
using System.Linq;
using System.Net.Sockets;
using System.Text;
using System.IO;


/// Responsible for communication with Python API.
public clast ExternalCommunicator : Communicator
{

    Academy academy;

    Dictionary current_agents;

    List brains;

    Dictionary hastentState;

    Dictionary storedActions;
    Dictionary storedMemories;
    Dictionary storedValues;

    private int comPort;
    Socket sender;
    byte[] messageHolder;

    const int messageLength = 12000;

    StreamWriter logWriter;
    string logPath;

    const string api = "API-2";

    private clast StepMessage
    {
        public string brain_name { get; set; }

        public List agents { get; set; }

        public List states { get; set; }

        public List rewards { get; set; }

        public List actions { get; set; }

        public List memories { get; set; }

        public List dones { get; set; }
    }

    private clast AgentMessage
    {
        public Dictionary action { get; set; }

        public Dictionary memory { get; set; }

        public Dictionary value { get; set; }

    }

    private clast ResetParametersMessage
    {
        public Dictionary parameters { get; set; }

        public bool train_model { get; set; }
    }

    /// Consrtuctor for the External Communicator
    public ExternalCommunicator(Academy aca)
    {
        academy = aca;
        brains = new List();
        current_agents = new Dictionary();

        hastentState = new Dictionary();

        storedActions = new Dictionary();
        storedMemories = new Dictionary();
        storedValues = new Dictionary();
    }

    /// Adds the brain to the list of brains which have already decided their
    /// actions.
    public void SubscribeBrain(Brain brain)
    {
        brains.Add(brain);
        hastentState[brain.gameObject.name] = false;
    }


    public bool CommunicatorHandShake(){
        try
        {
            ReadArgs();
        }
        catch
        {
            return false;
        }
        return true;
    }

    /// Contains the logic for the initializtation of the socket.
    public void InitializeCommunicator()
    {
        Application.logMessageReceived += HandleLog;
        logPath = Path.GetFullPath(".") + "/unity-environment.log";
        logWriter = new StreamWriter(logPath, false);
        logWriter.WriteLine(System.DateTime.Now.ToString());
        logWriter.WriteLine(" ");
        logWriter.Close();
        messageHolder = new byte[messageLength];

        // Create a TCP/IP  socket.  
        sender = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
        sender.Connect("localhost", comPort);

        AcademyParameters accParamerters = new AcademyParameters();
        accParamerters.brainParameters = new List();
        accParamerters.brainNames = new List();
        accParamerters.externalBrainNames = new List();
        accParamerters.apiNumber = api;
        accParamerters.logPath = logPath;
        foreach (Brain b in brains)
        {
            accParamerters.brainParameters.Add(b.brainParameters);
            accParamerters.brainNames.Add(b.gameObject.name);
            if (b.brainType == BrainType.External)
            {
                accParamerters.externalBrainNames.Add(b.gameObject.name);
            }
        }
        accParamerters.AcademyName = academy.gameObject.name;
        accParamerters.resetParameters = academy.resetParameters;

        SendParameters(accParamerters);
    }

	void HandleLog(string logString, string stackTrace, LogType type)
	{
        logWriter = new StreamWriter(logPath, true);
        logWriter.WriteLine(type.ToString());
        logWriter.WriteLine(logString);
        logWriter.WriteLine(stackTrace);
        logWriter.Close();
	}

    /// Listens to the socket for a command and returns the corresponding
    ///  External Command.
    public ExternalCommand GetCommand()
    {
        int location = sender.Receive(messageHolder);
        string message = Encoding.ASCII.GetString(messageHolder, 0, location);
        switch (message)
        {
            case "STEP":
                return ExternalCommand.STEP;
            case "RESET":
                return ExternalCommand.RESET;
            case "QUIT":
                return ExternalCommand.QUIT;
            default:
                return ExternalCommand.QUIT;
        }
    }

    /// Listens to the socket for the new resetParameters
    public Dictionary GetResetParameters()
    {
        sender.Send(Encoding.ASCII.GetBytes("CONFIG_REQUEST"));
        ResetParametersMessage resetParams = JsonConvert.DeserializeObject(Receive());
        academy.isInference = !resetParams.train_model;
        return resetParams.parameters;
    }


    /// Used to read Python-provided environment parameters
    private void ReadArgs()
    {
        string[] args = System.Environment.GetCommandLineArgs();
        string inputPort = "";
        for (int i = 0; i < args.Length; i++)
        {
            if (args[i] == "--port")
            {
                inputPort = args[i + 1];
            }
        }

        comPort = int.Parse(inputPort);
    }

    /// Sends Academy parameters to external agent
    private void SendParameters(AcademyParameters envParams)
    {
        string envMessage = JsonConvert.SerializeObject(envParams, Formatting.Indented);
        sender.Send(Encoding.ASCII.GetBytes(envMessage));
    }

    /// Receives messages from external agent
    private string Receive()
    {
        int location = sender.Receive(messageHolder);
        string message = Encoding.ASCII.GetString(messageHolder, 0, location);
        return message;
    }


    /// Ends connection and closes environment
    private void OnApplicationQuit()
    {
        sender.Close();
        sender.Shutdown(SocketShutdown.Both);
    }

    /// Contains logic for coverting texture into bytearray to send to 
    /// external agent.
    private byte[] TexToByteArray(Texture2D tex)
    {
        byte[] bytes = tex.EncodeToPNG();
        Object.DestroyImmediate(tex);
        Resources.UnloadUnusedastets();
        return bytes;
    }

    private byte[] AppendLength(byte[] input){
        byte[] newArray = new byte[input.Length + 4];
        input.CopyTo(newArray, 4);
        System.BitConverter.GetBytes(input.Length).CopyTo(newArray, 0);
        return newArray;
    }

    /// Collects the information from the brains and sends it accross the socket
    public void giveBrainInfo(Brain brain)
    {
        string brainName = brain.gameObject.name;
        current_agents[brainName] = new List(brain.agents.Keys);
        List concatenatedStates = new List();
        List concatenatedRewards = new List();
        List concatenatedMemories = new List();
        List concatenatedDones = new List();
        List concatenatedActions = new List();
        Dictionary collectedObservations = brain.CollectObservations();
        Dictionary collectedStates = brain.CollectStates();
        Dictionary collectedRewards = brain.CollectRewards();
        Dictionary collectedMemories = brain.CollectMemories();
        Dictionary collectedDones = brain.CollectDones();
        Dictionary collectedActions = brain.CollectActions();

        foreach (int id in current_agents[brainName])
        {
            concatenatedStates = concatenatedStates.Concat(collectedStates[id]).ToList();
            concatenatedRewards.Add(collectedRewards[id]);
            concatenatedMemories = concatenatedMemories.Concat(collectedMemories[id].ToList()).ToList();
            concatenatedDones.Add(collectedDones[id]);
            concatenatedActions = concatenatedActions.Concat(collectedActions[id].ToList()).ToList();
        }
        StepMessage message = new StepMessage()
        {
            brain_name = brainName,
            agents = current_agents[brainName],
            states = concatenatedStates,
            rewards = concatenatedRewards,
            actions = concatenatedActions,
            memories = concatenatedMemories,
            dones = concatenatedDones
        };
        string envMessage = JsonConvert.SerializeObject(message, Formatting.Indented);
        sender.Send(AppendLength(Encoding.ASCII.GetBytes(envMessage)));
        Receive();
        int i = 0;
        foreach (resolution res in brain.brainParameters.cameraResolutions)
        {
            foreach (int id in current_agents[brainName])
            {
                sender.Send(AppendLength(TexToByteArray(brain.ObservationToTex(collectedObservations[id][i], res.width, res.height))));
                Receive();
            }
            i++;
        }

        hastentState[brainName] = true;

        if (hastentState.Values.All(x => x))
        {
            // if all the brains listed have sent their state
            sender.Send(Encoding.ASCII.GetBytes((academy.done ? "True" : "False")));
            List brainNames = hastentState.Keys.ToList();
            foreach (string k in brainNames)
            {
                hastentState[k] = false;
            }
        }

    }

    /// Listens for actions, memories, and values and sends them 
    /// to the corrensponding brains.
    public void UpdateActions()
    {
        // TO MODIFY	--------------------------------------------
        sender.Send(Encoding.ASCII.GetBytes("STEPPING"));
        string a = Receive();
        AgentMessage agentMessage = JsonConvert.DeserializeObject(a);

        foreach (Brain brain in brains)
        {
            if (brain.brainType == BrainType.External)
            {
                string brainName = brain.gameObject.name;

                Dictionary actionDict = new Dictionary();
                for (int i = 0; i < current_agents[brainName].Count; i++)
                {
                    if (brain.brainParameters.actionSpaceType == StateType.continuous)
                    {
                        actionDict.Add(current_agents[brainName][i],
                            agentMessage.action[brainName].GetRange(i * brain.brainParameters.actionSize, brain.brainParameters.actionSize).ToArray());
                    }
                    else
                    {
                        actionDict.Add(current_agents[brainName][i],
                            agentMessage.action[brainName].GetRange(i, 1).ToArray());
                    }
                }
                storedActions[brainName] = actionDict;

                Dictionary memoryDict = new Dictionary();
                for (int i = 0; i < current_agents[brainName].Count; i++)
                {
                    memoryDict.Add(current_agents[brainName][i],
                        agentMessage.memory[brainName].GetRange(i * brain.brainParameters.memorySize, brain.brainParameters.memorySize).ToArray());
                }
                storedMemories[brainName] = memoryDict;

                Dictionary valueDict = new Dictionary();
                for (int i = 0; i < current_agents[brainName].Count; i++)
                {
                    valueDict.Add(current_agents[brainName][i],
                        agentMessage.value[brainName][i]);
                }
                storedValues[brainName] = valueDict;
            }

        }
    }

    /// Returns the actions corrensponding to the brain called brainName that 
    /// were received throught the socket.
    public Dictionary GetDecidedAction(string brainName)
    {
        return storedActions[brainName];
    }

    /// Returns the memories corrensponding to the brain called brainName that
    ///  were received throught the socket.
    public Dictionary GetMemories(string brainName)
    {
        return storedMemories[brainName];
    }

    /// Returns the values corrensponding to the brain called brainName that 
    /// were received throught the socket.
    public Dictionary GetValues(string brainName)
    {
        return storedValues[brainName];
    }

}