Refaktoring krok po kroku - jak uniknąć kuli błota


Sprawdź jak refaktoryzować aplikację krok, bo kroku, żeby wyjść z wielkiej kuli błota i sprawić żeby kod był czytelniejszy!


Z każdą kolejną książką, kursem i przerobionym materiałem dochodzę do wniosku, że programowanie samo z siebie jest proste. Tak naprawdę w dużej części jest to rzemieślnicza praca. Dużo trudniejsze, ale za razem ciekawsze i bardziej angażujące jest faktyczne rozwiązywanie problemów biznesowych. Rozwiązywanie problemów jest natomiast bardzo trudne, zwłaszcza jeśli masz doczynienia z kodem legacy. Najpierw musisz nauczyć się wychodzić z tego kodu. Do tego z kolei niezbędne są w twoim warsztacie dobre techniki refaktoryzacyjne. I o tych technikach chciałbym dzisiaj porozmawiać. Przejdźmy przez typowy przykład rozrastania się kodu, który powoduje jego psucie, a następnie spróbujmy to nieco naprawić. Zaciekawiło Cię to? Zapraszam do czytania!

Czym jest refaktoring?

Rafaktoring to zmiana struktury kodu, bez zmiany obserwowalnych zachowań. To jedna z definicji refaktoringu, do której mi najbliżej i jej będę się trzymać przez ten wpis. A przejdziemy razem przez cały proces i zobaczymy jak taki refaktoring poprawia czytelność i utrzymywalność naszego kodu.

Proste wymagania

Wyobraźmy sobie sytuację w której robimy system do sprzedaży książek. Początkowo mamy prostą klasę reprezentującą koszyk i przedmioty w tym koszyku. Kod mógłby wyglądać następująco

import uuid
from dataclasses import dataclass, field


@dataclass
class Item:
    name: str
    price: int
    id: uuid.UUID = field(default_factory=uuid.uuid4)


class Basket:
    def __init__(self):
        self._items = {}

    def add(self, item: Item):
        self._items[item.id] = item

    def total_cost(self):
        return sum([item.price for item in self._items.values()])

Pierwsza zmiana wymagań i rozbudowa wyliczenia ceny

I nagle przychodzi pierwsza zmiana wymagań. Zmiana nie jest duża. Chodzi o to, że jak klient kupi więcej niż 5 przedmiotów to dostaje rabat w wysokości 5%. Zmiana nie wydaje się skomplikowana. Łatwo możemy namierzyć miejsce z wyliczaniem ceny, więc dodanie tam jednego ifa będzie bardzo proste. Dla uproszczenia wkleję tylko zmodyfikowaną metodę

class Basket:
    # ...
    def total_cost(self):
        price = sum([item.price for item in self._items.values()])
        if len(self._items) < 5:
            return price
        else:
            return price * 0.95

Druga zmiana wymagań i kolejna rozbudowa

Wszystko wygląda dobrze, ale jak łatwo się domyślić, po czasie przychodzi drugie wymaganie. Teraz dochodzi nowy próg, czyli 10 przedmiotów. Powyżej tej ilości naliczamy aż 7% rabatu!

class Basket:
    # ...
    def total_cost(self):
        price = sum([item.price for item in self._items.values()])
        if len(self._items) >= 10:
            return price * 0.93
        elif len(self._items) >= 5:
            return price * 0.95
        else:
            return price

Code Review i usprzątanie kodu

No i tutaj do gry wchodzi senior, który robi CodeReview, czyli przegląd kodu. Intencje ma jak najbardziej słuszne. Sugeruje, że nie powinno stosować się magicznych wartości i powinno się to wynieść do stałych. Generalnie pomysł jest dobry. Ale nasz biedny junior nie potrafił wymyślić dobrej nazwy na te zmienne, więc zrobił coś następującego

FIVE = 5
TEN = 10
NINETY_FIVE = 95
NINETY_THREE = 93



class Basket:
    # ...
    def total_cost(self):
        price = sum([item.price for item in self._items.values()])
        if len(self._items) >= TEN:
            return price * NINETY_THREE
        elif len(self._items) >= FIVE:
            return price * NINETY_FIVE
        else:
            return price

