from vnpy_ctastrategy import (
    CtaTemplate,
    StopOrder,
    TickData,
    BarData,
    TradeData,
    OrderData,
    BarGenerator,
    ArrayManager,
)


class MacdStrategy(CtaTemplate):
    author = "siye"

    fast_window = 10
    slow_window = 20
    signal_period = 20

    fast_macd0 = 0.0
    fast_macd1 = 0.0

    slow_macd0 = 0.0
    slow_macd1 = 0.0


    parameters = ["fast_window", "slow_window"]
    variables = ["fast_macd0", "fast_macd1", "slow_macd0", "slow_macd1"]

    def __init__(self, cta_engine, strategy_name, vt_symbol, setting):
        """"""
        super().__init__(cta_engine, strategy_name, vt_symbol, setting)

        self.bg = BarGenerator(self.on_bar)
        self.am = ArrayManager()

    def on_init(self):
        """
        Callback when strategy is inited.
        """
        self.write_log("策略初始化")
        self.load_bar(10)

    def on_start(self):
        """
        Callback when strategy is started.
        """
        self.write_log("策略启动")
        self.put_event()

    def on_stop(self):
        """
        Callback when strategy is stopped.
        """
        self.write_log("策略停止")

        self.put_event()

    def on_tick(self, tick: TickData):
        """
        Callback of new tick data update.
        """
        self.bg.update_tick(tick)

    def on_bar(self, bar: BarData):
        """
        Callback of new bar data update.
        """

        am = self.am
        am.update_bar(bar)
        if not am.inited:
            return

        fast_macd = am.macd(self.fast_window,self.slow_window,self.signal_period,True)
        self.fast_macd0 = fast_macd[-1]
        self.fast_macd1 = fast_macd[-2]

        slow_macd = am.macd(self.fast_window,self.slow_window,self.signal_period,True)
        self.slow_macd0 = slow_macd[-1]
        self.slow_macd1 = slow_macd[-2]


        cross_over = self.fast_macd0 > self.slow_macd0 and self.fast_macd1 < self.slow_macd1
        cross_below = self.fast_macd0 < self.slow_macd0 and self.fast_macd1 > self.slow_macd1

        if cross_over:
            if self.pos == 0:
                self.buy(bar.close_price, 1)
            elif self.pos < 0:
                self.cover(bar.close_price, 1)
                self.buy(bar.close_price, 1)

        elif cross_below:
            if self.pos == 0:
                self.short(bar.close_price, 1)
            elif self.pos > 0:
                self.sell(bar.close_price, 1)
                self.short(bar.close_price, 1)

        self.put_event()

    def on_order(self, order: OrderData):
        """
        Callback of new order data update.
        """
        pass

    def on_trade(self, trade: TradeData):
        """
        Callback of new trade data update.
        """
        self.put_event()

    def on_stop_order(self, stop_order: StopOrder):
        """
        Callback of stop order update.
        """
        pass