import logging
from typing import Dict
import torch
import numpy as np
import copy

MODULE_LIST = {
    't2v-1.3B': {
        'cond_stream': ['self-attention', 'cross-attention', 'ffn'],
        'uncond_stream': ['self-attention', 'cross-attention', 'ffn']
    },
    't2v-14B': { # Wan2.1
        'cond_stream': ['self-attention', 'cross-attention', 'ffn'],
        'uncond_stream': ['self-attention', 'cross-attention', 'ffn']
    },
    'i2v-A14B': { # Wan2.2
        'cond_stream': ['self-attention', 'cross-attention', 'ffn'],
        'uncond_stream': ['self-attention', 'cross-attention', 'ffn']
    },
    'ti2v-5B': {
        'cond_stream': ['self-attention', 'cross-attention', 'ffn'],
        'uncond_stream': ['self-attention', 'cross-attention', 'ffn']
    },
    'hunyuan': {
        'single_stream': ['total'],
        'double_stream': ['img-attn', 'img-mlp', 'txt-attn', 'txt-mlp']
    },
    'flux-dev': {
        'single_stream': ['total'],
        'double_stream': ['img-attn', 'img-mlp', 'txt-attn', 'txt-mlp']
    },
}

STREAM_LIST = {
    't2v-1.3B': ['cond_stream', 'uncond_stream'],
    'ti2v-5B': ['cond_stream', 'uncond_stream'],
    't2v-14B': ['cond_stream', 'uncond_stream'],
    'i2v-A14B': ['cond_stream', 'uncond_stream'],
    'hunyuan': ['single_stream', 'double_stream'],
    'flux-dev': ['single_stream', 'double_stream']
}

LAYER_NUM = {
    't2v-14B': {
        'cond_stream': 40,
        'uncond_stream': 40
    },
    'i2v-A14B': {
        'cond_stream': 40,
        'uncond_stream': 40
    },
    't2v-1.3B': {
        'cond_stream': 30,
        'uncond_stream': 30
    },
    'ti2v-5B': {
        'cond_stream': 30,
        'uncond_stream': 30
    },
    'hunyuan': {
        'single_stream': 40,
        'double_stream': 20,
    },
    'flux-dev': {
        'single_stream': 38,
        'double_stream': 19
    },
}

ERROR_RATE = {
    't2v-14B': 1.0,
    'i2v-A14B': 1.0,
    't2v-1.3B': 1.0,
    'ti2v-5B': 0.8,
    'hunyuan': 1.0,
    'flux-dev': 1.7
}

CAL_AMOUNT_LIST = []
ERROR_THRESHOLD_LIST = []

def cache_init(self, num_steps=50):   
    '''
    Initialization for cache.
    '''
    # Initialize base cache structure
    logging.info(f"[CacheInit]: mode {self.mode}, dynamic_cache {self.dynamic_cache}, use_alpha {self.use_alpha}, first enhance = {self.first_enhance}")

    # Common cache dictionary initialization
    self.cache_dic = {
        'cache_counter': 0,
        'stream_list': STREAM_LIST[self.task],
        'module_list': MODULE_LIST[self.task],
        'layer_num': LAYER_NUM[self.task],
        'scaling_cache': self.mode == 'Scaling',
        'taylor_cache': self.mode == 'Taylor',
        'error_rate': ERROR_RATE[self.task],
        'test_FLOPs': False,
        'update_alpha': self.update_alpha,
        'use_alpha': self.use_alpha,
        'dynamic_cache': self.dynamic_cache and self.mode == 'Scaling',
        'fresh_threshold': 2,
        'max_order': 0,
        'num_steps': num_steps,
        'first_enhance': self.first_enhance if not self.update_alpha else 2,
        'last_enhance': 2,
        'accumulative_error': 0.0,
        'error_threshold': 0.0,
        'history_error': [],
        'sum_amount': {stream: LAYER_NUM[self.task][stream] * num_steps for stream in STREAM_LIST[self.task]},
        'cal_amount': {stream: 0 for stream in STREAM_LIST[self.task]},
    }
    # Initialize common error and alpha structures
    common_error_alpha = {
        stream: {key: {} for key in self.cache_dic['module_list'][stream]} for stream in self.cache_dic['stream_list']
    }
    cache = {
        -1: {
            stream: {j: {} for j in range(self.cache_dic['layer_num'][stream])}
            for stream in self.cache_dic['stream_list']
        },
        'activated_layers': {
            stream: {j: self.cache_dic['layer_num'][stream] - 1 for j in range(num_steps)}
            for stream in self.cache_dic['stream_list']
        },
        "dynamic_error": copy.deepcopy(common_error_alpha)
    }

    self.cache_dic['cache'] = cache

    # Mode-specific configurations
    if self.mode == 'Original':
        self.cache_dic['cache']["loaded_alpha"] = copy.deepcopy(common_error_alpha)

    elif self.mode == 'Taylor':
        self.cache_dic['fresh_threshold'] = 2
        self.cache_dic['first_enhance'] = self.first_enhance
        self.cache_dic['max_order'] = 1
        self.cache_dic['cache']["loaded_alpha"] = copy.deepcopy(common_error_alpha)
        
    elif self.mode == "Scaling":
        # Initialize all error-related structures
        if not self.update_alpha:
            self.cache_dic['cache']["loaded_alpha"] = torch.load(f"../assets/alpha_dict/alpha_dict_{self.task}.pth")
        else:
            self.cache_dic['cache']["loaded_alpha"] = copy.deepcopy(common_error_alpha)
        
    # self.current state initialization
    self.current = {
        'activated_steps': [0],
        'step': 0,
        'type': 'full',
        'num_steps': num_steps
    }

