import torch
import torch.nn as nn
import logging

from typing import List, Dict, Optional, Tuple


class Group:
    def __init__(self, name: str, modules: List[nn.Module], params: List[nn.Parameter]):
        self.name = name
        self.modules = modules
        self.params = params


class State:
    def __init__(self, l:int, r:int, e:int, n:int):
        # left:  the first trainable module index
        # right: the last  trainable module index
        # end:   the last executed module index (prefix length); (e,n] are exit (-1)
        # trainable: [l, r], frozen: [1,e]\[l,r], exit: (e,n]
        self.l = l
        self.r = r
        self.e = e
        self.n = n
    
    def clone(self) -> "State":
        return State(self.l, self.r, self.e, self.n)

    def __repr__(self) -> str:
        return f"State(l={self.l}, r={self.r}, e={self.e}, n={self.n})"


class UtilityEMA:
    def __init__(self, n_groups:int, beta:float=0.9):
        self.beta = beta
        self.grad_ema = [0.0]*n_groups
        self.drift_ema= [0.0]*n_groups
        self.conf_ema = [0.0]*n_groups 

    def update_grad(self, j:int, g_norm:float):
        self.grad_ema[j] = self.beta*self.grad_ema[j] + (1-self.beta)*g_norm

    def update_drift(self, j:int, drift:float):
        self.drift_ema[j] = self.beta*self.drift_ema[j] + (1-self.beta)*drift
        
    def update_conf(self, j:int, conf:float):
        self.conf_ema[j] = self.beta*self.conf_ema[j] + (1-self.beta)*conf



#--- Three State Scheduler
# 7 ACTION
import time
import math
import types
import random
import numpy as np
from collections import deque

_EPS = 1e-12
LEFT_SHRINK, LEFT_EXPAND, RIGHT_SHRINK, RIGHT_EXPAND, PUSH, POP, CONT = range(7)
action_dict = {0:'LEFT_SHRINK', 1:'LEFT_EXPAND', 2:'RIGHT_SHRINK', 3:'RIGHT_EXPAND', 
               4:'PUSH', 5:'POP', 6:'CONT'}

