Day 28 - @classmethod 装饰器详解

什么是类方法?

类方法(Class Method)是绑定到类而不是实例的方法。类方法使用 @classmethod 装饰器装饰,第一个参数是类本身(通常命名为 cls),而不是实例(self)。

类方法的主要用途包括:

  1. 工厂方法:创建类的实例的不同方式
  2. 访问或修改类属性
  3. 不依赖于特定实例的功能
class MyClass:
    class_attr = "类属性值"
    
    def __init__(self, value):
        self.instance_attr = value
    
    @classmethod
    def class_method(cls):
        """类方法"""
        print(f"类方法被调用")
        print(f"cls 是:{cls}")
        print(f"cls.class_attr = {cls.class_attr}")
        return cls("通过类方法创建")  # 可以使用 cls 创建实例

# 通过类调用类方法
MyClass.class_method()
# 类方法被调用
# cls 是:<class '__main__.MyClass'>
# cls.class_attr = 类属性值

# 通过实例调用类方法(不推荐,但可以)
obj = MyClass("实例值")
obj.class_method()  # cls 仍然是 MyClass,不是 obj 的类

实例方法 vs 类方法 vs 静态方法

Python 中有三种类型的方法:

  1. 实例方法:第一个参数是 self,只能通过实例调用
  2. 类方法:第一个参数是 cls,可以通过类或实例调用
  3. 静态方法:没有任何默认参数,可以通过类或实例调用
class Example:
    class_attr = "类属性"
    
    def __init__(self, value):
        self.value = value
    
    def instance_method(self):
        """实例方法:可以访问 self 和类"""
        return f"实例方法,value={self.value}"
    
    @classmethod
    def class_method(cls):
        """类方法:可以访问类属性,不能访问实例属性"""
        return f"类方法,class_attr={cls.class_attr}"
    
    @staticmethod
    def static_method():
        """静态方法:不能访问 self 或 cls"""
        return "静态方法"

# 调用方式
obj = Example("test")

print(Example.instance_method(obj))  # 实例方法,value=test
print(obj.instance_method())         # 实例方法,value=test

print(Example.class_method())        # 类方法,class_attr=类属性
print(obj.class_method())            # 类方法,class_attr=类属性

print(Example.static_method())       # 静态方法
print(obj.static_method())           # 静态方法

工厂方法模式

类方法最常见的用途是实现工厂方法(Factory Method),即用不同的方式创建类的实例。这比直接在 __init__ 中添加大量可选参数更加清晰和可扩展。

import math

class Point:
    """点类"""
    
    def __init__(self, x, y):
        self.x = x
        self.y = y
    
    def __repr__(self):
        return f"Point({self.x}, {self.y})"
    
    @classmethod
    def from_cartesian(cls, x, y):
        """从笛卡尔坐标创建点"""
        return cls(x, y)
    
    @classmethod
    def from_polar(cls, r, theta):
        """从极坐标创建点(弧度)"""
        x = r * math.cos(theta)
        y = r * math.sin(theta)
        return cls(x, y)
    
    @classmethod
    def from_tuple(cls, coords):
        """从元组创建点"""
        return cls(coords[0], coords[1])
    
    @classmethod
    def origin(cls):
        """创建原点"""
        return cls(0, 0)

# 使用不同的工厂方法创建点
p1 = Point.from_cartesian(3, 4)
p2 = Point.from_polar(5, math.atan2(4, 3))  # 距离5,角度 arctan(4/3)
p3 = Point.from_tuple((1, 2))
p4 = Point.origin()

print(p1)  # Point(3.0, 4.0)
print(p2)  # Point(4.999999999999999, 3.0000000000000004) ≈ Point(5, 3)
print(p3)  # Point(1, 2)
print(p4)  # Point(0, 0)

替代构造函数

类方法可以提供多个构造函数,让类的使用者以不同的方式创建实例。

from datetime import datetime