def cache_release(self):
    num_steps = self.cache_dic["num_steps"]
    calculate_rate(self.cache_dic)

    def destroy_tensors_in_dict(d):
        keys_to_delete = []
        for k, v in d.items():
            if isinstance(v, torch.Tensor):
                if v.is_cuda:
                    logging.debug(f"[Destroy] Found CUDA Tensor: {k}, shape: {v.shape}, device: {v.device}")
                else:
                    logging.debug(f"[Destroy] Found CPU Tensor: {k}, shape: {v.shape}")
                d[k] = None
            elif isinstance(v, dict):
                destroy_tensors_in_dict(v)
    destroy_tensors_in_dict(self.cache_dic['cache'][-1])
    
    self.cache_dic["cache"][-1] = {
        stream: {j: {} for j in range(self.cache_dic['layer_num'][stream])}
        for stream in self.cache_dic['stream_list']
    }
    self.cache_dic['cache']["activated_layers"] = {
        stream: {j: self.cache_dic['layer_num'][stream] - 1 for j in range(num_steps)}
        for stream in self.cache_dic['stream_list']
    }
    self.cache_dic['cache']["dynamic_error"] = {
        stream: {key: {} for key in self.cache_dic['module_list'][stream]} for stream in self.cache_dic['stream_list']
    }
    
    self.cache_dic['accumulative_error'] = 0.0
    self.cache_dic['error_threshold'] = 0.0
    self.cache_dic['history_error'] = []
    self.cache_dic['sum_amount'] = {stream: LAYER_NUM[self.task][stream] * num_steps for stream in STREAM_LIST[self.task]}
    self.cache_dic['cal_amount'] = {stream: 0 for stream in STREAM_LIST[self.task]}
    self.current = {
        'activated_steps': [0],
        'step': 0,
        'type': 'full',
        'num_steps': num_steps
    }

    allocated_after = torch.cuda.memory_allocated()
    reserved_after = torch.cuda.memory_reserved()
    logging.info(
        "[CacheRelease] Before cache release & empty_cache - "
        f"Allocated: {allocated_after / (1024 ** 2):.2f} MB, "
        f"Reserved: {reserved_after / (1024 ** 2):.2f} MB"
    )
    torch.cuda.empty_cache()
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    allocated_after = torch.cuda.memory_allocated()
    reserved_after = torch.cuda.memory_reserved()
    
    logging.info(
        "[CacheRelease] After cache release & empty_cache - "
        f"Allocated: {allocated_after / (1024 ** 2):.2f} MB, "
        f"Reserved: {reserved_after / (1024 ** 2):.2f} MB"
    )

def calculate_rate(cache_dic):
    if not cache_dic['update_alpha']:
        rate_list = [cache_dic['cal_amount'][stream] / cache_dic['sum_amount'][stream] for stream in cache_dic['stream_list']]
        CAL_AMOUNT_LIST.append(np.mean(rate_list))
        logging.info(f"[Accelerate Rate] : {1 / np.mean(CAL_AMOUNT_LIST)}")
        ERROR_THRESHOLD_LIST.append(cache_dic['error_threshold'])