Skip to content

cohere

CohereGen class for text generation using Cohere's API.

Classes:

  • CohereGen

    Cohere generation model for text generation.

CohereGen

CohereGen(
    model_name: Optional[str] = None,
    temperature: Optional[float] = None,
    prompt_template: str = '',
    output_max_length: int = 500,
    device: str = 'auto',
    structured_output: Optional[Type[BaseModel]] = None,
    system_message: str = '',
    api_params: dict[str, Any] = DEFAULT_API_PARAMS,
    api_key: str = '',
    cache: Optional[Cache] = None,
    logs: dict[str, Any] = DEFAULT_LOGS,
)

Bases: GenerationBase

Cohere generation model for text generation.

Methods:

  • generate

    Generate text using Cohere's API.

Source code in src/rago/generation/base.py
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
def __init__(
    self,
    model_name: Optional[str] = None,
    temperature: Optional[float] = None,
    prompt_template: str = '',
    output_max_length: int = 500,
    device: str = 'auto',
    structured_output: Optional[Type[BaseModel]] = None,
    system_message: str = '',
    api_params: dict[str, Any] = DEFAULT_API_PARAMS,
    api_key: str = '',
    cache: Optional[Cache] = None,
    logs: dict[str, Any] = DEFAULT_LOGS,
) -> None:
    """Initialize Generation class."""
    if logs is DEFAULT_LOGS:
        logs = {}
    super().__init__(api_key=api_key, cache=cache, logs=logs)

    self.model_name: str = (
        model_name if model_name is not None else self.default_model_name
    )
    self.output_max_length: int = (
        output_max_length or self.default_output_max_length
    )
    self.temperature: float = (
        temperature
        if temperature is not None
        else self.default_temperature
    )

    self.prompt_template: str = (
        prompt_template or self.default_prompt_template
    )
    self.structured_output: Optional[Type[BaseModel]] = structured_output
    if api_params is DEFAULT_API_PARAMS:
        api_params = deepcopy(self.default_api_params or {})

    self.system_message = system_message
    self.api_params = api_params

    if device not in ['cpu', 'cuda', 'auto']:
        raise Exception(
            f'Device {device} not supported. Options: cpu, cuda, auto.'
        )

    cuda_available = torch.cuda.is_available()
    self.device_name: str = (
        'cpu' if device == 'cpu' or not cuda_available else 'cuda'
    )
    self.device = torch.device(self.device_name)

    self._validate()
    self._setup()

generate

generate(query: str, context: list[str]) -> str | BaseModel

Generate text using Cohere's API.

Source code in src/rago/generation/cohere.py
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
def generate(self, query: str, context: list[str]) -> str | BaseModel:
    """Generate text using Cohere's API."""
    input_text = self.prompt_template.format(
        query=query, context=' '.join(context)
    )
    api_params = self.api_params or self.default_api_params

    if self.structured_output:
        messages = []
        # Explicit instruction to generate JSON output.
        system_instruction = (
            'Generate a JSON object that strictly follows the provided  '
            'JSON schema. Do not include any additional text.'
        )
        if self.system_message:
            system_instruction += ' ' + self.system_message
        messages.append({'role': 'system', 'content': system_instruction})
        messages.append({'role': 'user', 'content': input_text})

        response_format_config = {
            'type': 'json_object',
            'json_schema': (
                self.structured_output
                if isinstance(self.structured_output, dict)
                else self.structured_output.model_json_schema()
            ),
        }
        model_params = {
            'messages': messages,
            'max_tokens': self.output_max_length,
            'temperature': self.temperature,
            'model': self.model_name,
            'response_format': response_format_config,
            **api_params,
        }

        response = self.model.client.chat(**model_params)
        self.logs['model_params'] = model_params
        json_text = response.message.content[0].text
        parsed_dict = json.loads(json_text)
        parsed_model = self.structured_output(**parsed_dict)
        return parsed_model

    if self.system_message:
        messages = [
            {'role': 'system', 'content': self.system_message},
            {'role': 'user', 'content': input_text},
        ]
        model_params = {
            'model': self.model_name,
            'messages': messages,
            'max_tokens': self.output_max_length,
            'temperature': self.temperature,
            **api_params,
        }
        response = self.model.chat(**model_params)
        self.logs['model_params'] = model_params
        return cast(str, response.text)

    model_params = {
        'model': self.model_name,
        'prompt': input_text,
        'max_tokens': self.output_max_length,
        'temperature': self.temperature,
        **api_params,
    }
    response = self.model.generate(**model_params)
    self.logs['model_params'] = model_params
    return cast(str, response.generations[0].text.strip())