Trzecia zmiana wymagań i "bałagan"

Jak wiemy, jedyną stałą rzeczą jest zmiana, więc przychodzi kolejna zmiana wymagań. Tym razem biznes uznał, że te progi, to są zbyt restrykcyjne dla klientów i powinno się je zmienić na 2 i 8, zamiast 5 i 10. Nasz młody bohater przychodzi ze zmianą i robi to najszybciej jak to możliwe, czyli podmienia dwie linijki ze stałymi

FIVE = 2
TEN = 8
NINETY_FIVE = 95
NINETY_THREE = 93



class Basket:
    # ...
    def total_cost(self):
        price = sum([item.price for item in self._items.values()])
        if len(self._items) >= TEN:
            return price * NINETY_THREE
        elif len(self._items) >= FIVE:
            return price * NINETY_FIVE
        else:
            return price

Jeśli uśmiechnąłeś, bądź usmiechnęłaś się pod nosem, to powiem tylko, że podobne przypadki widziałem na produkcji i wcale nie było to śmieszne ;)

Czwarta zmiana wymagań i wycofanie się ze stałych

Ale idąc dalej biznes jest nieustępliwi i ponownie zmienia wymagania, ale tym razem ustawia zupełnie nowe progi. Teraz sytuacja wygląda jak w tabeli pod spodem

Ilość sztuk poniżej 2 poniżej 5 poniżej 10 powyżej 10
Rabat 0% 2% 3% 5%

Implementacja mogłaby zatem wyglądać następująco

class Basket:
    # ...
    def total_cost(self):
        price = sum([item.price for item in self._items.values()])
        if len(self._items) >= 10:
            return price * 0.95
        elif len(self._items) >= 5:
            return price * 0.97
        elif len(self._items) >= 2:
            return price * 0.98
        else:
            return price

Piąta zmiana wymagań i wprowadzenie hurtownika

I nagle przychodzi kolejna zmiana wymagań. Tym razem proste zadanie o nieskomplikowanej treści " Jeśli jesteś hurtownikiem dostajeś +5% rabatu ". Nie przystępuja nam nic innego jak zaimplementowanie tego. Pojawia się jednak pierwszy problem, bo nie mamy w koszyku użytkownika, ale senior deweloper doradza, że powinno się go wstrzyknąć podczas tworzenia koszyka i będzie dobrze. No to dodajmy przykładową klasę użytkownika i rozbudujmy nasz koszyk.

@dataclass
class User:
    name: str
    status: str
    id: uuid.UUID = field(default_factory=uuid.uuid4)


class Basket:
    def __init__(self, user: User):
        self._items = {}
        self._user = user

I teraz możemy rozbudować naszą metodę. Wymaganie jest proste +5% rabatu. No to sprawa wydaje się banalna.

class Basket:
    # ...
    def total_cost(self):
        # promotion by ammount
        price = sum([item.price for item in self._items.values()])
        if len(self._items) >= 10:
            price = price * 0.95
        elif len(self._items) >= 5:
            price = price * 0.97
        elif len(self._items) >= 2:
            price price * 0.98

        # promotion by user
        if self._user.status == "wholesaler":
            price = price * 0.95

        return price

Awaria i szybka łatka

I pojawia się pierwsza awaria. Po pierwsze nie powinno być tak, że ktoś kto kupuje tylko 1 produkt już ma rabat! Nawet jeśli jest hurtownikiem. A po drugie, miało być +5% rabatu, a nie kolejne 5% od upuszczonej ceny i już klienci zaczęli skarżyć się, że cena nie jest poprawna. Co generalnie jest prawdą, bo 98% * 0.95% nie jest równe 0.93% (2% i 5%).

Jako, że jest to awaria, to trzeba to szybko naprawić. I tutaj junior wykazał się skrupulatnością i pokrył wszystkie przypadki.