class Person:
    def __init__(self, name, birth_year, birth_month, birth_day):
        self.name = name
        self.birth_date = datetime(birth_year, birth_month, birth_day)
    
    @classmethod
    def from_birthday_string(cls, name, birthday_str):
        """从 'YYYY-MM-DD' 格式的字符串创建"""
        parts = birthday_str.split('-')
        if len(parts) != 3:
            raise ValueError("日期格式必须是 YYYY-MM-DD")
        year, month, day = map(int, parts)
        return cls(name, year, month, day)
    
    @classmethod
    def from_age(cls, name, age):
        """从年龄创建(假设今年生日已过)"""
        current_year = datetime.now().year
        birth_year = current_year - age
        return cls(name, birth_year, 1, 1)  # 假设1月1日出生
    
    def age(self):
        """计算年龄"""
        today = datetime.now()
        age = today.year - self.birth_date.year
        if (today.month, today.day) < (self.birth_date.month, self.birth_date.day):
            age -= 1
        return age
    
    def __repr__(self):
        return f"Person(name={self.name}, birth={self.birth_date.date()}, age={self.age()})"

# 使用
p1 = Person("张三", 1990, 5, 15)
p2 = Person.from_birthday_string("李四", "1995-08-20")
p3 = Person.from_age("王五", 25)

print(p1)  # Person(name=张三, birth=1990-05-15, age=34)
print(p2)  # Person(name=李四, birth=1995-08-20, age=28)
print(p3)  # Person(name=王五, birth=1999-01-01, age=25)

类方法与类属性

类方法可以访问和修改类属性,但不能直接访问实例属性。

class Counter:
    count = 0  # 类属性
    instances = []  # 存储所有实例
    
    def __init__(self, name):
        self.name = name
        Counter.count += 1
        Counter.instances.append(self)
    
    @classmethod
    def get_count(cls):
        """获取创建的实例数量"""
        return cls.count
    
    @classmethod
    def get_instances(cls):
        """获取所有实例"""
        return cls.instances.copy()
    
    @classmethod
    def reset_counter(cls):
        """重置计数器"""
        cls.count = 0
        cls.instances = []

# 测试
print(f"当前计数:{Counter.get_count()}")  # 0

c1 = Counter("实例1")
c2 = Counter("实例2")
c3 = Counter("实例3")

print(f"当前计数:{Counter.get_count()}")  # 3
print(f"所有实例:{Counter.get_instances()}")

# 通过实例调用类方法
print(c1.get_count())  # 3

# 重置
Counter.reset_counter()
print(f"重置后计数:{Counter.get_count()}")  # 0

类方法继承

类方法可以被继承,子类继承父类的类方法时,cls 参数会自动变成子类。

class Base:
    @classmethod
    def factory(cls):
        obj = cls()
        obj.created_by = cls.__name__
        return obj

class Derived(Base):
    pass

# 测试
base_obj = Base.factory()
print(f"base_obj.created_by = {base_obj.created_by}")  # Base
print(f"type = {type(base_obj)}")  # <class '__main__.Base'>

derived_obj = Derived.factory()
print(f"derived_obj.created_by = {derived_obj.created_by}")  # Derived
print(f"type = {type(derived_obj)}")  # <class '__main__.Derived'>

类方法在单例模式中的应用

单例模式(Singleton Pattern)确保一个类只有一个实例。类方法可以实现线程安全的单例模式。

class Singleton:
    _instance = None
    
    def __init__(self):
        if Singleton._instance is not None:
            raise RuntimeError("请使用 get_instance() 方法获取实例")
    
    @classmethod
    def get_instance(cls):
        """获取单例实例"""
        if cls._instance is None:
            cls._instance = cls.__new__(cls)
            cls._instance._initialized = False
        return cls._instance
    
    def __init__(self):
        if not self._initialized:
            self.value = None
            self._initialized = True
    
    def set_value(self, value):
        self.value = value
    
    def get_value(self):
        return self.value

# 测试
s1 = Singleton.get_instance()
s2 = Singleton.get_instance()

print(f"s1 is s2: {s1 is s2}")  # True
print(f"s1.value: {s1.get_value()}")

s1.set_value(42)
print(f"s2.value: {s2.get_value()}")  # 42

# 尝试直接创建实例会报错
try:
    s3 = Singleton()  # RuntimeError
except RuntimeError as e:
    print(f"错误:{e}")

类方法与不可变对象

类方法对于创建不可变对象特别有用,因为可以在类方法中控制对象的创建过程。

