import torch
import transformers
import warnings
import os
import datetime
import csv

transformers.utils.logging.set_verbosity(40)
warnings.filterwarnings("ignore")
from transformers import AutoModelForCausalLM, AutoTokenizer
from abc import ABC, abstractmethod
from accelerate import Accelerator
from .kvcache import KVCacheModel
from .util import seed_everything, norm_logits, sample, max_fn
import time
import numpy as np


class Decoding(ABC):
    def __init__(self, args):
        self.args = args
        self.accelerator = Accelerator()

        seed_everything(self.args.seed)
        self.seed = self.args.seed
        self.seed_set = set()

        assert self.accelerator.num_processes == 1 and args.eval_mode == "sd"

        self.draft_forward_times = 0
        self.target_forward_times = 0
        self.num_acc_tokens = []

    def load_model(self):
        self.color_print(f"Loading models:\n{self.args.draft_model}\n{self.args.target_model}", 3)
        self.draft_model = AutoModelForCausalLM.from_pretrained(self.args.draft_model, device_map="cuda:0",
                                                                torch_dtype=torch.bfloat16,
                                                                trust_remote_code=True).eval()
        self.target_model = AutoModelForCausalLM.from_pretrained(self.args.target_model,
                                                                 device_map="balanced_low_0",
                                                                 torch_dtype=torch.bfloat16,
                                                                 trust_remote_code=True).eval()
        self.vocab_size = self.args.vocab_size

    def load_tokenizer(self):
        self.color_print(f"Loading tokenizer of {self.args.draft_model}...", 3)
        self.tokenizer = AutoTokenizer.from_pretrained(self.args.draft_model, trust_remote_code=True)
        self.tokenizer.padding_side = "right"

        self.tokenizer.pad_token_id = 2

    @abstractmethod
    def load_data(self):
        pass

    @abstractmethod
    def preprocess(self, input_text):
        pass

    @abstractmethod
    def postprocess(self, input_text, output_text):
        pass

    @torch.no_grad()
    def speculative_decoding(self, prefix, num_candidates=3):
        # Record initial token count
        initial_token_count = prefix.shape[1]

        # Initialize logging
        log_dir = "speculative_logs"
        os.makedirs(log_dir, exist_ok=True)
        summary_file = f"{log_dir}/speculative_decoding_stats.csv"
        time_stats = {
            'start_time': time.time(),
            'draft_model': 0.0,
            'target_model': 0.0,
            'sampling_loop': 0.0,
            'residual_sampling': 0.0,
            'cache_operations': 0.0,
            'token_processing': 0.0,
            'total_iterations': 0,
            'total_draft_tokens': 0,
            'total_accepted_tokens': 0,
            'args': vars(self.args)
        }

        draft_device = self.draft_model.device
        target_device = self.target_model.device

        # Initialize KV caches for both models
        approx_model_cache = KVCacheModel(self.draft_model, self.args.temp, self.args.top_k, self.args.top_p)
        approx_model_cache.vocab_size = self.vocab_size
        target_model_cache = KVCacheModel(self.target_model, self.args.temp, self.args.top_k, self.args.top_p)
        target_model_cache.vocab_size = self.vocab_size
        EPS = 1e-20

        # Define constants for numerical stability
        EPS_TENSOR = torch.tensor(EPS, device=draft_device)
        ONE_TENSOR = torch.tensor(1.0, device=draft_device)
        ZERO_TENSOR = torch.tensor(0.0, device=draft_device)
        k_L = torch.tensor(num_candidates, device=draft_device)
        exp_term = lambda x: torch.pow(ONE_TENSOR - x, k_L)

        # Main speculative decoding loop
        while prefix.shape[1] < initial_token_count + self.args.max_tokens:
            time_stats['total_iterations'] += 1
            iter_start = time.time()

            # Prepare input for draft model
            prefix = prefix.long().to(draft_device)
            input_ids = prefix.repeat(num_candidates, 1)
            prefix_len = prefix.shape[1]

            # Track draft tokens generated in this iteration
            current_draft_tokens = self.args.gamma
            time_stats['total_draft_tokens'] += current_draft_tokens

            # Draft model generation
            draft_start = time.time()
            x = approx_model_cache.generate(input_ids, self.args.gamma)
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            time_stats['draft_model'] += time.time() - draft_start

            # Target model verification
            target_start = time.time()
            _ = target_model_cache.generate(x.to(target_device), 1)
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            time_stats['target_model'] += time.time() - target_start

            # Update forward pass counters
            if self.accelerator.is_main_process:
                self.draft_forward_times += self.args.gamma
                self.target_forward_times += 1

            n = prefix_len - 1

            O = []  # Accepted token sequence
            S = list(range(num_candidates))  # Candidate sequence indices
            final = 0
            # Sampling decision loop
            sampling_start = time.time()
            H = []
            Ms_final = torch.ones(1, device=draft_device)
            Mb_final = torch.ones(1, device=draft_device)

            # Evaluate each candidate sequence
            for j in S:
                # Precompute probability histories for current sequence j
                probs_s_j = approx_model_cache._prob_history[j, prefix_len - 1: prefix_len + self.args.gamma]
                probs_b_j = target_model_cache._prob_history.to(draft_device)[j,
                            prefix_len - 1: prefix_len + self.args.gamma]

                # Initialize joint probabilities
                M_s = torch.ones(1, device=draft_device)
                M_b = torch.ones(1, device=draft_device)
                for k in range(n - prefix_len + 1):
                    current_token = x[j, prefix_len + k]
                    q = approx_model_cache._prob_history[j, prefix_len + k - 1, current_token]
                    p = target_model_cache._prob_history.to(draft_device)[j, prefix_len + k - 1, current_token]
                    M_s = M_s * q
                    M_b = M_b * p
                g = n
                for i in range(g - prefix_len + 1, self.args.gamma - 1):
                    current_token = x[j, prefix_len + i]
                    prob_b = probs_b_j[i]
                    prob_s = probs_s_j[i]
                    prob_b_ = probs_b_j[i + 1]
                    prob_s_ = probs_s_j[i + 1]

                    p = prob_b[current_token]
                    q = prob_s[current_token]
                    M_s = M_s * q
                    M_b = M_b * p

                    # Calculate acceptance ratio
                    ratio = (M_s * prob_s_) / (M_b * prob_b_ + EPS_TENSOR)
                    min_ratio = torch.min(ratio, ONE_TENSOR)
                    raw_ratio = torch.min(M_s / (M_b + EPS_TENSOR), ONE_TENSOR)

                    # Numerator calculation
                    term1 = torch.sum(M_b * prob_b_ * exp_term(min_ratio))
                    term2 = M_b * exp_term(raw_ratio)
                    up = term1 - term2

                    # Denominator calculation
                    term3 = ONE_TENSOR - exp_term(M_s)
                    term5 = torch.sum(M_b * prob_b_ * (ONE_TENSOR - exp_term(min_ratio)))
                    down = term3 - term5
                    safe_down = torch.where(torch.abs(down) < EPS_TENSOR, EPS_TENSOR, down)

                    h = up / safe_down

                    # Rejection sampling
                    r = torch.rand(1, device=draft_device)
                    current_seq = x[j, prefix_len:prefix_len + i + 1]
                    current_seq_tuple = tuple(current_seq.cpu().numpy().tolist())

                    if current_seq_tuple not in H:
                        if r < h:
                            O = current_seq
                            n = i + prefix_len
                            final = j
                            Ms_final = M_s
                            Mb_final = M_b
                        else:
                            H.append(current_seq_tuple)
                # Handle case when i = L (last position)
                current_seq = x[j, prefix_len:prefix_len + self.args.gamma]
                r = torch.rand(1, device=draft_device)
                t = x[j, prefix_len + self.args.gamma - 1].item()
                p_j = prob_b_[t].item()
                q_j = prob_s_[t].item()
                M_s = M_s * q_j
                M_b = M_b * p_j

                hh = torch.min(M_s / (M_b + EPS_TENSOR), ONE_TENSOR)
                t1 = M_b * (1 - exp_term(hh))
                t2 = 1 - exp_term(M_s)
                h = t1 / (t2 + EPS_TENSOR)

                if r < h:
                    n = prefix_len + self.args.gamma - 1
                    O = current_seq
                    final = j
                    Ms_final = M_s
                    Mb_final = M_b
                    break

            time_stats['sampling_loop'] += time.time() - sampling_start

            # Track accepted tokens in this iteration
            accepted_tokens = n - prefix_len + 1
            time_stats['total_accepted_tokens'] += accepted_tokens

            self.num_acc_tokens.append(accepted_tokens)
            assert n >= prefix_len - 1, f"n {n}, prefix_len {prefix_len}"

            # Cache operations
            cache_start = time.time()
            O_tensor = torch.tensor(O, device=draft_device).unsqueeze(0).expand(prefix.size(0), -1)
            prefix = torch.cat((prefix, O_tensor), dim=1)

            if n > prefix_len - 1:
                target_model_cache.set_cache_to_jth_sample(final)
                approx_model_cache.set_cache_to_jth_sample(final)

            time_stats['cache_operations'] += time.time() - cache_start

            # Residual sampling for remaining tokens
            residual_start = time.time()
            remaining_tokens = self.args.gamma - (n - prefix_len + 1)
            if n < prefix_len + self.args.gamma - 1:
                prob_b = target_model_cache._prob_history.to(draft_device)[final, n]
                prob_s = approx_model_cache._prob_history[final, n]

                numerator_vector = Mb_final * prob_b * (
                            ONE_TENSOR - torch.min(Ms_final * prob_s / (Mb_final * prob_b + EPS_TENSOR),
                                                   ONE_TENSOR) ** k_L)
                denominator_val = torch.sum(numerator_vector)

                P_res = numerator_vector / (denominator_val + EPS_TENSOR)
                if torch.all(P_res <= 0) or torch.isnan(P_res).any() or torch.isinf(P_res).any():
                    # Fallback to target model distribution
                    P_res = prob_b
                    P_res = P_res / P_res.sum()
                # Residual sampling
                t = sample(P_res).to(draft_device).unsqueeze(1)
                t_token = sample(P_res)
                t_prob = prob_b[t_token]
                t_pros = prob_s[t_token]
                t = t_token.to(draft_device).unsqueeze(1)
                prefix = torch.cat((prefix, t), dim=1).long()
                target_model_cache.rollback(n + 1)
                approx_model_cache.rollback(n + 1)

            else:
                # All draft tokens accepted
                t = sample(target_model_cache._prob_history[final, -1, :self.vocab_size]).to(draft_device).unsqueeze(1)
                prefix = torch.cat((prefix, t), dim=1).long()
                target_model_cache.rollback(n + 2)
                approx_model_cache.rollback(n + 1)
            time_stats['residual_sampling'] += time.time() - residual_start

            # Final token processing
            token_start = time.time()
            time_stats['token_processing'] += time.time() - token_start

        # Calculate total runtime
        total_runtime = time.time() - time_stats['start_time']
        generated_tokens = prefix.shape[1] - initial_token_count

        # Calculate overall acceptance rate
        overall_acceptance_rate = time_stats['total_accepted_tokens'] / time_stats['total_draft_tokens'] if time_stats[
                                                                                                                'total_draft_tokens'] > 0 else 0

        print(f"innertotal_runtime{total_runtime}")
        # Save summary data to CSV
        if time_stats['total_iterations'] > 0 and self.accelerator.is_main_process:
            timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
            data_row = [
                timestamp,
                total_runtime,
                generated_tokens,
                time_stats['total_iterations'],
                time_stats['draft_model'],
                time_stats['target_model'],
                time_stats['sampling_loop'],
                time_stats['residual_sampling'],
                time_stats['cache_operations'],
                time_stats['token_processing'],
                generated_tokens / total_runtime if total_runtime > 0 else 0,
                self.args.gamma,
                self.args.max_tokens,
                self.args.temp,
                self.args.top_k,
                self.args.top_p,
                overall_acceptance_rate,
                time_stats['total_draft_tokens'],
                time_stats['total_accepted_tokens']
            ]

            file_exists = os.path.exists(summary_file)
            with open(summary_file, 'a', newline='') as f:
                writer = csv.writer(f)
                if not file_exists:
                    headers = [
                        'Timestamp', 'Total Runtime (s)', 'Generated Tokens', 'Iterations',
                        'Draft Model Time (s)', 'Target Model Time (s)', 'Sampling Time (s)',
                        'Residual Sampling Time (s)', 'Cache Operations Time (s)',
                        'Token Processing Time (s)', 'Tokens per Second',
                        'Gamma', 'Max Tokens', 'Temperature', 'Top K', 'Top P',
                        'Overall Acceptance Rate',
                        'Total Draft Tokens', 'Total Accepted Tokens'
                    ]
                    writer.writerow(headers)

                # Write data row
                writer.writerow(data_row)

        # Print summary information
        if self.accelerator.is_main_process:
            print(f"\n{'=' * 60}")
            print(f"【Speculative Decoding Statistics Saved】")
            print(f"File path: {summary_file}")
            print(f"Total runtime: {total_runtime:.3f}s")
            print(f"Generated tokens: {generated_tokens}")
            print(f"Total draft tokens: {time_stats['total_draft_tokens']}")
            print(f"Total accepted tokens: {time_stats['total_accepted_tokens']}")
            print(f"Overall acceptance rate: {overall_acceptance_rate:.2%}")
            print(f"{'=' * 60}\n")

        return prefix
   
    @abstractmethod
    def eval(self):
        pass

    def color_print(self, content: str, color_number: int = 4):
        """print content with color. Some color numbers are listed: Gray: 0, Red: 1, Green: 2, Yellow: 3, Blue: 4."""
        if self.accelerator.is_main_process:
            print(f"\033[9{color_number}m{content}\033[0m")