class Basket:
    # ...
    def total_cost(self):
        price = sum([item.price for item in self._items.values()])
        if len(self._items) >= 10:
            if self._user.status == "wholesaler":
                return price * 0.9
            else:
                return price * 0.95
        elif len(self._items) >= 5:
            if self._user.status == "wholesaler":
                return price * 0.92
            else:
                return price * 0.97
        elif len(self._items) >= 2:
            if self._user.status == "wholesaler":
                return price * 0.93
            else:
                return price * 0.98
        return price

Awaria zażegnana, a pożar ugaszony.

Refaktoring jako warunek konieczny

Teraz senior dostał przykaz, że musi skrupulatnie przyglądać się wymaganiom naszego zucha. Jako, że testy są warunkiem koniecznym do przeprowadzania refaktoringu, to pierwsze wymaganie od seniora polega na dopisaniu testów do obecnego kodu.

A więc testy dla klienta detalicznego mogą wyglądać następująco

import unittest

from basket import User, Basket, Item


class TestBasketForRetailer(unittest.TestCase):
    def setUp(self):
        user = User(name="ddeby", status="retailer")
        self.basket = Basket(user=user)

    def test_retailer_with_less_than_2_products_should_pay_standard_price(self):
        self.basket.add(Item(name="cos", price=100))
        self.assertEqual(self.basket.total_cost(), 100)

    def test_retailer_with_less_than_5_products_should_get_2_percents_promition(self):
        for _ in range(4):
            self.basket.add(Item(name="cos", price=25))
        self.assertEqual(self.basket.total_cost(), 98)

    def test_retailer_with_less_than_10_products_should_get_3_percents_promition(self):
        for _ in range(9):
            self.basket.add(Item(name="cos", price=10))
        self.assertEqual(self.basket.total_cost(), 87.3)

    def test_retailer_with_more_than_10_products_should_get_5_percents_promition(self):
        for _ in range(10):
            self.basket.add(Item(name="cos", price=10))
        self.assertEqual(self.basket.total_cost(), 95)
        
if __name__ == '__main__':
    unittest.main()

Z kolei testy dla hurtownika mogą wyglądać następująco

import unittest

from basket import User, Basket, Item


class TestBasketForWholeSaler(unittest.TestCase):
    def setUp(self):
        user = User(name="ddeby", status="wholesaler")
        self.basket = Basket(user=user)

    def test_wholesaler_with_less_than_2_products_should_pay_standard_price(self):
        self.basket.add(Item(name="cos", price=100))
        self.assertEqual(self.basket.total_cost(), 100)

    def test_wholesaler_with_less_than_5_products_should_get_2_percents_promition(self):
        for _ in range(4):
            self.basket.add(Item(name="cos", price=25))
        self.assertEqual(self.basket.total_cost(), 93)

    def test_wholesaler_with_less_than_10_products_should_get_3_percents_promition(self):
        for _ in range(9):
            self.basket.add(Item(name="cos", price=10))
        self.assertEqual(self.basket.total_cost(), 82.8)

    def test_wholesaler_with_more_than_10_products_should_get_5_percents_promition(self):
        for _ in range(10):
            self.basket.add(Item(name="cos", price=10))
        self.assertEqual(self.basket.total_cost(), 90)

if __name__ == '__main__':
    unittest.main()

Pierwsze kroki refaktoryzacji - "małe funkcje"

Następnie senior zauważył, że w zasadzie ta logika nie powinna siedzieć w tak dużej funkcji i warto to wynieść do osobnej funkcji. Co jest zresztą bardzo dobrą zasadą dobrego projektowania. Funkcje powinny być jak najmniejsze oraz powinny mieć pojedyńczą odpowiedzialność. Na szczęście mamy testy, więc taki refaktoring nie powinien być trudny.

class Basket:
    # ...
    def _promotion(self):
        if len(self._items) >= 10:
            if self._user.status == "wholesaler":
                return 0.1
            else:
                return 0.05
        elif len(self._items) >= 5:
            if self._user.status == "wholesaler":
                return 0.08
            else:
                return 0.03
        elif len(self._items) >= 2:
            if self._user.status == "wholesaler":
                return 0.07
            else:
                return 0.02
        return 0

    def total_cost(self):
        price = sum([item.price for item in self._items.values()])
        return price * (1 - self._promotion())