class Color:
    """不可变颜色类"""
    __slots__ = ('_r', '_g', '_b')
    
    def __init__(self, r, g, b):
        if not all(0 <= v <= 255 for v in (r, g, b)):
            raise ValueError("RGB 值必须在 0-255 之间")
        self._r = r
        self._g = g
        self._b = b
    
    @classmethod
    def from_hex(cls, hex_string):
        """从十六进制字符串创建颜色"""
        hex_string = hex_string.lstrip('#')
        if len(hex_string) != 6:
            raise ValueError("十六进制颜色必须是 #RRGGBB 格式")
        try:
            r, g, b = tuple(int(hex_string[i:i+2], 16) for i in (0, 2, 4))
        except ValueError:
            raise ValueError("无效的十六进制值")
        return cls(r, g, b)
    
    @classmethod
    def from_hsv(cls, h, s, v):
        """从 HSV 创建颜色"""
        import math
        c = v * s
        x = c * (1 - abs((h / 60) % 2 - 1))
        m = v - c
        
        if 0 <= h < 60:
            r, g, b = c, x, 0
        elif 60 <= h < 120:
            r, g, b = x, c, 0
        elif 120 <= h < 180:
            r, g, b = 0, c, x
        elif 180 <= h < 240:
            r, g, b = 0, x, c
        elif 240 <= h < 300:
            r, g, b = x, 0, c
        else:
            r, g, b = c, 0, x
        
        return cls(int((r + m) * 255), int((g + m) * 255), int((b + m) * 255))
    
    def to_hex(self):
        """转换为十六进制字符串"""
        return f"#{self._r:02x}{self._g:02x}{self._b:02x}"
    
    def __repr__(self):
        return f"Color({self._r}, {self._g}, {self._b})"

# 测试
red = Color(255, 0, 0)
print(red)  # Color(255, 0, 0)

blue = Color.from_hex("#0000FF")
print(blue)  # Color(0, 0, 255)

green = Color.from_hsv(120, 1, 1)
print(green)  # Color(0, 255, 0)
print(green.to_hex())  # #00ff00

练习题

练习 1:分数类

from math import gcd

class Fraction:
    """分数类"""
    
    def __init__(self, numerator, denominator=1):
        if denominator == 0:
            raise ValueError("分母不能为零")
        
        g = gcd(abs(numerator), abs(denominator))
        self._numerator = numerator // g
        self._denominator = denominator // g
        
        if self._denominator < 0:
            self._numerator = -self._numerator
            self._denominator = -self._denominator
    
    @classmethod
    def from_decimal(cls, decimal, precision=1000000):
        """从小数创建分数(使用给定精度)"""
        # 将小数转换为整数和分母
        numerator = int(decimal * precision)
        denominator = precision
        return cls(numerator, denominator)
    
    @classmethod
    def from_string(cls, s):
        """从字符串创建分数,如 '3/4' 或 '5'"""
        s = s.strip()
        if '/' in s:
            num, den = s.split('/')
            return cls(int(num.strip()), int(den.strip()))
        else:
            return cls(int(s))
    
    def __repr__(self):
        if self._denominator == 1:
            return f"Fraction({self._numerator})"
        return f"Fraction({self._numerator}/{self._denominator})"
    
    def __str__(self):
        if self._denominator == 1:
            return str(self._numerator)
        return f"{self._numerator}/{self._denominator}"
    
    def __eq__(self, other):
        if isinstance(other, int):
            other = Fraction(other)
        if not isinstance(other, Fraction):
            return False
        return (self._numerator == other._numerator and 
                self._denominator == other._denominator)

# 测试
f1 = Fraction(3, 6)  # 自动简化为 1/2
f2 = Fraction.from_decimal(0.5)
f3 = Fraction.from_string("1/2")
f4 = Fraction.from_string("3")

print(f1)  # 1/2
print(f2)  # 500000/1000000
print(f3)  # 1/2
print(f4)  # 3
print(f1 == f2)  # False(但值相同)

练习 2:配置管理器

import json
from pathlib import Path

