Skip to content

cohere

CohereGen class for text generation using Cohere's API.

Classes:

  • CohereGen

    Cohere generation model for text generation.

CohereGen

CohereGen(
    model_name: Optional[str] = None,
    temperature: Optional[float] = None,
    prompt_template: str = '',
    output_max_length: int = 500,
    device: str = 'auto',
    structured_output: Optional[Type[BaseModel]] = None,
    system_message: str = '',
    api_params: dict[str, Any] = DEFAULT_API_PARAMS,
    api_key: str = '',
    cache: Cache | None = None,
    logs: dict[str, Any] | None = None,
)

Bases: GenerationBase

Cohere generation model for text generation.

Methods:

  • apply

    Apply attached configuration to the step.

  • generate

    Generate text using Cohere's API.

  • process

    Generate a result from the current pipeline content.

Source code in src/rago/generation/base.py
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def __init__(
    self,
    model_name: Optional[str] = None,
    temperature: Optional[float] = None,
    prompt_template: str = '',
    output_max_length: int = 500,
    device: str = 'auto',
    structured_output: Optional[Type[BaseModel]] = None,
    system_message: str = '',
    api_params: dict[str, Any] = DEFAULT_API_PARAMS,
    api_key: str = '',
    cache: Cache | None = None,
    logs: dict[str, Any] | None = None,
) -> None:
    super().__init__()
    self.api_key = api_key
    self.cache = cache
    self.logs = logs if logs is not None else {}

    self.model_name = (
        model_name if model_name is not None else self.default_model_name
    )
    self.output_max_length = (
        output_max_length or self.default_output_max_length
    )
    self.temperature = (
        temperature
        if temperature is not None
        else self.default_temperature
    )
    self.prompt_template = prompt_template or self.default_prompt_template
    self.structured_output = structured_output
    if api_params is DEFAULT_API_PARAMS:
        api_params = deepcopy(self.default_api_params or {})

    self.system_message = system_message
    self.api_params = api_params

    if device not in ['cpu', 'cuda', 'auto']:
        raise Exception(
            f'Device {device} not supported. Options: cpu, cuda, auto.'
        )

    cuda_available = torch.cuda.is_available()
    self.device_name = (
        'cpu' if device == 'cpu' or not cuda_available else 'cuda'
    )
    self.device = torch.device(self.device_name)

    self._validate()
    self._load_optional_modules()
    self._setup()

apply

apply(parameters: Any) -> None

Apply attached configuration to the step.

Source code in src/rago/base.py
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
def apply(self, parameters: Any) -> None:
    """Apply attached configuration to the step."""
    if parameters is None:
        return

    if _is_cache_backend(parameters):
        self.cache = parameters
        return

    if _is_vector_db(parameters):
        setattr(self, 'db', parameters)
        return

    if _is_text_splitter(parameters):
        setattr(self, 'splitter', parameters)
        return

    for key, value in config_to_dict(parameters).items():
        if key == 'cache':
            self.cache = value
        elif key == 'logs':
            self.logs = value if value is not None else {}
        else:
            setattr(self, key, value)

generate

generate(query: str, data: list[str]) -> str | BaseModel

Generate text using Cohere's API.

Source code in src/rago/generation/cohere.py
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
def generate(self, query: str, data: list[str]) -> str | BaseModel:
    """Generate text using Cohere's API."""
    input_text = self._format_prompt(query, data)
    api_params = self.api_params or self.default_api_params

    if self.structured_output:
        messages = []
        # Explicit instruction to generate JSON output.
        system_instruction = (
            'Generate a JSON object that strictly follows the provided  '
            'JSON schema. Do not include any additional text.'
        )
        if self.system_message:
            system_instruction += ' ' + self.system_message
        messages.append({'role': 'system', 'content': system_instruction})
        messages.append({'role': 'user', 'content': input_text})

        response_format_config = {
            'type': 'json_object',
            'json_schema': (
                self.structured_output
                if isinstance(self.structured_output, dict)
                else self.structured_output.model_json_schema()
            ),
        }
        model_params = {
            'messages': messages,
            'max_tokens': self.output_max_length,
            'temperature': self.temperature,
            'model': self.model_name,
            'response_format': response_format_config,
            **api_params,
        }

        response = self.model.client.chat(**model_params)
        # self.logs['model_params'] = model_params
        json_text = response.message.content[0].text
        parsed_dict = json.loads(json_text)
        parsed_model = self.structured_output(**parsed_dict)
        return parsed_model

    if self.system_message:
        messages = [
            {'role': 'system', 'content': self.system_message},
            {'role': 'user', 'content': input_text},
        ]
        model_params = {
            'model': self.model_name,
            'messages': messages,
            'max_tokens': self.output_max_length,
            'temperature': self.temperature,
            **api_params,
        }
        response = self.model.chat(**model_params)
        # self.logs['model_params'] = model_params
        return cast(str, response.text)

    model_params = {
        'model': self.model_name,
        'prompt': input_text,
        'max_tokens': self.output_max_length,
        'temperature': self.temperature,
        **api_params,
    }
    response = self.model.generate(**model_params)
    # self.logs['model_params'] = model_params
    return cast(str, response.generations[0].text.strip())

process

process(inp: Input) -> Output

Generate a result from the current pipeline content.

Source code in src/rago/generation/base.py
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
def process(self, inp: Input) -> Output:
    """Generate a result from the current pipeline content."""
    query = str(inp.query)
    data = [
        str(item)
        for item in ensure_list(
            inp.get('content', inp.get('data', inp.get('source')))
        )
    ]
    result = self.generate(query, data)
    output = Output.from_input(inp)
    output.result = result
    output.content = _serialize_generation_result(result)
    output.data = output.content
    return output