Generalnie taki refactoring to krok w dobrym kierunku. Faktycznie logikę liczenia promocji odseparowaliśmy. Dodatkowo zwracany jest faktyczny rabat, a nie wyliczona już cena, więc nie trzeba w pamięci obliczać na szybko tych procentów.

Drugi kroki refaktoryzacji - "prawo Demeter" i enkapsulacja

Kolejną uwagą jest fakt, że łamiemy tutaj prawo demeter. W dużym skrócie chodzi o to, że nie powinniśmy odpytywać o wnętrzności klasy "user", a powinniśmy polegać na jej interfejsie, bo być może w przyszłości statusy będą zastąpione innym rozwiązaniem. To z kolei dobrze łączy się z dobrymi zasadami programowania obiektowego. Enkapsulacja jako jedna z istotnych cech mówi, żeby chować wewnętrzną strukturę klas i obiektów. No to wykonajmy kolejny krok refaktoryzacyjny.

@dataclass
class User:
    # ...
    def is_wholesaler(self):
        return self.status == "wholesaler"
class Basket:
    # ...
    def _promotion(self):
        if len(self._items) >= 10:
            if self._user.is_wholesaler():
                return 0.1
            else:
                return 0.05
        elif len(self._items) >= 5:
            if self._user.is_wholesaler():
                return 0.08
            else:
                return 0.03
        elif len(self._items) >= 2:
            if self._user.is_wholesaler():
                return 0.07
            else:
                return 0.02
        return 0

    def total_cost(self):
        price = sum([item.price for item in self._items.values()])
        return price * (1 - self._promotion())

Trzeci krok refaktoryzacji - poczucie estetyki

Zrobiliśmy kolejny krok w dobrą stronę. Ale naszemu seniorowi to nie wystarczyło. Seniorowi nie podobało się, że w każdym z progów sprawdzany jest status i powinno to odbywać się w odwrotnej kolejności.

class Basket:
    # ...
    def _promotion(self):
        if self._user.is_wholesaler():
            if len(self._items) >= 10:
                return 0.1
            elif len(self._items) >= 5:
                return 0.08
            elif len(self._items) >= 2:
                return 0.07
            else:
                return 0
        else:
            if len(self._items) >= 10:
                return 0.05
            elif len(self._items) >= 5:
                return 0.03
            elif len(self._items) >= 2:
                return 0.02
            else:
                return 0

Czy jest lepiej? Ja osobiście nie powiedziałbym, że jest to jakiś game changer, ale te rozwiązanie otworzyło nowe furtki seniorowi.

Czwarty krok refaktoryzacji - strategia i dziedziczenie

Przecież to ewidentna strategia. Przecież to oczywiste, że to w kwestii użytkownika jest przeliczanie promocji!

Wprowadzamy więc stosowane zmiany w kodzie

class Wholesaler(User):
    def promotion(self, items):
        if len(items) >= 10:
            return 0.1
        elif len(items) >= 5:
            return 0.08
        elif len(items) >= 2:
            return 0.07
        else:
            return 0


class Retailer(User):
    def promotion(self, items):
        if len(items) >= 10:
            return 0.05
        elif len(items) >= 5:
            return 0.03
        elif len(items) >= 2:
            return 0.02
        else:
            return 0
class Basket:
    # ...
    def _promotion(self):
        return self._user.promotion(self._items)

    def total_cost(self):
        price = sum([item.price for item in self._items.values()])
        return price * (1 - self._promotion())

Musimy też poprawić testy. Na szczęście skorzystaliśmy z setUp, więc zmiana będzie bardzo prosta.

class TestBasketForWholesaler(unittest.TestCase):
    def setUp(self):
        user = Wholesaler(name="ddeby", status="wholesaler")
        self.basket = Basket(user=user)
    # ...


class TestBasketForRetailer(unittest.TestCase):
    def setUp(self):
        user = Retailer(name="ddeby", status="retailer")
        self.basket = Basket(user=user)
    # ...

Podsumowanie

