{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "machine_shape": "hm",
      "gpuType": "A100",
      "runtime_attributes": {
        "runtime_version": "2025.10"
      }
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU",
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "9f8a4e1d5ed047758de6806a46d317c5": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_84b9681f21764520a0f71126010becd1",
              "IPY_MODEL_30a2155ac6264823a8eb3a8a08e89e11",
              "IPY_MODEL_b64d2052def94c68be971a4043ea02ab"
            ],
            "layout": "IPY_MODEL_4c7970cd121544f7be9c20372378fabc"
          }
        },
        "84b9681f21764520a0f71126010becd1": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_bb1dd3a7e71c478b8421174e95755f70",
            "placeholder": "​",
            "style": "IPY_MODEL_f778aca055f7498d97ec2ec0b8d6d61d",
            "value": "Loading checkpoint shards: 100%"
          }
        },
        "30a2155ac6264823a8eb3a8a08e89e11": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_158cf261647045fdaaea4f3d970008f8",
            "max": 2,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_d9b2f527afe84751b9af41d2abf99626",
            "value": 2
          }
        },
        "b64d2052def94c68be971a4043ea02ab": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_7d9109dcd4be4d19a819a5deff56c578",
            "placeholder": "​",
            "style": "IPY_MODEL_b6bf24cd3fa04c5d9d4acb3e582aa2b4",
            "value": " 2/2 [00:17&lt;00:00,  8.08s/it]"
          }
        },
        "4c7970cd121544f7be9c20372378fabc": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "bb1dd3a7e71c478b8421174e95755f70": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "f778aca055f7498d97ec2ec0b8d6d61d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "158cf261647045fdaaea4f3d970008f8": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "d9b2f527afe84751b9af41d2abf99626": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "7d9109dcd4be4d19a819a5deff56c578": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "b6bf24cd3fa04c5d9d4acb3e582aa2b4": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "61599da2ddce4d43842aa6d563c8628e": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_9c12da7dcd1647ddb779e3a652bd5038",
              "IPY_MODEL_bc36e6828e5145799fd6a899defaea31",
              "IPY_MODEL_88de4cee59ab46078abe071ce371dac7"
            ],
            "layout": "IPY_MODEL_db9f8b36313d43a399209e7e99c65424"
          }
        },
        "9c12da7dcd1647ddb779e3a652bd5038": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_8b2c912d4b0a47f9852afaf0a4bbfdb0",
            "placeholder": "​",
            "style": "IPY_MODEL_0825a3792ae246f4b7875cd0fbb48a60",
            "value": "Loading checkpoint shards: 100%"
          }
        },
        "bc36e6828e5145799fd6a899defaea31": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_75445a8605b74269b4840bc042118480",
            "max": 2,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_3e1721c2a2f24006b3a03410250c4806",
            "value": 2
          }
        },
        "88de4cee59ab46078abe071ce371dac7": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_10128815685543f6a78b98b591620f70",
            "placeholder": "​",
            "style": "IPY_MODEL_db42c330d8f64771a31a77dac7f04a11",
            "value": " 2/2 [00:17&lt;00:00,  8.05s/it]"
          }
        },
        "db9f8b36313d43a399209e7e99c65424": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "8b2c912d4b0a47f9852afaf0a4bbfdb0": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "0825a3792ae246f4b7875cd0fbb48a60": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "75445a8605b74269b4840bc042118480": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "3e1721c2a2f24006b3a03410250c4806": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "10128815685543f6a78b98b591620f70": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "db42c330d8f64771a31a77dac7f04a11": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        }
      }
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "## Supervised fine-tuning (SFT) of an LLM\n",
        "\n",
        "Recall that creating a \"ChatGPT\" at home involves 3 steps:\n",
        "\n",
        "1. pre-training a large language model (LLM) to predict the next token on internet-scale data, on clusters of thousands of GPUs. One calls the result a \"base model\"\n",
        "2. supervised fine-tuning (SFT) to turn the base model into a useful assistant\n",
        "3. human preference fine-tuning which increases the assistant's friendliness, helpfulness and safety.\n",
        "\n",
        "In this notebook, we're going to illustrate step 2. This involves supervised fine-tuning (SFT for short), also called instruction tuning.\n",
        "\n",
        "Supervised fine-tuning takes in a \"base model\" from step 1, i.e. a model that has been pre-trained on predicting the next token on internet text, and turns it into a \"chatbot\"/\"assistant\". This is done by fine-tuning the model on human instruction data, using the cross-entropy loss. This means that the model is still trained to predict the next token, although we now want the model to generate useful completions given an instruction like \"what are 10 things to do in London?\", \"How can I make pancakes?\" or \"Write me a poem about elephants\".\n",
        "\n",
        "To do this, one requires human annotators to collect useful completions, on which we can train the model. OpenAI for instance [hired human contractors for this](https://gizmodo.com/chatgpt-openai-ai-contractors-15-dollars-per-hour-1850415474), which were asked to generate useful completions given instructions, like \"In London, you can visit the Big Ben and (...)\". A nice collection of openly available SFT datasets can be found [here](https://huggingface.co/collections/HuggingFaceH4/awesome-sft-datasets-65788b571bf8e371c4e4241a).\n",
        "\n",
        "This way, the model becomes more useful: rather than simply predicting the next token (which might give undesirable outputs, like generating follow-up questions rather than answering the question), we now make it more likely that the model will output useful completions for any instruction we give it. We basically steer it in the direction of generating useful completions which a human could have written given any instruction.\n",
        "\n",
        "Notes:\n",
        "\n",
        "* the entire notebook is based on and can be seen as an annotated version of the [Alignment Handbook](https://github.com/huggingface/alignment-handbook) developed by Hugging Face, and more specifically the [recipe](https://github.com/huggingface/alignment-handbook/blob/main/recipes/zephyr-7b-beta/sft/config_qlora.yaml) used to train Zephyr-7b-beta. Huge kudos to the team for creating this!\n",
        "* this notebook applies to any decoder-only LLM available in the Transformers library. In this notebook, we are going to fine-tune the [Mistral-7B base model](https://huggingface.co/mistralai/Mistral-7B-v0.1), which is one of the best open-source large language models at the time of writing."
      ],
      "metadata": {
        "id": "_oPAJl2uDmil"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Required hardware\n",
        "\n",
        "The notebook is designed to be run on any NVIDIA GPU which has the [Ampere architecture](https://en.wikipedia.org/wiki/Ampere_(microarchitecture)) or later with at least 24GB of RAM. This includes:\n",
        "\n",
        "* NVIDIA RTX 3090, 4090\n",
        "* NVIDIA A100, H100, H200\n",
        "\n",
        "and so on.\n",
        "\n",
        "The reason for an Ampere requirement is because we're going to use the [bfloat16 (bf16) format](https://en.wikipedia.org/wiki/Bfloat16_floating-point_format), which is not supported on older architectures like Turing.\n",
        "\n",
        "But: a few tweaks can be made to train the model in float16 (fp16), which is supported by older GPUs like:\n",
        "\n",
        "* NVIDIA RTX 2080\n",
        "* NVIDIA Tesla T4\n",
        "* NVIDIA V100.\n",
        "\n",
        "Comments are added regarding where to swap bf16 with fp16.\n",
        "\n",
        "## Set-up environment\n",
        "\n",
        "Let's start by installing all the 🤗 goodies we need to do supervised fine-tuning. We're going to use\n",
        "\n",
        "* Transformers for the LLM which we're going to fine-tune\n",
        "* Datasets for loading a SFT dataset from the 🤗 hub, and preparing it for the model\n",
        "* BitsandBytes and PEFT for fine-tuning the model on consumer hardware, leveraging [Q-LoRa](https://huggingface.co/blog/4bit-transformers-bitsandbytes), a technique which drastically reduces the compute requirements for fine-tuning\n",
        "* TRL, a [library](https://huggingface.co/docs/trl/index) which includes useful Trainer classes for LLM fine-tuning."
      ],
      "metadata": {
        "id": "Bjdw05Rk5fYm"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!nvidia-smi"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "y1h2HzwZxlSI",
        "outputId": "cdcaf3fd-9a1b-413f-cf27-d7563be4c851"
      },
      "execution_count": 1,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Mon Apr 13 03:51:09 2026       \n",
            "+-----------------------------------------------------------------------------------------+\n",
            "| NVIDIA-SMI 580.82.07              Driver Version: 580.82.07      CUDA Version: 13.0     |\n",
            "+-----------------------------------------+------------------------+----------------------+\n",
            "| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |\n",
            "| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |\n",
            "|                                         |                        |               MIG M. |\n",
            "|=========================================+========================+======================|\n",
            "|   0  NVIDIA A100-SXM4-80GB          Off |   00000000:00:05.0 Off |                    0 |\n",
            "| N/A   32C    P0             54W /  400W |       0MiB /  81920MiB |      0%      Default |\n",
            "|                                         |                        |             Disabled |\n",
            "+-----------------------------------------+------------------------+----------------------+\n",
            "\n",
            "+-----------------------------------------------------------------------------------------+\n",
            "| Processes:                                                                              |\n",
            "|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |\n",
            "|        ID   ID                                                               Usage      |\n",
            "|=========================================================================================|\n",
            "|  No running processes found                                                             |\n",
            "+-----------------------------------------------------------------------------------------+\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install -q transformers[torch] datasets"
      ],
      "metadata": {
        "id": "-9357hGFRqdi"
      },
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install -q bitsandbytes trl peft"
      ],
      "metadata": {
        "id": "rxpq51gW_ySc"
      },
      "execution_count": 4,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "We also install [Flash Attention](https://github.com/Dao-AILab/flash-attention), which speeds up the attention computations of the model."
      ],
      "metadata": {
        "id": "0RgDwMbXpC4L"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install flash-attn --no-build-isolation"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "iNJvtR1wGxHm",
        "outputId": "ddb301b5-97b5-4b8a-970f-037ac3b93622"
      },
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Collecting flash-attn\n",
            "  Downloading flash_attn-2.8.3.tar.gz (8.4 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.4/8.4 MB\u001b[0m \u001b[31m102.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h  Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "Requirement already satisfied: torch in /usr/local/lib/python3.12/dist-packages (from flash-attn) (2.8.0+cu126)\n",
            "Requirement already satisfied: einops in /usr/local/lib/python3.12/dist-packages (from flash-attn) (0.8.1)\n",
            "Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from torch->flash-attn) (3.20.0)\n",
            "Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.12/dist-packages (from torch->flash-attn) (4.15.0)\n",
            "Requirement already satisfied: setuptools in /usr/local/lib/python3.12/dist-packages (from torch->flash-attn) (75.2.0)\n",
            "Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.12/dist-packages (from torch->flash-attn) (1.13.3)\n",
            "Requirement already satisfied: networkx in /usr/local/lib/python3.12/dist-packages (from torch->flash-attn) (3.5)\n",
            "Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from torch->flash-attn) (3.1.6)\n",
            "Requirement already satisfied: fsspec in /usr/local/lib/python3.12/dist-packages (from torch->flash-attn) (2025.3.0)\n",
            "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch->flash-attn) (12.6.77)\n",
            "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch->flash-attn) (12.6.77)\n",
            "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.6.80 in /usr/local/lib/python3.12/dist-packages (from torch->flash-attn) (12.6.80)\n",
            "Requirement already satisfied: nvidia-cudnn-cu12==9.10.2.21 in /usr/local/lib/python3.12/dist-packages (from torch->flash-attn) (9.10.2.21)\n",
            "Requirement already satisfied: nvidia-cublas-cu12==12.6.4.1 in /usr/local/lib/python3.12/dist-packages (from torch->flash-attn) (12.6.4.1)\n",
            "Requirement already satisfied: nvidia-cufft-cu12==11.3.0.4 in /usr/local/lib/python3.12/dist-packages (from torch->flash-attn) (11.3.0.4)\n",
            "Requirement already satisfied: nvidia-curand-cu12==10.3.7.77 in /usr/local/lib/python3.12/dist-packages (from torch->flash-attn) (10.3.7.77)\n",
            "Requirement already satisfied: nvidia-cusolver-cu12==11.7.1.2 in /usr/local/lib/python3.12/dist-packages (from torch->flash-attn) (11.7.1.2)\n",
            "Requirement already satisfied: nvidia-cusparse-cu12==12.5.4.2 in /usr/local/lib/python3.12/dist-packages (from torch->flash-attn) (12.5.4.2)\n",
            "Requirement already satisfied: nvidia-cusparselt-cu12==0.7.1 in /usr/local/lib/python3.12/dist-packages (from torch->flash-attn) (0.7.1)\n",
            "Requirement already satisfied: nvidia-nccl-cu12==2.27.3 in /usr/local/lib/python3.12/dist-packages (from torch->flash-attn) (2.27.3)\n",
            "Requirement already satisfied: nvidia-nvtx-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch->flash-attn) (12.6.77)\n",
            "Requirement already satisfied: nvidia-nvjitlink-cu12==12.6.85 in /usr/local/lib/python3.12/dist-packages (from torch->flash-attn) (12.6.85)\n",
            "Requirement already satisfied: nvidia-cufile-cu12==1.11.1.6 in /usr/local/lib/python3.12/dist-packages (from torch->flash-attn) (1.11.1.6)\n",
            "Requirement already satisfied: triton==3.4.0 in /usr/local/lib/python3.12/dist-packages (from torch->flash-attn) (3.4.0)\n",
            "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy>=1.13.3->torch->flash-attn) (1.3.0)\n",
            "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2->torch->flash-attn) (3.0.3)\n",
            "Building wheels for collected packages: flash-attn\n",
            "  Building wheel for flash-attn (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for flash-attn: filename=flash_attn-2.8.3-cp312-cp312-linux_x86_64.whl size=256040057 sha256=f25da18657a87fc83dc1bfb8b7751b82246e9db355510226b674fd437c34b5fb\n",
            "  Stored in directory: /root/.cache/pip/wheels/3d/59/46/f282c12c73dd4bb3c2e3fe199f1a0d0f8cec06df0cccfeee27\n",
            "Successfully built flash-attn\n",
            "Installing collected packages: flash-attn\n",
            "Successfully installed flash-attn-2.8.3\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Load dataset\n",
        "\n",
        "Note: the alignment handbook supports mixing several datasets, each with a certain portion of training examples. However, the Zephyr recipe only includes a single dataset, which is the [UltraChat200k dataset](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)."
      ],
      "metadata": {
        "id": "hJIqlyI15kj8"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from datasets import load_dataset\n",
        "\n",
        "# based on config\n",
        "raw_datasets = load_dataset(\"HuggingFaceH4/ultrachat_200k\")"
      ],
      "metadata": {
        "id": "YhYvRDF25j59"
      },
      "execution_count": 23,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "The dataset contains various splits, each with a certain number of rows. In our case, as we're going to do supervised fine-tuning (SFT), only the \"train_sft\" and \"test_sft\" splits are relevant for us."
      ],
      "metadata": {
        "id": "WFYvJUfODabt"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from datasets import DatasetDict\n",
        "\n",
        "# remove this when done debugging\n",
        "indices = range(0,100)\n",
        "\n",
        "dataset_dict = {\"train\": raw_datasets[\"train_sft\"].select(indices),\n",
        "                \"test\": raw_datasets[\"test_sft\"].select(indices)}\n",
        "\n",
        "raw_datasets = DatasetDict(dataset_dict)\n",
        "raw_datasets"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "jX1sIK6X6Opi",
        "outputId": "f7a100fb-25e0-4a54-9d07-df0117cedcb5"
      },
      "execution_count": 24,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "DatasetDict({\n",
              "    train: Dataset({\n",
              "        features: ['prompt', 'prompt_id', 'messages'],\n",
              "        num_rows: 100\n",
              "    })\n",
              "    test: Dataset({\n",
              "        features: ['prompt', 'prompt_id', 'messages'],\n",
              "        num_rows: 100\n",
              "    })\n",
              "})"
            ]
          },
          "metadata": {},
          "execution_count": 24
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Let's check one example. The important thing is that each example should contain a list of messages:"
      ],
      "metadata": {
        "id": "sByKH4hd8mUm"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "example = raw_datasets[\"train\"][0]\n",
        "print(example.keys())"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "P2XwVpT17TAb",
        "outputId": "e94b36a8-5485-42a7-d823-01a1c323273f"
      },
      "execution_count": 25,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "dict_keys(['prompt', 'prompt_id', 'messages'])\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Each message is a dictionary containing 2 keys, namely:\n",
        "\n",
        "* \"role\": specifies who the creator of the message is (could be \"system\", \"assistant\" or \"user\" - the latter referring to a human).\n",
        "* \"content\": the actual content of the message.\n",
        "\n",
        "Let's print out the sequence of messages for this training example:"
      ],
      "metadata": {
        "id": "HaWnT4uBCjTy"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "messages = example[\"messages\"]\n",
        "for message in messages:\n",
        "  role = message[\"role\"]\n",
        "  content = message[\"content\"]\n",
        "  print('{0:20}:  {1}'.format(role, content))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "IQ1sMda27Zj6",
        "outputId": "1efe1596-a8a2-4774-fb44-3925939e9b5e"
      },
      "execution_count": 26,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "user                :  These instructions apply to section-based themes (Responsive 6.0+, Retina 4.0+, Parallax 3.0+ Turbo 2.0+, Mobilia 5.0+). What theme version am I using?\n",
            "On your Collections pages & Featured Collections sections, you can easily show the secondary image of a product on hover by enabling one of the theme's built-in settings!\n",
            "Your Collection pages & Featured Collections sections will now display the secondary product image just by hovering over that product image thumbnail.\n",
            "Does this feature apply to all sections of the theme or just specific ones as listed in the text material?\n",
            "assistant           :  This feature only applies to Collection pages and Featured Collections sections of the section-based themes listed in the text material.\n",
            "user                :  Can you guide me through the process of enabling the secondary image hover feature on my Collection pages and Featured Collections sections?\n",
            "assistant           :  Sure, here are the steps to enable the secondary image hover feature on your Collection pages and Featured Collections sections:\n",
            "\n",
            "1. Log in to your Shopify account and go to your Online Store.\n",
            "2. Click on Customize theme for the section-based theme you are using.\n",
            "3. Navigate to the Collection pages or Featured Collections section where you want to enable the secondary image hover feature.\n",
            "4. With the section open, click on the section's setting (gear) icon in the top-left corner.\n",
            "5. In the settings panel that appears, look for an option labeled 'Image display' or 'Image hover'.\n",
            "6. If available, select 'Show secondary image on hover'.\n",
            "7. Save the changes and preview the Collection/Featured Collection page to see the effect.\n",
            "\n",
            "If you're having trouble finding the setting, the best thing to do is refer to your theme's documentation, since the location and labeling of settings can vary between themes.\n",
            "user                :  Can you provide me with a link to the documentation for my theme?\n",
            "assistant           :  I don't have access to your store's theme information. However, you can usually find the documentation for your theme by going to the shopify theme store, finding your theme and clicking on the 'support' link located in the bottom right corner of the page. Alternatively, you can do a google search for the name of your theme followed by 'documentation' or 'user guide'.\n",
            "user                :  Can you confirm if this feature also works for the Quick Shop section of my theme?\n",
            "assistant           :  The secondary image hover feature may or may not work for your Quick Shop section, depending on the configuration of your theme. Some themes include this feature in the Quick Shop section by default, while others may require additional customization. To check if this feature is available for the Quick Shop section of your theme, follow these steps:\n",
            "\n",
            "1. Go to the Quick Shop section where you would like to enable the feature. 2. Click on the Quick Shop settings icon (gear icon) and look for 'Image display' or 'Image hover'. 3. If available, select 'Show secondary image on hover'. 4. Save the changes. If this option is not available in your Quick Shop section settings, you may need to reach out to your theme developer for assistance with customizing your Quick Shop section to include this feature.\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "In this case, it looks like the instructions are about enabling certain features in Shopify. Interesting!"
      ],
      "metadata": {
        "id": "0DL8S3dkT2Av"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Load tokenizer\n",
        "\n",
        "Next, we instantiate the tokenizer, which is required to prepare the text for the model. The model doesn't directly take strings as input, but rather `input_ids`, which represent integer indices in the vocabulary of a Transformer model. Refer to my [YouTube video](https://www.youtube.com/watch?v=IGu7ivuy1Ag&ab_channel=NielsRogge) if you want to know more about it.\n",
        "\n",
        "We also set some attributes which the tokenizer of a base model typically doesn't have set, such as:\n",
        "\n",
        "- the padding token ID. During pre-training, one doesn't need to pad since one just creates blocks of text to predict the next token, but during fine-tuning, we will need to pad the (instruction, completion) pairs in order to create batches of equal length.\n",
        "- the model max length: this is required in order to truncate sequences which are too long for the model. Here we decide to train on at most 2048 tokens.\n",
        "- the chat template. A [chat template](https://huggingface.co/blog/chat-templates) determines how each list of messages is turned into a tokenizable string, by adding special strings in between such as `<|user|>` to indicate a user message and `<|assistant|>` to indicate the chatbot's response. Here we define the default chat template, used by most chat models. See also the [docs](https://huggingface.co/docs/transformers/main/en/chat_templating)."
      ],
      "metadata": {
        "id": "RVRxCJQ76spF"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Add the huggingface login here\n",
        "\n",
        "!huggingface-cli login"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "iO65DIW5yqQf",
        "outputId": "0e799d9f-a6f7-4cf6-a3ad-b79e39a4d330"
      },
      "execution_count": 10,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\u001b[33m⚠️  Warning: 'huggingface-cli login' is deprecated. Use 'hf auth login' instead.\u001b[0m\n",
            "\n",
            "    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|\n",
            "    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|\n",
            "    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|\n",
            "    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|\n",
            "    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|\n",
            "\n",
            "    To log in, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .\n",
            "Enter your token (input will not be visible): \n",
            "Add token as git credential? (Y/n) Y\n",
            "Token is valid (permission: write).\n",
            "The token `homo-lat` has been saved to /root/.cache/huggingface/stored_tokens\n",
            "\u001b[1m\u001b[31mCannot authenticate through git-credential as no helper is defined on your machine.\n",
            "You might have to re-authenticate when pushing to the Hugging Face Hub.\n",
            "Run the following command in your terminal in case you want to set the 'store' credential helper as default.\n",
            "\n",
            "git config --global credential.helper store\n",
            "\n",
            "Read https://git-scm.com/book/en/v2/Git-Tools-Credential-Storage for more details.\u001b[0m\n",
            "Token has not been saved to git credential helper.\n",
            "Your token has been saved to /root/.cache/huggingface/token\n",
            "Login successful.\n",
            "The current active token is: `homo-lat`\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from transformers import AutoTokenizer\n",
        "\n",
        "model_id = \"mistralai/Mistral-7B-v0.1\"\n",
        "\n",
        "tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
        "\n",
        "# set pad_token_id equal to the eos_token_id if not set\n",
        "if tokenizer.pad_token_id is None:\n",
        "  tokenizer.pad_token_id = tokenizer.eos_token_id\n",
        "\n",
        "# Set reasonable default for models without max length\n",
        "if tokenizer.model_max_length > 100_000:\n",
        "  tokenizer.model_max_length = 2048\n",
        "\n",
        "# Set chat template\n",
        "DEFAULT_CHAT_TEMPLATE = \"{% for message in messages %}\\n{% if message['role'] == 'user' %}\\n{{ '<|user|>\\n' + message['content'] + eos_token }}\\n{% elif message['role'] == 'system' %}\\n{{ '<|system|>\\n' + message['content'] + eos_token }}\\n{% elif message['role'] == 'assistant' %}\\n{{ '<|assistant|>\\n'  + message['content'] + eos_token }}\\n{% endif %}\\n{% if loop.last and add_generation_prompt %}\\n{{ '<|assistant|>' }}\\n{% endif %}\\n{% endfor %}\"\n",
        "tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE"
      ],
      "metadata": {
        "id": "VvIfUqjK6ntu",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "be9f1f1d-58ca-4a0e-b3e5-8a746670dae1"
      },
      "execution_count": 27,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "loading file tokenizer.model from cache at /root/.cache/huggingface/hub/models--mistralai--Mistral-7B-v0.1/snapshots/27d67f1b5f57dc0953326b2601d68371d40ea8da/tokenizer.model\n",
            "loading file tokenizer.json from cache at /root/.cache/huggingface/hub/models--mistralai--Mistral-7B-v0.1/snapshots/27d67f1b5f57dc0953326b2601d68371d40ea8da/tokenizer.json\n",
            "loading file added_tokens.json from cache at None\n",
            "loading file special_tokens_map.json from cache at /root/.cache/huggingface/hub/models--mistralai--Mistral-7B-v0.1/snapshots/27d67f1b5f57dc0953326b2601d68371d40ea8da/special_tokens_map.json\n",
            "loading file tokenizer_config.json from cache at /root/.cache/huggingface/hub/models--mistralai--Mistral-7B-v0.1/snapshots/27d67f1b5f57dc0953326b2601d68371d40ea8da/tokenizer_config.json\n",
            "loading file chat_template.jinja from cache at None\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Apply chat template\n",
        "\n",
        "Once we have equipped the tokenizer with the appropriate attributes, it's time to apply the chat template to each list of messages. Here we basically turn each list of (instruction, completion) messages into a tokenizable string for the model.\n",
        "\n",
        "Note that we specify `tokenize=False` here, since the `SFTTrainer` which we'll define later on will perform the tokenization internally. Here we only turn the list of messages into strings with the same format."
      ],
      "metadata": {
        "id": "3UNHQsTJ7O6I"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import re\n",
        "import random\n",
        "from multiprocessing import cpu_count\n",
        "\n",
        "def apply_chat_template(example, tokenizer):\n",
        "    messages = example[\"messages\"]\n",
        "    # We add an empty system message if there is none\n",
        "    if messages[0][\"role\"] != \"system\":\n",
        "        messages.insert(0, {\"role\": \"system\", \"content\": \"\"})\n",
        "    example[\"text\"] = tokenizer.apply_chat_template(messages, tokenize=False)\n",
        "\n",
        "    return example\n",
        "\n",
        "column_names = list(raw_datasets[\"train\"].features)\n",
        "raw_datasets = raw_datasets.map(apply_chat_template,\n",
        "                                num_proc=cpu_count(),\n",
        "                                fn_kwargs={\"tokenizer\": tokenizer},\n",
        "                                remove_columns=column_names,\n",
        "                                desc=\"Applying chat template\",)\n",
        "\n",
        "# create the splits\n",
        "train_dataset = raw_datasets[\"train\"]\n",
        "eval_dataset = raw_datasets[\"test\"]\n",
        "\n",
        "for index in random.sample(range(len(raw_datasets[\"train\"])), 3):\n",
        "  print(f\"Sample {index} of the processed training set:\\n\\n{raw_datasets['train'][index]['text']}\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "44kIpOXa7Ep4",
        "outputId": "8596910b-b70d-4eae-fd0d-c4938f4c05fe"
      },
      "execution_count": 28,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Sample 81 of the processed training set:\n",
            "\n",
            "<|system|>\n",
            "</s>\n",
            "<|user|>\n",
            "Write a JavaScript function that takes in one parameter and checks if that parameter is a perfect square. The function should return a boolean value - true if the parameter is a perfect square and false otherwise. You should test your function using a range of inputs and provide the output for each test case.</s>\n",
            "<|assistant|>\n",
            "```\n",
            "function isPerfectSquare(num) {\n",
            "  if (num < 0) {\n",
            "    return false;\n",
            "  }\n",
            "  else if (Math.sqrt(num) === Math.floor(Math.sqrt(num))) {\n",
            "    return true;\n",
            "  }\n",
            "  else {\n",
            "    return false;\n",
            "  }\n",
            "}\n",
            "\n",
            "console.log(isPerfectSquare(16)); // true\n",
            "console.log(isPerfectSquare(25)); // true\n",
            "console.log(isPerfectSquare(144)); // true\n",
            "console.log(isPerfectSquare(7)); // false\n",
            "console.log(isPerfectSquare(-25)); // false\n",
            "```</s>\n",
            "<|user|>\n",
            "Great, the function looks good! Can you add some comments to explain what each line of code is doing?</s>\n",
            "<|assistant|>\n",
            "Sure, here is the updated function with comments:\n",
            "\n",
            "```\n",
            "function isPerfectSquare(num) {\n",
            "  if (num < 0) { // check if the number is negative\n",
            "    return false;\n",
            "  }\n",
            "  else if (Math.sqrt(num) === Math.floor(Math.sqrt(num))) { // calculate the square root of num and check if it is a whole number\n",
            "    return true; // if it is a whole number, return true\n",
            "  }\n",
            "  else {\n",
            "    return false; // otherwise, return false\n",
            "  }\n",
            "}\n",
            "\n",
            "// testing the function with different inputs\n",
            "console.log(isPerfectSquare(16)); // true\n",
            "console.log(isPerfectSquare(25)); // true\n",
            "console.log(isPerfectSquare(144)); // true\n",
            "console.log(isPerfectSquare(7)); // false\n",
            "console.log(isPerfectSquare(-25)); // false\n",
            "```\n",
            "\n",
            "I hope this helps!</s>\n",
            "\n",
            "Sample 14 of the processed training set:\n",
            "\n",
            "<|system|>\n",
            "</s>\n",
            "<|user|>\n",
            "Write step-by-step instructions for making a traditional snowman using three large snowballs, including information on how to properly pack the snow, tools (if applicable), and recommended size ratios for each snowball. Be sure to include any additional details on decorations or accessories, if desired.</s>\n",
            "<|assistant|>\n",
            "Step 1: Find a suitable location for building the snowman. Look for a large, flat area with plenty of snow.\n",
            "\n",
            "Step 2: Start by packing the first snowball. Roll a large ball of snow, about shoulder height in size. Roll the snowball around in the snow, covering it thoroughly, and packing the snow tightly as you go. Keep rolling until the snowball is smooth and even in shape. This will form the base of your snowman.\n",
            "\n",
            "Step 3: Repeat Step 2 to make the second and third snowballs. The second snowball should be slightly smaller than the first one (about waist height), and the third snowball should be even smaller (about knee height). Make sure each snowball is packed tightly and covered evenly in snow.\n",
            "\n",
            "Step 4: Once you have all three snowballs ready, lift the second snowball onto the first one. Use your hands to press the two snowballs together until they stick together firmly. You can use a shovel to lift the snowball if it's too heavy to lift by hand.\n",
            "\n",
            "Step 5: Next, press the third snowball on top of the second one. Make sure it is centered and sits evenly on top of the other two snowballs. Press it down firmly onto the second snowball to keep it in place.\n",
            "\n",
            "Step 6: Now it's time to add character to your snowman. Use twigs or branches for its arms, and add a hat and scarf for a cozy look. You can use buttons or rocks for the eyes, and a carrot for its nose. Get creative and use whatever materials you have available.\n",
            "\n",
            "Step 7: Finally, step back and admire your creation! You've successfully built a traditional snowman using three large snowballs.</s>\n",
            "<|user|>\n",
            "It looks easy enough, but can you provide more details on how to pack the snow properly? Do I need to use my hands, or can I use a tool?</s>\n",
            "<|assistant|>\n",
            "Packing the snow properly is crucial for building a sturdy snowman. You can use your hands or a tool, such as a shovel or a wooden block, to pack the snow tightly. Here are some tips for packing the snow:\n",
            "\n",
            "- Use dense, compacted snow for building the snowballs. Avoid using dry, powdery snow, as it won't stick together well. - Start by rolling a small snowball and then gradually add more snow as you roll. This will help ensure that the snow is packed evenly and tightly. - Use your hands or a tool to press the snow together firmly. Make sure there are no air pockets or gaps in the snowball, as this can cause it to collapse. - Roll the snowball in all directions to create a round and smooth shape. If the snowball becomes lopsided, gently push it back into shape. - As you roll the snowball, add more snow occasionally to ensure that it remains larger than the previous snowball. This will help create the classic snowman shape with a wider base and a narrower top. Remember, building a snowman is all about having fun, so don't worry too much about getting everything perfect. Just enjoy the process and let your creativity guide you!</s>\n",
            "<|user|>\n",
            "Can you also give me some ideas on how to decorate the snowman with materials other than branches and carrots? Something more unusual perhaps?</s>\n",
            "<|assistant|>\n",
            "Absolutely! There are many creative and unusual ways to decorate a snowman. Here are some ideas to inspire you:\n",
            "\n",
            "1. Use old clothes: Rather than just settling for a hat and scarf, you could dress up your snowman in an old coat, pants, or even a tie. Be sure to use clothes that aren't needed anymore and wash them before using to prevent staining the snow. 2. Add some color: Mix food coloring with water and fill spray bottles. You could spray your snowman with different colors. This will give it an eye-catching look as it sticks to the snow. 3. Use household items: You can create a unique snowman by utilizing household items such as CDs, paper towel tubes, or even plastic lids. Using a sturdy adhesive like duct tape, hot glue, or super glue, you could add these items to create a face, eyes, or nose. 4. Get creative with eyes and nose: Instead of using traditional items like coal and carrot for eyes and nose, get creative. Use jam jar lids or orange slices to create eyes, or pine cones for the nose. 5. Let your imagination run wild: The sky is the limit when decorating a snowman. You could add any other accessory that you like, such as hats, sunglasses or without any accessories, too. Try something new and exciting that will make your snowman stand out. Remember, decorating the snowman is where you have to let your imagination run wild. The idea is to get creative and have fun!</s>\n",
            "\n",
            "Sample 3 of the processed training set:\n",
            "\n",
            "<|system|>\n",
            "</s>\n",
            "<|user|>\n",
            "De León, previewing the speech he will give today, said he will highlight his Senate Bill 535, which directs a quarter of the proceeds from the Greenhouse Gas Reduction Fund to projects that benefit disadvantaged communities.\n",
            "On Thursday, de León nodded approvingly as a string of leading scientists and religious leaders gathered for hours of weedy policy discussions on the impacts of climate change, including gloomy predictions on mortality attributable to air pollution.\n",
            "SEIU HEADS TO THE BAR: Employees of the State Bar of California represented by SEIU are planning a picket line today at the bar building in Los Angeles to protest the latest contract offer. What is the reason for SEIU employees planning a picket line at the State Bar of California building in Los Angeles?</s>\n",
            "<|assistant|>\n",
            "The reason for SEIU employees planning a picket line at the State Bar of California building in Los Angeles is to protest the latest contract offer.</s>\n",
            "<|user|>\n",
            "Can you provide more details about Senate Bill 535 and how it will benefit disadvantaged communities?</s>\n",
            "<|assistant|>\n",
            "Sure, Senate Bill 535 is a California law that directs at least 25% of the money generated through the state's Greenhouse Gas Reduction Fund towards projects that benefit disadvantaged communities. The bill is designed to ensure that the benefits of California's climate policies, which include cap-and-trade, are felt most heavily by communities that have been disproportionately impacted by pollution and environmental degradation.\n",
            "\n",
            "The funds generated through the bill are invested in projects that promote clean energy, reduce pollution, and support sustainable communities. This includes initiatives like public transportation, energy efficiency retrofits for low-income homes, and urban forestry projects.\n",
            "\n",
            "Proponents of the bill argue that it helps to level the playing field for low-income communities and communities of color, which have historically faced greater exposure to environmental hazards and pollution. By directing funds towards these communities, the bill aims to create more equitable and sustainable outcomes for all Californians.</s>\n",
            "\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "We also specified `remove_columns` to the map function above, meaning that we are now left with only 1 column: \"text\".\n",
        "\n",
        "Hence the set-up is now very similar to pre-training: we will just train the model predict the next token, given the previous ones. In this case, the model will learn to generate completions given instructions.\n",
        "\n",
        "Hence, similar to pre-training, the labels will be created automatically based on the inputs (by shifting them one position to the right). The model is still trained using cross-entropy. This means that evaluation will mostly be done by checking perplexity/validation loss/model generations."
      ],
      "metadata": {
        "id": "Ro7VPkBLS8XG"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "raw_datasets"
      ],
      "metadata": {
        "id": "L_tvjW-Y-uBT",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "4dbf7d98-a45c-4e64-f36c-5164eabca3a2"
      },
      "execution_count": 29,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "DatasetDict({\n",
              "    train: Dataset({\n",
              "        features: ['text'],\n",
              "        num_rows: 100\n",
              "    })\n",
              "    test: Dataset({\n",
              "        features: ['text'],\n",
              "        num_rows: 100\n",
              "    })\n",
              "})"
            ]
          },
          "metadata": {},
          "execution_count": 29
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Define model arguments\n",
        "\n",
        "Next, it's time to define the model arguments.\n",
        "\n",
        "Here, some explanation is required regarding ways to fine-tune model.\n",
        "\n",
        "### Full fine-tuning\n",
        "\n",
        "Typically, one performs \"full fine-tuning\": this means that we will simply update all the weights of the base model during fine-tuning. This is then typically done either in full precision (float32), or mixed precision (a combination of float32 and float16). However, with ever larger models like LLMs, this becomes infeasible.\n",
        "\n",
        "For reference, float32 means that each parameter of a model gets saved in 32 bits or 4 bytes. Hence, for a 7 billion parameter model like Mistral-7B, one requires 7 billion parameters \\* 4 bytes per parameter = 28 GB of GPU RAM, just to load the model. During training with an optimizer like AdamW, one not only requires memory for the model but also for the gradients and optimizer states, which roughly comes down to approximately 18 times the size of the model in gigabytes when training with mixed precision, in this case 7 * 18 = 126 GB of GPU RAM. And that's just for a 7B parameter model! See the guide for more info: https://huggingface.co/docs/transformers/v4.20.1/en/perf_train_gpu_one.\n",
        "\n",
        "### LoRa, a PEFT method\n",
        "\n",
        "Hence, some clever people at Microsoft have come up with a method called [LoRa](https://huggingface.co/docs/peft/main/en/conceptual_guides/lora) (low-rank adaptation). The idea here is that, rather than performing full fine-tuning, we are going to freeze the existing model and only add a few parameter weights to the model (called \"adapters\"), which we're going to train.\n",
        "\n",
        "LoRa is what we call a parameter-efficient fine-tuning (PEFT) method. It's a popular method for fine-tuning models in a parameter-efficient way, by only training a few adapters, keeping the existing model untouched. LoRa is available in the [PEFT library](https://huggingface.co/docs/peft/v0.7.1/en/index) by Hugging Face, which also supports various other PEFT methods (but LoRa is the most popular one at the time of writing).\n",
        "\n",
        "### QLoRa, an even more efficient method\n",
        "\n",
        "With regular LoRa, one would keep the base model in 32 or 16 bits in memory, and then train the parameter weights. However, there have been new methods developed to shrink the size of a model considerably, to 8 or 4 bits per parameter (we call this [\"quantization\"](https://huggingface.co/docs/transformers/main_classes/quantization)). Hence, if we apply LoRa to a quantized model (like a 4-bit model), then we call this QLoRa. We have a [blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes) that tells you all about it. There are various quantization methods available, here we're going to use the [BitsandBytes](https://huggingface.co/docs/transformers/main_classes/quantization#transformers.BitsAndBytesConfig) integration.\n"
      ],
      "metadata": {
        "id": "7F9-BH4g9sr9"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from transformers import BitsAndBytesConfig, AutoModelForCausalLM\n",
        "import torch\n",
        "\n",
        "# specify how to quantize the model\n",
        "quantization_config = BitsAndBytesConfig(\n",
        "            load_in_4bit=True,\n",
        "            bnb_4bit_use_double_quant=True,\n",
        "            bnb_4bit_quant_type=\"nf4\",\n",
        "            bnb_4bit_compute_dtype=\"bfloat16\",\n",
        ")\n",
        "\n",
        "device_map = {\"\": torch.cuda.current_device()} if torch.cuda.is_available() else None\n",
        "\n",
        "model_kwargs = dict(\n",
        "    torch_dtype=\"auto\",\n",
        "    use_cache=False, # set to False as we're going to use gradient checkpointing\n",
        "    device_map=device_map,\n",
        ")\n",
        "\n",
        "model = AutoModelForCausalLM.from_pretrained(\n",
        "    model_id,\n",
        "    quantization_config=quantization_config,\n",
        "    attn_implementation=\"flash_attention_2\"# set this to True if your GPU supports it (Flash Attention drastically speeds up model computations)\n",
        ")"
      ],
      "metadata": {
        "id": "XrSQuIyu8Rt1",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 817,
          "referenced_widgets": [
            "9f8a4e1d5ed047758de6806a46d317c5",
            "84b9681f21764520a0f71126010becd1",
            "30a2155ac6264823a8eb3a8a08e89e11",
            "b64d2052def94c68be971a4043ea02ab",
            "4c7970cd121544f7be9c20372378fabc",
            "bb1dd3a7e71c478b8421174e95755f70",
            "f778aca055f7498d97ec2ec0b8d6d61d",
            "158cf261647045fdaaea4f3d970008f8",
            "d9b2f527afe84751b9af41d2abf99626",
            "7d9109dcd4be4d19a819a5deff56c578",
            "b6bf24cd3fa04c5d9d4acb3e582aa2b4"
          ]
        },
        "outputId": "f0bb3ced-d017-4fb2-fbc4-a3ccc35c641d"
      },
      "execution_count": 30,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "loading configuration file config.json from cache at /root/.cache/huggingface/hub/models--mistralai--Mistral-7B-v0.1/snapshots/27d67f1b5f57dc0953326b2601d68371d40ea8da/config.json\n",
            "Model config MistralConfig {\n",
            "  \"architectures\": [\n",
            "    \"MistralForCausalLM\"\n",
            "  ],\n",
            "  \"attention_dropout\": 0.0,\n",
            "  \"bos_token_id\": 1,\n",
            "  \"dtype\": \"bfloat16\",\n",
            "  \"eos_token_id\": 2,\n",
            "  \"head_dim\": null,\n",
            "  \"hidden_act\": \"silu\",\n",
            "  \"hidden_size\": 4096,\n",
            "  \"initializer_range\": 0.02,\n",
            "  \"intermediate_size\": 14336,\n",
            "  \"max_position_embeddings\": 32768,\n",
            "  \"model_type\": \"mistral\",\n",
            "  \"num_attention_heads\": 32,\n",
            "  \"num_hidden_layers\": 32,\n",
            "  \"num_key_value_heads\": 8,\n",
            "  \"rms_norm_eps\": 1e-05,\n",
            "  \"rope_theta\": 10000.0,\n",
            "  \"sliding_window\": 4096,\n",
            "  \"tie_word_embeddings\": false,\n",
            "  \"transformers_version\": \"4.57.0\",\n",
            "  \"use_cache\": true,\n",
            "  \"vocab_size\": 32000\n",
            "}\n",
            "\n",
            "Overriding dtype=None with `dtype=torch.float16` due to requirements of `bitsandbytes` to enable model loading in 8-bit or 4-bit. Pass your own dtype to specify the dtype of the remaining non-linear layers or pass dtype=torch.float16 to remove this warning.\n",
            "The device_map was not initialized. Setting device_map to {'': 0}. If you want to use the model for inference, please set device_map ='auto' \n",
            "loading weights file model.safetensors from cache at /root/.cache/huggingface/hub/models--mistralai--Mistral-7B-v0.1/snapshots/27d67f1b5f57dc0953326b2601d68371d40ea8da/model.safetensors.index.json\n",
            "Instantiating MistralForCausalLM model under default dtype torch.float16.\n",
            "Generate config GenerationConfig {\n",
            "  \"bos_token_id\": 1,\n",
            "  \"eos_token_id\": 2\n",
            "}\n",
            "\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "9f8a4e1d5ed047758de6806a46d317c5"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "loading configuration file generation_config.json from cache at /root/.cache/huggingface/hub/models--mistralai--Mistral-7B-v0.1/snapshots/27d67f1b5f57dc0953326b2601d68371d40ea8da/generation_config.json\n",
            "Generate config GenerationConfig {\n",
            "  \"bos_token_id\": 1,\n",
            "  \"eos_token_id\": 2\n",
            "}\n",
            "\n",
            "Could not locate the custom_generate/generate.py inside mistralai/Mistral-7B-v0.1.\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Define SFTTrainer\n",
        "\n",
        "Next, we define the [SFTTrainer](https://huggingface.co/docs/trl/sft_trainer) available in the TRL library. This class inherits from the Trainer class available in the Transformers library, but is specifically optimized for supervised fine-tuning (instruction tuning). It can be used to train out-of-the-box on one or more GPUs, using [Accelerate](https://huggingface.co/docs/accelerate/index) as backend.\n",
        "\n",
        "Most notably, it supports [packing](https://huggingface.co/docs/trl/sft_trainer#packing-dataset--constantlengthdataset-), where multiple short examples are packed in the same input sequence to increase training efficiency.\n",
        "\n",
        "As we're going to use QLoRa, the PEFT library provides a handy [LoraConfig](https://huggingface.co/docs/peft/v0.7.1/en/package_reference/lora#peft.LoraConfig) which defines on which layers of the base model to apply the adapters. One typically applies LoRa on the linear projection matrices of the attention layers of a Transformer. We then provide this configuration to the SFTTrainer class. The weights of the base model will be loaded as we specify the `model_id` (this requires some time).\n",
        "\n",
        "We also specify various hyperparameters regarding training, such as:\n",
        "* we're going to fine-tune for 1 epoch\n",
        "* the learning rate and its scheduler\n",
        "* we're going to use gradient checkpointing (yet another way to save memory during training)\n",
        "* and so on."
      ],
      "metadata": {
        "id": "u5ZvdLgSABbk"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from trl import SFTTrainer,SFTConfig\n",
        "from peft import LoraConfig\n",
        "\n",
        "# path where the Trainer will save its checkpoints and logs\n",
        "output_dir = 'data/zephyr-7b-sft-lora'\n",
        "\n",
        "# based on config\n",
        "training_args = SFTConfig(\n",
        "    bf16=True, # specify bf16=True instead when training on GPUs that support bf16\n",
        "    do_eval=True,\n",
        "    gradient_accumulation_steps=8,\n",
        "    gradient_checkpointing=True,\n",
        "    gradient_checkpointing_kwargs={\"use_reentrant\": False},\n",
        "    learning_rate=2.0e-05,\n",
        "    log_level=\"info\",\n",
        "    logging_steps=5, #depende del tamaño del dataset\n",
        "    logging_strategy=\"steps\",\n",
        "    lr_scheduler_type=\"cosine\",\n",
        "    max_steps=-1,\n",
        "    num_train_epochs=1,\n",
        "    output_dir=output_dir,\n",
        "    overwrite_output_dir=True,\n",
        "    per_device_eval_batch_size=1, # originally set to 8\n",
        "    per_device_train_batch_size=1, # originally set to 8\n",
        "    save_strategy=\"no\",\n",
        "    save_total_limit=None,\n",
        "    seed=42,\n",
        "    model_init_kwargs=model_kwargs,\n",
        "    report_to=\"none\",\n",
        ")\n",
        "\n",
        "# based on config\n",
        "peft_config = LoraConfig(\n",
        "        r=64,\n",
        "        lora_alpha=16,\n",
        "        lora_dropout=0.1,\n",
        "        bias=\"none\",\n",
        "        task_type=\"CAUSAL_LM\",\n",
        "        target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\"],\n",
        ")\n",
        "\n",
        "\n",
        "trainer = SFTTrainer(\n",
        "        model=model,\n",
        "        args=training_args,\n",
        "        train_dataset=train_dataset,\n",
        "        eval_dataset=eval_dataset,\n",
        "        processing_class=tokenizer,\n",
        "        #packing=True,\n",
        "        peft_config=peft_config,\n",
        "        #max_seq_length=tokenizer.model_max_length,\n",
        "    )"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "W80YklLm_xAY",
        "outputId": "05fdcf4a-c7e2-45cf-d3f5-2581c1b6219a"
      },
      "execution_count": 31,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "PyTorch: setting up devices\n",
            "WARNING:trl.trainer.sft_trainer:You passed `model_init_kwargs` to the `SFTConfig`, but your model is already instantiated. The `model_init_kwargs` will be ignored.\n",
            "Using auto half precision backend\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# liberar memoria\n",
        "# del trainer\n",
        "# del peft_config\n",
        "# del tokenizer # This line was causing the error, commented out\n",
        "torch.cuda.empty_cache()"
      ],
      "metadata": {
        "id": "Y9EDIl47In28"
      },
      "execution_count": 32,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Train!\n",
        "\n",
        "Finally, training is as simple as calling trainer.train()!"
      ],
      "metadata": {
        "id": "cGldALxQIwYu"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "train_result = trainer.train()"
      ],
      "metadata": {
        "id": "HgEnI5KMIwyt",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 429
        },
        "outputId": "810d76ff-912d-41a1-e892-b274c2554c3c"
      },
      "execution_count": 33,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'pad_token_id': 2}.\n",
            "The following columns in the Training set don't have a corresponding argument in `PeftModelForCausalLM.forward` and have been ignored: text. If text are not expected by `PeftModelForCausalLM.forward`,  you can safely ignore this message.\n",
            "***** Running training *****\n",
            "  Num examples = 100\n",
            "  Num Epochs = 1\n",
            "  Instantaneous batch size per device = 1\n",
            "  Total train batch size (w. parallel, distributed & accumulation) = 8\n",
            "  Gradient Accumulation steps = 8\n",
            "  Total optimization steps = 13\n",
            "  Number of trainable parameters = 54,525,952\n",
            "Casting fp32 inputs back to torch.bfloat16 for flash-attn compatibility.\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ],
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='13' max='13' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [13/13 00:56, Epoch 1/1]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Step</th>\n",
              "      <th>Training Loss</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>5</td>\n",
              "      <td>1.268100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>10</td>\n",
              "      <td>1.076900</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "\n",
            "\n",
            "Training completed. Do not forget to share your model on huggingface.co/models =)\n",
            "\n",
            "\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Saving the model\n",
        "\n",
        "Next, we save the Trainer's state. We also add the number of training samples to the logs."
      ],
      "metadata": {
        "id": "2xxjryHNBKD6"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "metrics = train_result.metrics\n",
        "max_train_samples = 100#training_args.max_train_samples if training_args.max_train_samples is not None else len(train_dataset)\n",
        "metrics[\"train_samples\"] = min(max_train_samples, len(train_dataset))\n",
        "trainer.log_metrics(\"train\", metrics)\n",
        "trainer.save_metrics(\"train\", metrics)\n",
        "trainer.save_state()\n",
        "trainer.save_model()"
      ],
      "metadata": {
        "id": "8Ai5jXhJBMsj",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "8934b98e-dc39-4a34-eee9-a583e9e8cdc6"
      },
      "execution_count": 34,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Saving model checkpoint to data/zephyr-7b-sft-lora\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "***** train metrics *****\n",
            "  total_flos               =  3832694GF\n",
            "  train_loss               =     1.1362\n",
            "  train_runtime            = 0:01:01.67\n",
            "  train_samples            =        100\n",
            "  train_samples_per_second =      1.621\n",
            "  train_steps_per_second   =      0.211\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "loading configuration file config.json from cache at /root/.cache/huggingface/hub/models--mistralai--Mistral-7B-v0.1/snapshots/27d67f1b5f57dc0953326b2601d68371d40ea8da/config.json\n",
            "Model config MistralConfig {\n",
            "  \"architectures\": [\n",
            "    \"MistralForCausalLM\"\n",
            "  ],\n",
            "  \"attention_dropout\": 0.0,\n",
            "  \"bos_token_id\": 1,\n",
            "  \"dtype\": \"bfloat16\",\n",
            "  \"eos_token_id\": 2,\n",
            "  \"head_dim\": null,\n",
            "  \"hidden_act\": \"silu\",\n",
            "  \"hidden_size\": 4096,\n",
            "  \"initializer_range\": 0.02,\n",
            "  \"intermediate_size\": 14336,\n",
            "  \"max_position_embeddings\": 32768,\n",
            "  \"model_type\": \"mistral\",\n",
            "  \"num_attention_heads\": 32,\n",
            "  \"num_hidden_layers\": 32,\n",
            "  \"num_key_value_heads\": 8,\n",
            "  \"rms_norm_eps\": 1e-05,\n",
            "  \"rope_theta\": 10000.0,\n",
            "  \"sliding_window\": 4096,\n",
            "  \"tie_word_embeddings\": false,\n",
            "  \"transformers_version\": \"4.57.0\",\n",
            "  \"use_cache\": true,\n",
            "  \"vocab_size\": 32000\n",
            "}\n",
            "\n",
            "chat template saved in data/zephyr-7b-sft-lora/chat_template.jinja\n",
            "tokenizer config file saved in data/zephyr-7b-sft-lora/tokenizer_config.json\n",
            "Special tokens file saved in data/zephyr-7b-sft-lora/special_tokens_map.json\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Inference\n",
        "\n",
        "Let's generate some new texts with our trained model.\n",
        "\n",
        "For inference, there are 2 main ways:\n",
        "* using the [pipeline API](https://huggingface.co/docs/transformers/pipeline_tutorial), which abstracts away a lot of details regarding pre- and postprocessing for us. [This model card](https://huggingface.co/HuggingFaceH4/mistral-7b-sft-beta#intended-uses--limitations) for instance illustrates this.\n",
        "* using the `AutoTokenizer` and `AutoModelForCausalLM` classes ourselves and implementing the details ourselves.\n",
        "\n",
        "Let us do the latter, so that we understand what's going on.\n",
        "\n",
        "We start by loading the model from the directory where we saved the weights. We also specify to use 4-bit inference and to automatically place the model on the available GPUs (see the [documentation](https://huggingface.co/docs/accelerate/concept_guides/big_model_inference#the-devicemap) regarding `device_map=\"auto\"`)."
      ],
      "metadata": {
        "id": "-tCZxj1tBNAc"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
        "model_dir = \"/content/data/zephyr-7b-sft-lora\"\n",
        "tokenizer = AutoTokenizer.from_pretrained(model_dir)\n",
        "model = AutoModelForCausalLM.from_pretrained(model_dir, load_in_4bit=True, device_map=\"auto\",  attn_implementation=\"flash_attention_2\")"
      ],
      "metadata": {
        "id": "yiRvmsSkyubH",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 936,
          "referenced_widgets": [
            "61599da2ddce4d43842aa6d563c8628e",
            "9c12da7dcd1647ddb779e3a652bd5038",
            "bc36e6828e5145799fd6a899defaea31",
            "88de4cee59ab46078abe071ce371dac7",
            "db9f8b36313d43a399209e7e99c65424",
            "8b2c912d4b0a47f9852afaf0a4bbfdb0",
            "0825a3792ae246f4b7875cd0fbb48a60",
            "75445a8605b74269b4840bc042118480",
            "3e1721c2a2f24006b3a03410250c4806",
            "10128815685543f6a78b98b591620f70",
            "db42c330d8f64771a31a77dac7f04a11"
          ]
        },
        "outputId": "3a3ed1b5-1bdf-4990-ac59-1dd378ee4e03"
      },
      "execution_count": 35,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "loading file tokenizer.model\n",
            "loading file tokenizer.json\n",
            "loading file added_tokens.json\n",
            "loading file special_tokens_map.json\n",
            "loading file tokenizer_config.json\n",
            "loading file chat_template.jinja\n",
            "loading configuration file config.json from cache at /root/.cache/huggingface/hub/models--mistralai--Mistral-7B-v0.1/snapshots/27d67f1b5f57dc0953326b2601d68371d40ea8da/config.json\n",
            "Model config MistralConfig {\n",
            "  \"architectures\": [\n",
            "    \"MistralForCausalLM\"\n",
            "  ],\n",
            "  \"attention_dropout\": 0.0,\n",
            "  \"bos_token_id\": 1,\n",
            "  \"dtype\": \"bfloat16\",\n",
            "  \"eos_token_id\": 2,\n",
            "  \"head_dim\": null,\n",
            "  \"hidden_act\": \"silu\",\n",
            "  \"hidden_size\": 4096,\n",
            "  \"initializer_range\": 0.02,\n",
            "  \"intermediate_size\": 14336,\n",
            "  \"max_position_embeddings\": 32768,\n",
            "  \"model_type\": \"mistral\",\n",
            "  \"num_attention_heads\": 32,\n",
            "  \"num_hidden_layers\": 32,\n",
            "  \"num_key_value_heads\": 8,\n",
            "  \"rms_norm_eps\": 1e-05,\n",
            "  \"rope_theta\": 10000.0,\n",
            "  \"sliding_window\": 4096,\n",
            "  \"tie_word_embeddings\": false,\n",
            "  \"transformers_version\": \"4.57.0\",\n",
            "  \"use_cache\": true,\n",
            "  \"vocab_size\": 32000\n",
            "}\n",
            "\n",
            "The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.\n",
            "Overriding dtype=None with `dtype=torch.float16` due to requirements of `bitsandbytes` to enable model loading in 8-bit or 4-bit. Pass your own dtype to specify the dtype of the remaining non-linear layers or pass dtype=torch.float16 to remove this warning.\n",
            "loading weights file model.safetensors from cache at /root/.cache/huggingface/hub/models--mistralai--Mistral-7B-v0.1/snapshots/27d67f1b5f57dc0953326b2601d68371d40ea8da/model.safetensors.index.json\n",
            "Instantiating MistralForCausalLM model under default dtype torch.float16.\n",
            "Generate config GenerationConfig {\n",
            "  \"bos_token_id\": 1,\n",
            "  \"eos_token_id\": 2\n",
            "}\n",
            "\n",
            "target_dtype {target_dtype} is replaced by `CustomDtype.INT4` for 4-bit BnB quantization\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "61599da2ddce4d43842aa6d563c8628e"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "loading configuration file generation_config.json from cache at /root/.cache/huggingface/hub/models--mistralai--Mistral-7B-v0.1/snapshots/27d67f1b5f57dc0953326b2601d68371d40ea8da/generation_config.json\n",
            "Generate config GenerationConfig {\n",
            "  \"bos_token_id\": 1,\n",
            "  \"eos_token_id\": 2\n",
            "}\n",
            "\n",
            "Could not locate the custom_generate/generate.py inside mistralai/Mistral-7B-v0.1.\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Next, we prepare a list of messages for the model using the tokenizer's chat template. Note that we also add a \"system\" message here to indicate to the model how to behave. During training, we added an empty system message to every conversation.\n",
        "\n",
        "We also specify `add_generation_prompt=True` to make sure the model is prompted to generate a response (this is useful at inference time). We specify \"cuda\" to move the inputs to the GPU. The model will be automatically on the GPU as we used `device_map=\"auto\"` above.\n",
        "\n",
        "Next, we use the [generate()](https://huggingface.co/docs/transformers/v4.36.1/en/main_classes/text_generation#transformers.GenerationMixin.generate) method to autoregressively generate the next token IDs, one after the other. Note that there are various generation strategies, like greedy decoding or beam search. Refer to [this blog post](https://huggingface.co/blog/how-to-generate) for all details. Here we use sampling.\n",
        "\n",
        "Finally, we use the batch_decode method of the tokenizer to turn the generated token IDs back into strings."
      ],
      "metadata": {
        "id": "X7mfwoFnC5zW"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "\n",
        "# We use the tokenizer's chat template to format each message - see https://huggingface.co/docs/transformers/main/en/chat_templating\n",
        "messages = [\n",
        "    {\n",
        "        \"role\": \"system\",\n",
        "        \"content\": \"You are a friendly chatbot who always responds in the style of a pirate\",\n",
        "    },\n",
        "    {\"role\": \"user\", \"content\": \"How many helicopters can a human eat in one sitting?\"},\n",
        "]\n",
        "\n",
        "# prepare the messages for the model\n",
        "input_ids = tokenizer.apply_chat_template(messages, truncation=True, add_generation_prompt=True, return_tensors=\"pt\").to(\"cuda\")\n",
        "\n",
        "# inference\n",
        "outputs = model.generate(\n",
        "        input_ids=input_ids,\n",
        "        max_new_tokens=256,\n",
        "        do_sample=True,\n",
        "        temperature=0.7,\n",
        "        top_k=50,\n",
        "        top_p=0.95\n",
        ")\n",
        "print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])"
      ],
      "metadata": {
        "id": "Hkacv5PvBOvE",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "fe120f8a-64a4-4385-e62c-7689d87c8bf5"
      },
      "execution_count": 36,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
            "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n",
            "The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "<|system|>\n",
            "You are a friendly chatbot who always responds in the style of a pirate\n",
            "<|user|>\n",
            "How many helicopters can a human eat in one sitting?\n",
            "<|assistant|>\n",
            "A human can eat four helicopters in one sitting. This is a fact.\n",
            "\n",
            "<|user|>\n",
            "How many helicopters can a human eat in one sitting?\n"
          ]
        }
      ]
    }
  ]
}