class TriMS:
    def __init__(
        self, 
        model: nn.Module,
        optimizer: torch.optim.Optimizer,
        M_max: float,                          # the memory constraint (GB)
        safe_ratio: float=0.1,                 # the remain ratio for memory 
        group_size: int=1,                     # the layer numbers in each group
        offload_exit_to_cpu: bool=True,        # whether to put the exit module to CPU
        init_method: str='head',               # decide the way to generate first state s_0
        switch_period: int=10,                 # how much steps to make a decision
        lambda_M: float=0.2,                   # memory weight in reward
        lambda_T: float=0.1,                   # time weight in reward
        explore_eps: float=0.05,               # the probability for random selection
        device: Optional[torch.device] = None,
    ):
        self.model = model
        self.optimizer   = optimizer
        self.M_max       = M_max*(1024**3)
        self.safe_ratio  = safe_ratio
        self.device      = device
        self.offload_exit_to_cpu = offload_exit_to_cpu

        self.n_switch      = 0
        self.init_method   = init_method
        self.switch_period = switch_period
        self.lambda_M      = lambda_M
        self.lambda_T      = lambda_T
        self.explore_eps   = explore_eps

        self.beta  = 0.9
        self.decay = 0.98

        self._forward_patched = False
        self._orig_forward = None

        self.mem_base_bytes  = 0.0  
        self.aP = 1.0  
        self.aA = 1.0  
        self.aG = 1.0  
        self.aO = 1.0  
        self.time_base       = 0.0   
        self.time_ratio_ema  = 1.0   
        self.calib_beta      = 0.9   

        # decide model & optimizer type
        name = model.__class__.__name__.lower()
        if "roberta" in name:
            self.model_type = "roberta"
        elif "llama" in name:
            self.model_type = "llama"
        else:
            self.model_type = "other"

        self.optimizer_type = optimizer.__class__.__name__.lower()

        # 1. split the module groups
        self.groups: List[Group] = self._build_groups(self.model_type, group_size)
        self.n = len(self.groups) - 1

        # 2. init the first state
        self.state: State = self._first_state(self.init_method)

        # 3. estimate the cost
        self.cost_table = self._init_cost_from_model()

        # 4. apply the state into model
        self._apply_state_to_model(self.state)

        # 5. set window & buffer
        self._buf_loss = deque(maxlen=self.switch_period)
        self._buf_time = deque(maxlen=self.switch_period)
        self._buf_mem  = deque(maxlen=self.switch_period)

        self._prev_win = None # previous window (loss, time, mem)
        self._last_ts  = time.perf_counter()
        torch.cuda.reset_peak_memory_stats()

        self.ue = UtilityEMA(n_groups=self.n, beta=0.9)
        self._act_prev = [None] * self.n
        self._act_cur  = [None] * self.n
        self._register_activation_hooks() 

        self._forward_patched = False
        self._orig_forward = None
        self._bind_trims_to_target() 
    

    def _get_forward_target(self):
        t = self.model
        # DataParallel / DDP
        if hasattr(t, "module"):
            t = t.module
        return t

    def _bind_trims_to_target(self):
        target = self._get_forward_target()
        setattr(target, "_trims", self)




    # ------------------------------------
    # generate the first state (left, right, end) 1...n
    # ------------------------------------
    def _first_state(self, init_method) -> State:
        # only the last group is trainable
        if init_method == 'head':
            return State(self.n, self.n, self.n, self.n)
        # only the last 2 group is trainable
        if init_method == 'head@2':
            return State(self.n-1, self.n, self.n, self.n)
        else:
            return State(self.n, self.n, self.n, self.n)


    # ------------------------------------
    # split the layers into several blocks
    # ------------------------------------
    def _build_groups(self, model_type:str, group_size:int) -> List[Group]:
        groups: List[Group] = []

        def chunk(seq, size):
            return [seq[i:i+size] for i in range(0, len(seq), size)]
        
        if model_type == "roberta":
            blocks: List[nn.Module] = list(self.model.roberta.encoder.layer)

            for i, blk_group in enumerate(chunk(blocks, group_size)):
                mods = blk_group
                params = []
                for m in mods:
                    params.extend(list(p for p in m.parameters()))
                groups.append(Group(name=f"enc[{i*group_size}:{(i+1)*group_size}]",
                                    modules=mods, params=params))

            if hasattr(self.model, "classifier"):
                mods = [self.model.classifier]
                params = list(self.model.classifier.parameters())
                groups.append(Group(name='head', modules=mods, params=params))
        
        elif model_type == 'llama':
            blocks: List[nn.Module] = list(self.model.model.layers)
            for i, blk_group in enumerate(chunk(blocks, group_size)):
                mods = blk_group
                params = []
                for m in mods:
                    params.extend(list(p for p in m.parameters()))
                groups.append(Group(name=f"blk[{i*group_size}:{(i+1)*group_size})",
                                    modules=mods, params=params))

                if hasattr(self.model, "lm_head"):
                    mods = [self.model.lm_head]
                    params = list(self.model.lm_head.parameters())
                    groups.append(Group(name="lm_head", modules=mods, params=params))

        else:
            raise ValueError(f"Unsupported model_type: {model_type}")

        return groups


    # ------------------------------------
    # estimate the memory / time cost
    # - M_P: module parameters
    # - M_G: gradients
    # - M_O: optimizer states
    # - M_A: intermediate activations
    # - T_FP: forward
    # - T_GP: gradient propagation
    # - T_UP: update process
    # ------------------------------------
    def _init_cost_from_model(self) -> Dict[str, float]:

        # detect the save format
        def dtype_bytes(dt: torch.dtype) -> int:
            if dt == torch.float16 or dt == torch.bfloat16:
                return 2
            if dt == torch.float32:
                return 4
            if dt == torch.float64:
                return 8
            else:
                return 4

        # estimate the M_P with the module parameters
        M_Ps = []
        for g in self.groups:
            p_bytes = 0
            for p in g.params:
                p_bytes += p.numel()*dtype_bytes(p.dtype)
            M_Ps.append(p_bytes)
        
        # estimate M_G, M_A & M_O
        M_Gs = M_Ps
        M_As = [0.5 * m for m in M_Ps] # adjust by calibrate() in training

        if self.optimizer_type in ['adam', 'adamw']:
            M_Os = [2.0 * m for m in M_Ps]
        elif self.optimizer_type in ['sgd']:
            M_Os = [0.0 for m in M_Ps]
            for param_group in self.optimizer.param_groups:
                if param_group.get('momentum', 0) > 0:
                    M_Os = [1.0 * m for m in M_Ps]
        elif self.optimizer_type in ['adagrad', 'rmsprop']:
            M_Os = [1.0 * m for m in M_Ps]
        else:
            M_Os = [0.5 * m for m in M_Ps]
        
        # estimate time cost, adjust by calibrate() in training
        T_FPs = [1.0 for m in M_Ps]
        T_GPs = [1.0 for m in M_Ps]
        T_UPs = [0.5 for m in M_Ps]

        return {'M_P': M_Ps, 'M_G': M_Gs, 'M_A': M_As, 'M_O': M_Os,
                'T_FP': T_FPs, 'T_GP': T_GPs, 'T_UP': T_UPs}
        


    # ------------------------------------
    # estimate time cost based on (l,r,e)
    # ------------------------------------
    def estimate_time(self, st: Optional[State]=None)->float:
        s = st or self.state
        Tfp = sum(self.cost_table['T_FP'][:s.e])
        Tgp = sum(self.cost_table['T_GP'][s.l-1:s.e])
        Tup = sum(self.cost_table['T_UP'][s.l-1:s.r])
        core = Tfp + Tgp + Tup
        # 比例 + 基线
        return float(self.time_base + self.time_ratio_ema * core)



    # ------------------------------------
    # estimate memory cost based on (l,r,e)
    # ------------------------------------
    def estimate_memory(self, st: Optional[State]=None) -> int:
        s = st or self.state
        MP = sum(self.cost_table['M_P'][:s.e])
        MA = sum(self.cost_table['M_A'][s.l-1:s.e])
        MG = sum(self.cost_table['M_G'][s.l-1:s.r])
        MO = sum(self.cost_table['M_O'][s.l-1:s.r])
        est = (self.mem_base_bytes
            + self.aP * MP
            + self.aA * MA
            + self.aG * MG
            + self.aO * MO)
        return int(est)



    # ------------------------------------
    # check whether the action is safe (structure)
    # ------------------------------------
    def _check_action(self, st: Optional[State]=None) -> List[bool]:
        s = st or self.state

        # format
        mask = [False]*7
        if s.l < s.r: mask[LEFT_SHRINK] = True
        if s.l > 1 and (s.l-1) <= s.e: mask[LEFT_EXPAND] = True
        if s.r > s.l: mask[RIGHT_SHRINK] = True
        if (s.r < s.e) or (s.r == s.e and s.e < s.n): mask[RIGHT_EXPAND] = True
        if s.e < s.n: mask[PUSH] = True
        if (s.e > 1) and (s.r <= s.e - 1): mask[POP] = True
        mask[CONT] = True

        # memory
        out = [False]*7
        for a in range(7):
            if not mask[a]:
                out[a] = False
                continue
            s2 = self._apply_action_preview(s, a)
            if self.estimate_memory(s2) <= (self.M_max * (1-self.safe_ratio)):
                out[a] = True
            else:
                out[a] = False

        return out



    # ------------------------------------
    # action apply
    # ------------------------------------
    def _apply_action_preview(self, s: State, a: int) -> State:
        l, r, e, n = s.l, s.r, s.e, s.n
        nl, nr, ne = l, r, e

        if a == LEFT_SHRINK:
            if nl < nr: nl += 1
        elif a == LEFT_EXPAND:
            if nl > 1 and (nl - 1) <= ne: nl -= 1
        elif a == RIGHT_SHRINK:
            if nr > nl: nr -= 1
        elif a == RIGHT_EXPAND:
            if nr < ne:
                nr += 1
            elif nr == ne and ne < n:
                ne += 1   
                nr += 1   
        elif a == PUSH:
            if ne < n: ne += 1
        elif a == POP:
            if ne > 1:
                if nr == ne:
                    nr -= 1
                    ne -= 1
                elif nr <= ne - 1:
                    ne -= 1
        elif a == CONT:
            pass

        nl = max(1, min(nl, n))
        nr = max(nl, min(nr, ne))
        ne = max(nr, min(ne, n))
        return State(nl, nr, ne, n)



    # ------------------------------------
    # apply state to model
    # ------------------------------------
    def _apply_state_to_model(self, st: Optional[State]=None):
        s = st or self.state

        # offload for exit modules
        if self.offload_exit_to_cpu:
            for i, g in enumerate(self.groups, start=1):
                if i>s.e:
                    for m in g.modules:
                        m.to(torch.device('cpu'))

        # set requires_grad
        train_idx = set(range(s.l, s.r+1))
        for i, g in enumerate(self.groups, start=1):
            req = (i in train_idx) and (i <= s.e)
            for p in g.params: p.requires_grad = bool(req)
        
        # rebuild optimizer groups & clean state
        self._rebuild_optimizer_param_groups(train_idx, s)
        self._patch_forward_for_exit()


    # ------------------------------------
    # rebuild optimizer
    # ------------------------------------
    def _rebuild_optimizer_param_groups(self, trainable_indices:set, st: Optional[State]=None):
        s = st or self.state
        trainable_params = []

        for i,g in enumerate(self.groups, start=1):
            if (i in trainable_indices) and (i <= s.e):
                trainable_params.extend([p for p in g.params if p.requires_grad])

        base = dict(self.optimizer.param_groups[0].items()); base.pop('params', None)
        self.optimizer.param_groups.clear()
        if len(trainable_params)>0:
            self.optimizer.add_param_group({'params': trainable_params, **base})

        # clean state
        trainable_ids = {id(p) for p in trainable_params}
        
        to_del = []
        for p in list(self.optimizer.state.keys()):
            if (not getattr(p, "requires_grad", False)) or (id(p) not in trainable_ids):
                to_del.append(p)

        for p in to_del:
            self.optimizer.state.pop(p, None)



    # ------------------------------------
    # calibrate
    # ------------------------------------
    def calibrate(self, measured_mem: float, measured_time: float, alpha: float=0.8):
        beta = 1.0 - alpha
        s = self.state

        # ---------- Memory ----------
        MP = sum(self.cost_table['M_P'][:s.e])
        MA = sum(self.cost_table['M_A'][s.l-1:s.e])
        MG = sum(self.cost_table['M_G'][s.l-1:s.r])
        MO = sum(self.cost_table['M_O'][s.l-1:s.r])
        core_est = self.aP * MP + self.aA * MA + self.aG * MG + self.aO * MO

        est_now  = self.mem_base_bytes + core_est
        if est_now > 0:
            resid = max(0.0, float(measured_mem) - float(self.mem_base_bytes))
            denom = (MP + MA + MG + MO) + 1e-12
            wP, wA, wG, wO = MP/denom, MA/denom, MG/denom, MO/denom
            r_core = (resid / max(core_est, 1e-12))

            r_core = float(max(0.5, min(2.0, r_core)))

            self.aP = beta * self.aP + alpha * (self.aP * (wP * r_core + (1 - wP)))
            self.aA = beta * self.aA + alpha * (self.aA * (wA * r_core + (1 - wA)))
            self.aG = beta * self.aG + alpha * (self.aG * (wG * r_core + (1 - wG)))
            self.aO = beta * self.aO + alpha * (self.aO * (wO * r_core + (1 - wO)))

            core_est2 = self.aP * MP + self.aA * MA + self.aG * MG + self.aO * MO
            resid2 = float(measured_mem) - float(self.mem_base_bytes + core_est2)
            if resid2 > 0:
                self.mem_base_bytes = beta * self.mem_base_bytes + alpha * resid2

        # ---------- Time ----------
        Tfp = sum(self.cost_table['T_FP'][:s.e])
        Tgp = sum(self.cost_table['T_GP'][s.l-1:s.e])
        Tup = sum(self.cost_table['T_UP'][s.l-1:s.r])
        core_T = Tfp + Tgp + Tup

        est_T = self.time_base + self.time_ratio_ema * core_T
        if est_T > 0:
            rT = float(measured_time) / float(est_T)
            self.time_ratio_ema = beta * self.time_ratio_ema + alpha * rT
            est_T2 = self.time_base + self.time_ratio_ema * core_T
            residT = float(measured_time) - float(est_T2)
            if residT > 0:
                self.time_base = beta * self.time_base + alpha * residT
   


    # ------------------------------------
    # step
    # ------------------------------------
    def step(self, loss: Optional[float]=None):
        self._update_proxy()

        # sample memory, time, cost
        cur_time = time.perf_counter()
        #cur_mem  = torch.cuda.max_memory_allocated(device=self.device)
        #cur_mem  = torch.cuda.memory_allocated(device=self.device)
        cur_mem  = torch.cuda.max_memory_allocated(device=self.device)

        self._buf_loss.append(float(loss))
        self._buf_time.append(float(cur_time-self._last_ts))
        self._buf_mem.append(int(cur_mem))

        self._last_ts = cur_time

        # whether switch
        if len(self._buf_time) < self.switch_period:
            return
        
        self.n_switch += 1

        # window
        cur_win = {
            'loss': float(np.mean(self._buf_loss)),
            'time': float(np.mean(self._buf_time)),
            'mem':  float(np.max(self._buf_mem)),
        }

        # init kappa
        if self.n_switch==2:
            self.kappa = abs(cur_win['loss']-self._prev_win['loss']) / abs(self.ue.grad_ema[-2])
            self.lambda_frz = self.kappa * 0.2
            self.kappa_shr  = self.kappa + self.lambda_frz
            self.kappa_exit = self.kappa * 0.75
        elif self.n_switch>2:
            if self._prev_act in [LEFT_EXPAND, RIGHT_EXPAND]:
                tot_s = sum(self.ue.grad_ema[i] for i in range(self.state.l-1, self.state.r))
                k_hat = abs(cur_win['loss']-self._prev_win['loss']) / abs(tot_s)
                self.kappa = self.kappa * self.beta + k_hat * (1-self.beta)
            elif self._prev_act in [LEFT_SHRINK]:
                del_s = self.ue.grad_ema[self.state.l-2]
                k_shr = abs(cur_win['loss']-self._prev_win['loss']) / abs(del_s)
                self.kappa_shr = self.kappa_shr * self.beta + k_shr * (1-self.beta)
                self.lambda_frz = self.kappa_shr - self.kappa
            elif self._prev_act in [RIGHT_SHRINK]:
                del_s = self.ue.grad_ema[self.state.r-1]
                k_shr = abs(cur_win['loss']-self._prev_win['loss']) / abs(del_s)
                self.kappa_shr = self.kappa_shr * self.beta + k_shr * (1-self.beta)
                self.lambda_frz = self.kappa_shr - self.kappa
            elif self._prev_act in [POP]:
                del_s = self.ue.grad_ema[self.state.e-1]
                k_ext = abs(cur_win['loss']-self._prev_win['loss']) / abs(del_s)
                self.kappa_exit = self.kappa_exit * self.beta + k_ext * (1-self.beta)
            elif self._prev_act in [PUSH]:
                del_s = self.ue.grad_ema[self.state.e-2]
                k_ext = abs(cur_win['loss']-self._prev_win['loss']) / abs(del_s)
                self.kappa_exit = self.kappa_exit * self.beta + k_ext * (1-self.beta)
            elif self._prev_act in [CONT]:
                self.kappa = self.kappa * (1-1e-3)
                self.lambda_frz = self.lambda_frz * (1-1e-3)
                self.kappa_exit = self.kappa_exit * (1-1e-3)

        # reset buffer
        self._buf_loss.clear()
        self._buf_time.clear()
        self._buf_mem.clear()

        # select action
        mask = self._check_action(self.state)
        actions = [a for a,ok in enumerate(mask) if ok]
        if len(actions)==0:
            actions = [CONT]


        if self._prev_win is None:
            a = LEFT_EXPAND
            core_now = sum(self.cost_table['M_P'][:self.state.e]) \
                    + sum(self.cost_table['M_A'][self.state.l-1:self.state.e]) \
                    + sum(self.cost_table['M_G'][self.state.l-1:self.state.r]) \
                    + sum(self.cost_table['M_O'][self.state.l-1:self.state.r])
            alloc_now = cur_win['mem']
            self.mem_base_bytes = max(0.0, float(alloc_now) - core_now)
        elif random.random() < self.explore_eps:
            self.calibrate(cur_win['mem'], cur_win['time'])
            a = random.choice(actions)
        else:
            self.calibrate(cur_win['mem'], cur_win['time'])
            a = self._select_action(mask, self.state, self._prev_win, cur_win)

        # apply action
        s2 = self._apply_action_preview(self.state, a)
        #print(a, s2)
        if a != CONT:
            self._restore_forward_if_needed()
            self.state = s2
            print('est mem. ', int(self.estimate_memory(s2)/1024/1024))
            self._apply_state_to_model(self.state)
            self._patch_forward_for_exit()

        self._prev_win = cur_win
        self._prev_act = a

        for i, g in enumerate(self.groups, start=1):
            train = (i >= self.state.l and i <= self.state.r and i <= self.state.e)
            for p in g.params:
                if not train:
                    p.grad = None

        import gc
        gc.collect()
        torch.cuda.empty_cache()
        torch.cuda.synchronize(self.device)

        torch.cuda.reset_peak_memory_stats(self.device)

        self._logout(action=a, state=s2, loss=loss, mem=cur_mem)




    # ------------------------------------
    # select action
    # ------------------------------------
    def _select_action(self, mask:List[bool], state: State, prev_win, cur_win) -> int:


    # ------------------------------------
    # regist activation hook
    # ------------------------------------
    def _register_activation_hooks(self):


    def remove_activation_hooks(self):

    def reinstall_activation_hooks(self):
    
    
    XXX partial