I tu chciałbym się zatrzymać na chwilę. Uważasz, że taki kierunek refaktoryzacji poszedł w dobrą stronę? Niżej wkleję pełny kod razem z testami, a Ciebie proszę o wpisanie w komentarzu w jakim kierunku Ty poszedłbyś ze zmianami. Uważasz, że to był dobry wybór, czy może jednak spróbowałbyś to napisać nieco inaczej? Koniecznie podziel się swoim rozwiązaniem w komentarzu! Zachęcam do takiego eksperymentowania, bo to super ćwiczenie, które pomoże tobie nauczyć się refaktoryzacji, zanim zrobisz to na produkcyjnym przykładzie.

import uuid
from dataclasses import dataclass, field


@dataclass
class User:
    name: str
    status: str
    id: uuid.UUID = field(default_factory=uuid.uuid4)

    def promotion(self):
        raise 0


class Wholesaler(User):
    def promotion(self, items):
        if len(items) >= 10:
            return 0.1
        elif len(items) >= 5:
            return 0.08
        elif len(items) >= 2:
            return 0.07
        else:
            return 0


class Retailer(User):
    def promotion(self, items):
        if len(items) >= 10:
            return 0.05
        elif len(items) >= 5:
            return 0.03
        elif len(items) >= 2:
            return 0.02
        else:
            return 0

@dataclass
class Item:
    name: str
    price: int
    id: uuid.UUID = field(default_factory=uuid.uuid4)


class Basket:
    def __init__(self, user: User):
        self._user = user
        self._items = {}

    def add(self, item: Item):
        self._items[item.id] = item

    def remove(self, uuid: uuid.UUID):
        self._items.pop(uuid, None)

    def _promotion(self):
        return self._user.promotion(self._items)

    def total_cost(self):
        price = sum([item.price for item in self._items.values()])
        return price * (1 - self._promotion())

Testy

import unittest

from basket import Basket, Item, Wholesaler, Retailer


class TestBasketForWholesaler(unittest.TestCase):
    def setUp(self):
        user = Wholesaler(name="ddeby", status="wholesaler")
        self.basket = Basket(user=user)

    def test_wholesaler_with_less_than_2_products_should_pay_standard_price(self):
        self.basket.add(Item(name="cos", price=100))
        self.assertEqual(self.basket.total_cost(), 100)

    def test_wholesaler_with_less_than_5_products_should_get_2_percents_promition(self):
        for _ in range(4):
            self.basket.add(Item(name="cos", price=25))
        self.assertEqual(self.basket.total_cost(), 93)

    def test_wholesaler_with_less_than_10_products_should_get_3_percents_promition(self):
        for _ in range(9):
            self.basket.add(Item(name="cos", price=10))
        self.assertEqual(self.basket.total_cost(), 82.8)

    def test_wholesaler_with_more_than_10_products_should_get_5_percents_promition(self):
        for _ in range(10):
            self.basket.add(Item(name="cos", price=10))
        self.assertEqual(self.basket.total_cost(), 90)


class TestBasketForRetailer(unittest.TestCase):
    def setUp(self):
        user = Retailer(name="ddeby", status="retailer")
        self.basket = Basket(user=user)

    def test_retailer_with_less_than_2_products_should_pay_standard_price(self):
        self.basket.add(Item(name="cos", price=100))
        self.assertEqual(self.basket.total_cost(), 100)

    def test_retailer_with_less_than_5_products_should_get_2_percents_promition(self):
        for _ in range(4):
            self.basket.add(Item(name="cos", price=25))
        self.assertEqual(self.basket.total_cost(), 98)

    def test_retailer_with_less_than_10_products_should_get_3_percents_promition(self):
        for _ in range(9):
            self.basket.add(Item(name="cos", price=10))
        self.assertEqual(self.basket.total_cost(), 87.3)

    def test_retailer_with_more_than_10_products_should_get_5_percents_promition(self):
        for _ in range(10):
            self.basket.add(Item(name="cos", price=10))
        self.assertEqual(self.basket.total_cost(), 95)

if __name__ == '__main__':
    unittest.main()

W kolejnym wpisie spróbujemy do tego podejść nieco inaczej, bo zrobimy refaktoring na poziomie całego modelu .

Apr 19, 2023

Najnowsze wpisy

Zobacz wszystkie