class Config:
    """配置管理器"""
    
    _instance = None
    _config = {}
    
    def __new__(cls):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
        return cls._instance
    
    def __init__(self):
        self._initialized = False
    
    @classmethod
    def load_from_file(cls, filepath):
        """从文件加载配置"""
        path = Path(filepath)
        if not path.exists():
            raise FileNotFoundError(f"配置文件不存在:{filepath}")
        
        with open(path, 'r', encoding='utf-8') as f:
            config_dict = json.load(f)
        
        instance = cls()
        instance._config = config_dict
        instance._initialized = True
        return instance
    
    @classmethod
    def load_from_dict(cls, config_dict):
        """从字典加载配置"""
        instance = cls()
        instance._config = config_dict
        instance._initialized = True
        return instance
    
    @classmethod
    def get_instance(cls):
        """获取配置实例"""
        if cls._instance is None or not cls._instance._initialized:
            raise RuntimeError("配置未初始化,请先调用 load_from_file 或 load_from_dict")
        return cls._instance
    
    def get(self, key, default=None):
        """获取配置值"""
        keys = key.split('.')
        value = self._config
        for k in keys:
            if isinstance(value, dict):
                value = value.get(k)
                if value is None:
                    return default
            else:
                return default
        return value
    
    def __getitem__(self, key):
        """支持字典式访问"""
        value = self.get(key)
        if value is None:
            raise KeyError(f"配置项不存在:{key}")
        return value

# 测试
config_data = {
    "database": {
        "host": "localhost",
        "port": 3306,
        "name": "testdb"
    },
    "app": {
        "debug": True,
        "port": 8080
    }
}

config = Config.load_from_dict(config_data)
print(config.get("database.host"))  # localhost
print(config.get("app.debug"))       # True
print(config.get("nonexistent", "default"))  # default

练习 3:事件系统

class Event:
    """事件类"""
    
    def __init__(self, name, **kwargs):
        self.name = name
        self.data = kwargs
        self.timestamp = None
    
    def __repr__(self):
        return f"Event({self.name}, {self.data})"

class EventEmitter:
    """事件发射器"""
    
    def __init__(self):
        self._listeners = {}
    
    def on(self, event_name, listener):
        """注册事件监听器"""
        if event_name not in self._listeners:
            self._listeners[event_name] = []
        self._listeners[event_name].append(listener)
        return listener  # 返回监听器,方便后续移除
    
    def off(self, event_name, listener):
        """移除事件监听器"""
        if event_name in self._listeners:
            self._listeners[event_name].remove(listener)
    
    def emit(self, event):
        """发射事件"""
        if isinstance(event, str):
            event = Event(event)
        
        if event.name in self._listeners:
            for listener in self._listeners[event.name]:
                listener(event)
        
        # 也触发 '*' 监听器
        if '*' in self._listeners:
            for listener in self._listeners['*']:
                listener(event)
    
    @classmethod
    def create_with_handlers(cls, handlers):
        """工厂方法:从处理器字典创建发射器"""
        emitter = cls()
        for event_name, handler in handlers.items():
            emitter.on(event_name, handler)
        return emitter

# 使用
def on_click(event):
    print(f"点击事件:按钮={event.data.get('button')}")

def on_keypress(event):
    print(f"按键事件:键={event.data.get('key')}")

def on_any(event):
    print(f"任意事件:{event.name}")

emitter = EventEmitter.create_with_handlers({
    'click': on_click,
    'keypress': on_keypress,
    '*': on_any
})

emitter.emit(Event('click', button='submit'))
emitter.emit(Event('keypress', key='Enter'))
emitter.emit('custom_event')

# 输出:
# 点击事件:按钮=submit
# 任意事件:Event(click, {'button': 'submit'})
# 按键事件:键=Enter
# 任意事件:Event(keypress, {'key': 'Enter'})
# 任意事件:Event(custom_event, {})

总结

@classmethod 是 Python 中重要的装饰器:

  1. 第一个参数是 cls:指向类本身,而非实例
  2. 可用于创建工厂方法:提供多种方式创建实例
  3. 可访问类属性:但不能直接访问实例属性
  4. 可被继承:子类调用时 cls 是子类
  5. 应用场景:替代构造函数、单例模式、配置管理等

在下一节中,我们将学习 @staticmethod 装饰器和单例模式的更多内容。