// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc.  All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
//     * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
//     * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
//     * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#include <thirdparty/protobuf/compiler/python/pyi_generator.h>

#include <string>

#include <thirdparty/protobuf/stubs/strutil.h>
#include <thirdparty/protobuf/compiler/python/helpers.h>
#include <thirdparty/protobuf/descriptor.h>
#include <thirdparty/protobuf/descriptor.pb.h>
#include <thirdparty/protobuf/io/printer.h>
#include <thirdparty/protobuf/io/zero_copy_stream.h>

namespace google {
namespace protobuf {
namespace compiler {
namespace python {

template <typename DescriptorT>
struct SortByName {
  bool operator()(const DescriptorT* l, const DescriptorT* r) const {
    return l->name() < r->name();
  }
};

PyiGenerator::PyiGenerator() : file_(nullptr) {}

PyiGenerator::~PyiGenerator() {}

void PyiGenerator::PrintItemMap(
    const std::map<std::string, std::string>& item_map) const {
  for (const auto& entry : item_map) {
    printer_->Print("$key$: $value$\n", "key", entry.first, "value",
                    entry.second);
  }
}

template <typename DescriptorT>
std::string PyiGenerator::ModuleLevelName(
    const DescriptorT& descriptor,
    const std::map<std::string, std::string>& import_map) const {
  std::string name = NamePrefixedWithNestedTypes(descriptor, ".");
  if (descriptor.file() != file_) {
    std::string module_alias;
    std::string filename = descriptor.file()->name();
    if (import_map.find(filename) == import_map.end()) {
      std::string module_name = ModuleName(descriptor.file()->name());
      std::vector<std::string> tokens = Split(module_name, ".");
      module_alias = "_" + tokens.back();
    } else {
      module_alias = import_map.at(filename);
    }
    name = module_alias + "." + name;
  }
  return name;
}

struct ImportModules {
  bool has_repeated = false;    // _containers
  bool has_iterable = false;    // typing.Iterable
  bool has_messages = false;    // _message
  bool has_enums = false;       // _enum_type_wrapper
  bool has_extendable = false;  // _python_message
  bool has_mapping = false;     // typing.Mapping
  bool has_optional = false;    // typing.Optional
  bool has_union = false;       // typing.Union
  bool has_well_known_type = false;
};

// Checks whether a descriptor name matches a well-known type.
bool IsWellKnownType(const std::string& name) {
  // LINT.IfChange(wktbases)
  return (name == "google.protobuf.Any" ||
          name == "google.protobuf.Duration" ||
          name == "google.protobuf.FieldMask" ||
          name == "google.protobuf.ListValue" ||
          name == "google.protobuf.Struct" ||
          name == "google.protobuf.Timestamp");
  // LINT.ThenChange(//depot/google3/net/proto2/python/internal/well_known_types.py:wktbases)
}

// Checks what modules should be imported for this message
// descriptor.
void CheckImportModules(const Descriptor* descriptor,
                        ImportModules* import_modules) {
  if (descriptor->extension_range_count() > 0) {
    import_modules->has_extendable = true;
  }
  if (descriptor->enum_type_count() > 0) {
    import_modules->has_enums = true;
  }
  if (IsWellKnownType(descriptor->full_name())) {
    import_modules->has_well_known_type = true;
  }
  for (int i = 0; i < descriptor->field_count(); ++i) {
    const FieldDescriptor* field = descriptor->field(i);
    if (IsPythonKeyword(field->name())) {
      continue;
    }
    import_modules->has_optional = true;
    if (field->is_repeated()) {
      import_modules->has_repeated = true;
    }
    if (field->is_map()) {
      import_modules->has_mapping = true;
      const FieldDescriptor* value_des = field->message_type()->field(1);
      if (value_des->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE ||
          value_des->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) {
        import_modules->has_union = true;
      }
    } else {
      if (field->is_repeated()) {
        import_modules->has_iterable = true;
      }
      if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
        import_modules->has_union = true;
        import_modules->has_mapping = true;
      }
      if (field->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) {
        import_modules->has_union = true;
      }
    }
  }
  for (int i = 0; i < descriptor->nested_type_count(); ++i) {
    CheckImportModules(descriptor->nested_type(i), import_modules);
  }
}

void PyiGenerator::PrintImportForDescriptor(
    const FileDescriptor& desc,
    std::map<std::string, std::string>* import_map,
    std::set<std::string>* seen_aliases) const {
  const std::string& filename = desc.name();
  std::string module_name = StrippedModuleName(filename);
  size_t last_dot_pos = module_name.rfind('.');
  std::string import_statement;
  if (last_dot_pos == std::string::npos) {
    import_statement = "import " + module_name;
  } else {
    import_statement = "from " + module_name.substr(0, last_dot_pos) +
                       " import " + module_name.substr(last_dot_pos + 1);
    module_name = module_name.substr(last_dot_pos + 1);
  }
  std::string alias = "_" + module_name;
  // Generate a unique alias by adding _1 suffixes until we get an unused alias.
  while (seen_aliases->find(alias) != seen_aliases->end()) {
    alias = alias + "_1";
  }
  printer_->Print("$statement$ as $alias$\n", "statement",
                  import_statement, "alias", alias);
  (*import_map)[filename] = alias;
  seen_aliases->insert(alias);
}

void PyiGenerator::PrintImports(
    std::map<std::string, std::string>* item_map,
    std::map<std::string, std::string>* import_map) const {
  // Prints imported dependent _pb2 files.
  std::set<std::string> seen_aliases;
  for (int i = 0; i < file_->dependency_count(); ++i) {
    const FileDescriptor* dep = file_->dependency(i);
    PrintImportForDescriptor(*dep, import_map, &seen_aliases);
    for (int j = 0; j < dep->public_dependency_count(); ++j) {
      PrintImportForDescriptor(
          *dep->public_dependency(j), import_map, &seen_aliases);
    }
  }

  // Checks what modules should be imported.
  ImportModules import_modules;
  if (file_->message_type_count() > 0) {
    import_modules.has_messages = true;
  }
  if (file_->enum_type_count() > 0) {
    import_modules.has_enums = true;
  }
  for (int i = 0; i < file_->message_type_count(); i++) {
    CheckImportModules(file_->message_type(i), &import_modules);
  }

  // Prints modules (e.g. _containers, _messages, typing) that are
  // required in the proto file.
  if (import_modules.has_repeated) {
    printer_->Print(
        "from google.protobuf.internal import containers as "
        "_containers\n");
  }
  if (import_modules.has_enums) {
    printer_->Print(
        "from google.protobuf.internal import enum_type_wrapper"
        " as _enum_type_wrapper\n");
  }
  if (import_modules.has_extendable) {
    printer_->Print(
        "from google.protobuf.internal import python_message"
        " as _python_message\n");
  }
  if (import_modules.has_well_known_type) {
    printer_->Print(
        "from google.protobuf.internal import well_known_types"
        " as _well_known_types\n");
  }
  printer_->Print(
      "from google.protobuf import"
      " descriptor as _descriptor\n");
  if (import_modules.has_messages) {
    printer_->Print(
        "from google.protobuf import message as _message\n");
  }
  if (HasGenericServices(file_)) {
    printer_->Print(
        "from google.protobuf import service as"
        " _service\n");
  }
  printer_->Print("from typing import ");
  printer_->Print("ClassVar as _ClassVar");
  if (import_modules.has_iterable) {
    printer_->Print(", Iterable as _Iterable");
  }
  if (import_modules.has_mapping) {
    printer_->Print(", Mapping as _Mapping");
  }
  if (import_modules.has_optional) {
    printer_->Print(", Optional as _Optional");
  }
  if (import_modules.has_union) {
    printer_->Print(", Union as _Union");
  }
  printer_->Print("\n\n");

  // Public imports
  for (int i = 0; i < file_->public_dependency_count(); ++i) {
    const FileDescriptor* public_dep = file_->public_dependency(i);
    std::string module_name = StrippedModuleName(public_dep->name());
    // Top level messages in public imports
    for (int i = 0; i < public_dep->message_type_count(); ++i) {
      printer_->Print("from $module$ import $message_class$\n", "module",
                      module_name, "message_class",
                      public_dep->message_type(i)->name());
    }
    // Top level enums for public imports
    for (int i = 0; i < public_dep->enum_type_count(); ++i) {
      printer_->Print("from $module$ import $enum_class$\n", "module",
                      module_name, "enum_class",
                      public_dep->enum_type(i)->name());
    }
    // Enum values for public imports
    for (int i = 0; i < public_dep->enum_type_count(); ++i) {
      const EnumDescriptor* enum_descriptor = public_dep->enum_type(i);
      for (int j = 0; j < enum_descriptor->value_count(); ++j) {
        (*item_map)[enum_descriptor->value(j)->name()] =
            ModuleLevelName(*enum_descriptor, *import_map);
      }
    }
    // Top level extensions for public imports
    AddExtensions(*public_dep, item_map);
  }
}

void PyiGenerator::PrintEnum(const EnumDescriptor& enum_descriptor) const {
  std::string enum_name = enum_descriptor.name();
  printer_->Print(
      "class $enum_name$(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):\n"
      "    __slots__ = []\n",
      "enum_name", enum_name);
}

// Adds enum value to item map which will be ordered and printed later.
void PyiGenerator::AddEnumValue(
    const EnumDescriptor& enum_descriptor,
    std::map<std::string, std::string>* item_map,
    const std::map<std::string, std::string>& import_map) const {
  // enum values
  std::string module_enum_name = ModuleLevelName(enum_descriptor, import_map);
  for (int j = 0; j < enum_descriptor.value_count(); ++j) {
    const EnumValueDescriptor* value_descriptor = enum_descriptor.value(j);
    (*item_map)[value_descriptor->name()] = module_enum_name;
  }
}

// Prints top level enums
void PyiGenerator::PrintTopLevelEnums() const {
  for (int i = 0; i < file_->enum_type_count(); ++i) {
    printer_->Print("\n");
    PrintEnum(*file_->enum_type(i));
  }
}

// Add top level extensions to item_map which will be ordered and
// printed later.
template <typename DescriptorT>
void PyiGenerator::AddExtensions(
    const DescriptorT& descriptor,
    std::map<std::string, std::string>* item_map) const {
  for (int i = 0; i < descriptor.extension_count(); ++i) {
    const FieldDescriptor* extension_field = descriptor.extension(i);
    std::string constant_name = extension_field->name() + "_FIELD_NUMBER";
    ToUpper(&constant_name);
    (*item_map)[constant_name] = "_ClassVar[int]";
    (*item_map)[extension_field->name()] = "_descriptor.FieldDescriptor";
  }
}

// Returns the string format of a field's cpp_type
std::string PyiGenerator::GetFieldType(
    const FieldDescriptor& field_des, const Descriptor& containing_des,
    const std::map<std::string, std::string>& import_map) const {
  switch (field_des.cpp_type()) {
    case FieldDescriptor::CPPTYPE_INT32:
    case FieldDescriptor::CPPTYPE_UINT32:
    case FieldDescriptor::CPPTYPE_INT64:
    case FieldDescriptor::CPPTYPE_UINT64:
      return "int";
    case FieldDescriptor::CPPTYPE_DOUBLE:
    case FieldDescriptor::CPPTYPE_FLOAT:
      return "float";
    case FieldDescriptor::CPPTYPE_BOOL:
      return "bool";
    case FieldDescriptor::CPPTYPE_ENUM:
      return ModuleLevelName(*field_des.enum_type(), import_map);
    case FieldDescriptor::CPPTYPE_STRING:
      if (field_des.type() == FieldDescriptor::TYPE_STRING) {
        return "str";
      } else {
        return "bytes";
      }
    case FieldDescriptor::CPPTYPE_MESSAGE: {
      // If the field is inside a nested message and the nested message has the
      // same name as a top-level message, then we need to prefix the field type
      // with the module name for disambiguation.
      std::string name = ModuleLevelName(*field_des.message_type(), import_map);
      if ((containing_des.containing_type() != nullptr &&
           name == containing_des.name())) {
        std::string module = ModuleName(field_des.file()->name());
        name = module + "." + name;
      }
      return name;
    }
    default:
      GOOGLE_LOG(FATAL) << "Unsupported field type.";
  }
  return "";
}

void PyiGenerator::PrintMessage(
    const Descriptor& message_descriptor, bool is_nested,
    const std::map<std::string, std::string>& import_map) const {
  if (!is_nested) {
    printer_->Print("\n");
  }
  std::string class_name = message_descriptor.name();
  std::string extra_base;
  // A well-known type needs to inherit from its corresponding base class in
  // net/proto2/python/internal/well_known_types.
  if (IsWellKnownType(message_descriptor.full_name())) {
    extra_base = ", _well_known_types." + message_descriptor.name();
  } else {
    extra_base = "";
  }
  printer_->Print("class $class_name$(_message.Message$extra_base$):\n",
                  "class_name", class_name, "extra_base", extra_base);
  printer_->Indent();
  printer_->Indent();

  std::vector<const FieldDescriptor*> fields;
  fields.reserve(message_descriptor.field_count());
  for (int i = 0; i < message_descriptor.field_count(); ++i) {
    fields.push_back(message_descriptor.field(i));
  }
  std::sort(fields.begin(), fields.end(), SortByName<FieldDescriptor>());

  // Prints slots
  printer_->Print("__slots__ = [", "class_name", class_name);
  bool first_item = true;
  for (const auto& field_des : fields) {
    if (IsPythonKeyword(field_des->name())) {
      continue;
    }
    if (first_item) {
      first_item = false;
    } else {
      printer_->Print(", ");
    }
    printer_->Print("\"$field_name$\"", "field_name", field_des->name());
  }
  printer_->Print("]\n");

  std::map<std::string, std::string> item_map;
  // Prints Extensions for extendable messages
  if (message_descriptor.extension_range_count() > 0) {
    item_map["Extensions"] = "_python_message._ExtensionDict";
  }

  // Prints nested enums
  std::vector<const EnumDescriptor*> nested_enums;
  nested_enums.reserve(message_descriptor.enum_type_count());
  for (int i = 0; i < message_descriptor.enum_type_count(); ++i) {
    nested_enums.push_back(message_descriptor.enum_type(i));
  }
  std::sort(nested_enums.begin(), nested_enums.end(),
            SortByName<EnumDescriptor>());

  for (const auto& entry : nested_enums) {
    PrintEnum(*entry);
    // Adds enum value to item_map which will be ordered and printed later
    AddEnumValue(*entry, &item_map, import_map);
  }

  // Prints nested messages
  std::vector<const Descriptor*> nested_messages;
  nested_messages.reserve(message_descriptor.nested_type_count());
  for (int i = 0; i < message_descriptor.nested_type_count(); ++i) {
    nested_messages.push_back(message_descriptor.nested_type(i));
  }
  std::sort(nested_messages.begin(), nested_messages.end(),
            SortByName<Descriptor>());

  for (const auto& entry : nested_messages) {
    PrintMessage(*entry, true, import_map);
  }

  // Adds extensions to item_map which will be ordered and printed later
  AddExtensions(message_descriptor, &item_map);

  // Adds field number and field descriptor to item_map
  for (int i = 0; i < message_descriptor.field_count(); ++i) {
    const FieldDescriptor& field_des = *message_descriptor.field(i);
    item_map[ToUpper(field_des.name()) + "_FIELD_NUMBER"] =
        "_ClassVar[int]";
    if (IsPythonKeyword(field_des.name())) {
      continue;
    }
    std::string field_type = "";
    if (field_des.is_map()) {
      const FieldDescriptor* key_des = field_des.message_type()->field(0);
      const FieldDescriptor* value_des = field_des.message_type()->field(1);
      field_type = (value_des->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE
                        ? "_containers.MessageMap["
                        : "_containers.ScalarMap[");
      field_type += GetFieldType(*key_des, message_descriptor, import_map);
      field_type += ", ";
      field_type += GetFieldType(*value_des, message_descriptor, import_map);
    } else {
      if (field_des.is_repeated()) {
        field_type = (field_des.cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE
                          ? "_containers.RepeatedCompositeFieldContainer["
                          : "_containers.RepeatedScalarFieldContainer[");
      }
      field_type += GetFieldType(field_des, message_descriptor, import_map);
    }

    if (field_des.is_repeated()) {
      field_type += "]";
    }
    item_map[field_des.name()] = field_type;
  }

  // Prints all items in item_map
  PrintItemMap(item_map);

  // Prints __init__
  printer_->Print("def __init__(self");
  bool has_key_words = false;
  bool is_first = true;
  for (int i = 0; i < message_descriptor.field_count(); ++i) {
    const FieldDescriptor* field_des = message_descriptor.field(i);
    if (IsPythonKeyword(field_des->name())) {
      has_key_words = true;
      continue;
    }
    std::string field_name = field_des->name();
    if (is_first && field_name == "self") {
      // See b/144146793 for an example of real code that generates a (self,
      // self) method signature. Since repeating a parameter name is illegal in
      // Python, we rename the duplicate self.
      field_name = "self_";
    }
    is_first = false;
    printer_->Print(", $field_name$: ", "field_name", field_name);
    if (field_des->is_repeated() ||
        field_des->cpp_type() != FieldDescriptor::CPPTYPE_BOOL) {
      printer_->Print("_Optional[");
    }
    if (field_des->is_map()) {
      const Descriptor* map_entry = field_des->message_type();
      printer_->Print(
          "_Mapping[$key_type$, $value_type$]", "key_type",
          GetFieldType(*map_entry->field(0), message_descriptor, import_map),
          "value_type",
          GetFieldType(*map_entry->field(1), message_descriptor, import_map));
    } else {
      if (field_des->is_repeated()) {
        printer_->Print("_Iterable[");
      }
      if (field_des->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
        printer_->Print(
            "_Union[$type_name$, _Mapping]", "type_name",
            GetFieldType(*field_des, message_descriptor, import_map));
      } else {
        if (field_des->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) {
          printer_->Print("_Union[$type_name$, str]", "type_name",
                          ModuleLevelName(*field_des->enum_type(), import_map));
        } else {
          printer_->Print(
              "$type_name$", "type_name",
              GetFieldType(*field_des, message_descriptor, import_map));
        }
      }
      if (field_des->is_repeated()) {
        printer_->Print("]");
      }
    }
    if (field_des->is_repeated() ||
        field_des->cpp_type() != FieldDescriptor::CPPTYPE_BOOL) {
      printer_->Print("]");
    }
    printer_->Print(" = ...");
  }
  if (has_key_words) {
    printer_->Print(", **kwargs");
  }
  printer_->Print(") -> None: ...\n");

  printer_->Outdent();
  printer_->Outdent();
}

void PyiGenerator::PrintMessages(
    const std::map<std::string, std::string>& import_map) const {
  // Deterministically order the descriptors.
  std::vector<const Descriptor*> messages;
  messages.reserve(file_->message_type_count());
  for (int i = 0; i < file_->message_type_count(); ++i) {
    messages.push_back(file_->message_type(i));
  }
  std::sort(messages.begin(), messages.end(), SortByName<Descriptor>());

  for (const auto& entry : messages) {
    PrintMessage(*entry, false, import_map);
  }
}

void PyiGenerator::PrintServices() const {
  std::vector<const ServiceDescriptor*> services;
  services.reserve(file_->service_count());
  for (int i = 0; i < file_->service_count(); ++i) {
    services.push_back(file_->service(i));
  }
  std::sort(services.begin(), services.end(), SortByName<ServiceDescriptor>());

  // Prints $Service$ and $Service$_Stub classes
  for (const auto& entry : services) {
    printer_->Print("\n");
    printer_->Print(
        "class $service_name$(_service.service): ...\n\n"
        "class $service_name$_Stub($service_name$): ...\n",
        "service_name", entry->name());
  }
}

bool PyiGenerator::Generate(const FileDescriptor* file,
                            const std::string& parameter,
                            GeneratorContext* context,
                            std::string* error) const {
  MutexLock lock(&mutex_);
  // Calculate file name.
  file_ = file;
  std::string filename =
      parameter.empty() ? GetFileName(file, ".pyi") : parameter;

  std::unique_ptr<io::ZeroCopyOutputStream> output(context->Open(filename));
  GOOGLE_CHECK(output.get());
  io::Printer printer(output.get(), '$');
  printer_ = &printer;

  // item map will store "DESCRIPTOR", top level extensions, top level enum
  // values. The items will be sorted and printed later.
  std::map<std::string, std::string> item_map;

  // Adds "DESCRIPTOR" into item_map.
  item_map["DESCRIPTOR"] = "_descriptor.FileDescriptor";

  // import_map will be a mapping from filename to module alias, e.g.
  // "google3/foo/bar.py" -> "_bar"
  std::map<std::string, std::string> import_map;

  PrintImports(&item_map, &import_map);
  // Adds top level enum values to item_map.
  for (int i = 0; i < file_->enum_type_count(); ++i) {
    AddEnumValue(*file_->enum_type(i), &item_map, import_map);
  }
  // Adds top level extensions to item_map.
  AddExtensions(*file_, &item_map);
  // Prints item map
  PrintItemMap(item_map);

  PrintMessages(import_map);
  PrintTopLevelEnums();
  if (HasGenericServices(file)) {
    PrintServices();
  }
  return true;
}

}  // namespace python
}  // namespace compiler
}  // namespace protobuf
}  // namespace google