import torch
import torch.nn as nn
import torch.utils.data
import numpy as np
from sacred import Experiment
import os
import sys
from collections import Counter
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(root_dir)
from config import ex
from tqdm import tqdm

class Classifier(nn.Module):
    def __init__(self, num_drf, k, _config):
        super(Classifier, self).__init__()
        input_dim = num_drf * k
        hidden_dim = input_dim // 2
        output_dim = k

        if _config['classifier_arch'] == 'mini':
            hidden_dim = input_dim // 2 if _config['classifier_vocab'] == 'top-k' else 100

            self.classifier = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, output_dim),
            nn.Sigmoid(),
        )
        
        elif _config['classifier_arch'] == 'small':
            hidden_dim_high = input_dim * 4
            self.classifier = nn.Sequential(
                nn.Linear(input_dim, hidden_dim_high),
                nn.GELU(),
                nn.Linear(hidden_dim_high, hidden_dim_high),
                nn.GELU(),
                nn.Linear(hidden_dim_high, output_dim),
                nn.Sigmoid(),
            )
    
        self.classifier = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
            nn.Sigmoid(),
        )

        if _config['load_classifier']:
            self.load_state_dict(torch.load(_config['classifier_ckpt_path']))

    def forward(self, x):
        return self.classifier(x)
        
def init_classifier(_config):
    num_drf = len(_config['drafting'])
    k = _config['classifier_top_k']
    classifier = Classifier(num_drf=num_drf, k=k, _config=_config)
    return classifier

class ClassifierDataset(torch.utils.data.Dataset):
    def __init__(self, sequences, ids_accepted_tokens, tokens_accepted_tokens_topk, value_probability_accepted_topk,
                 ids_first_rejected, tokens_rejected_tokens_topk, value_probability_rejected_topk, num_drf, k):

        self.samples = []

        # Process accepted tokens
        for seq_idx, token_positions in enumerate(ids_accepted_tokens):
            for tokens_idx, pos in enumerate(token_positions):
                # For each accepted token at position pos in sequence seq_idx
                sample = self.process_sample(seq_idx, pos, sequences, tokens_accepted_tokens_topk,
                                             value_probability_accepted_topk, num_drf, k, sequences[seq_idx][pos], tokens_idx)
                if sample is not None:
                    self.samples.append(sample)

        # Process rejected tokens
        for seq_idx, token_positions in enumerate(ids_first_rejected):
            for tokens_idx, pos in enumerate(token_positions):
                # For each rejected token at position pos in sequence seq_idx
                sample = self.process_sample(seq_idx, pos, sequences, tokens_rejected_tokens_topk,
                                             value_probability_rejected_topk, num_drf, k, sequences[seq_idx][pos], tokens_idx)
                if sample is not None:
                    self.samples.append(sample)

    def process_sample(self, seq_idx, pos, sequences, tokens_topk, probs_topk, num_drf, k, ground_truth_token, tokens_idx):
        # For each drafting
        tokens_drf = []
        probs_drf = []
        drf_keys = list(tokens_topk[0].keys())
        for drf in drf_keys:
            # Get top-k tokens and probabilities at position pos
            tokens_list = tokens_topk[seq_idx][drf][tokens_idx]  # List of k tokens
            probs_list = probs_topk[seq_idx][drf][tokens_idx]    # List of k probabilities
            tokens_drf.append(tokens_list)
            probs_drf.append(probs_list)

        
        
        # Use tokens from first drafting as reference
        token_probs = {}
        num_drf = len(tokens_drf)
        for drf_idx in range(num_drf):
            tokens_list = tokens_drf[drf_idx]
            probs_list = probs_drf[drf_idx]
            for idx, token in enumerate(tokens_list):
                prob = probs_list[idx]
                if token not in token_probs:
                    token_probs[token] = 0.0
                token_probs[token] += prob

        # Sort tokens by combined probabilities in descending order
        sorted_tokens = sorted(token_probs.items(), key=lambda item: item[1], reverse=True)

        # Select top k tokens
        ref_tokens = [token for token, prob in sorted_tokens[:k]]

        # Build feature matrix
        features = []
        for i, token in enumerate(ref_tokens):
            feature = []
            for drf_idx in range(num_drf):
                # For each drafting
                if token in tokens_drf[drf_idx]:
                    idx = tokens_drf[drf_idx].index(token)
                    prob = probs_drf[drf_idx][idx]
                else:
                    prob = 0.0
                feature.append(prob) 
            features.append(feature)  # feature is of size num_drf

        # features is of size (k, num_drf)
        features = np.array(features)  # shape (k, num_drf)
        # Flatten to (k * num_drf)
        features = features.flatten()  # shape (k * num_drf)

        # Build label vector
        labels = np.zeros(len(ref_tokens), dtype=np.float32)  # size k
        for idx, token in enumerate(ref_tokens):
            if token == ground_truth_token:
                labels[idx] = 1.0

        if labels.sum() == 0:
            # Skip this sample
            return None

        return features, labels

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        features, labels = self.samples[idx]
        return torch.FloatTensor(features), torch.FloatTensor(labels)

def get_label_distribution(dataset):
    # Initialize a counter for the classes
    label_counter = Counter()

    # Iterate through all samples in the dataset
    for _, labels in dataset.samples:
        # Find the index of the class (position of 1 in the one-hot label vector)
        class_index = labels.argmax()
        label_counter[class_index.item()] += 1

    return label_counter

def train_classifier(_config):
    root = "/XXXX-5/home-XXXX-3/data/MSD/npy"
    exp_title = "fp16-mm-weight-cascade-MTC-A6000-CAPTION-240924"
    ckpt_name = "sd_llava-68m_llava-llama-7b_caption-florence2-0.77b-C-multimodal-text-only-mm-weight-1-cascade-drafting_DC100_EN_mtl-128_gamma-5_t0_fp16-16_2024"
    k = _config['classifier_top_k']

    npy_names = [
        "sequences",
        "ids_accepted_tokens",
        "tokens_accepted_tokens_topk",
        "value_probability_accepted_topk",
        "ids_first_rejected",
        "tokens_rejected_tokens_topk",
        "value_probability_rejected_topk",
    ]

    map_npy = {
        filename: np.load(f"{root}/{exp_title}/{ckpt_name}/{filename}.npy", allow_pickle=True) for filename in npy_names
    }

    num_drf = len(_config['drafting'])

    classifier = init_classifier(_config).to("cuda:0")

    dataset = ClassifierDataset(
        sequences=map_npy["sequences"],
        ids_accepted_tokens=map_npy["ids_accepted_tokens"],
        tokens_accepted_tokens_topk=map_npy["tokens_accepted_tokens_topk"],
        value_probability_accepted_topk=map_npy["value_probability_accepted_topk"],
        ids_first_rejected=map_npy["ids_first_rejected"],
        tokens_rejected_tokens_topk=map_npy["tokens_rejected_tokens_topk"],
        value_probability_rejected_topk=map_npy["value_probability_rejected_topk"],
        num_drf=num_drf,
        k=k
    )


    # Assuming you have initialized `dataset`
    label_distribution = get_label_distribution(dataset)
    print(f"label_distribution: {label_distribution}")
        
    batch_size = _config.get('classifier_batch_size', 32)
    lr = _config.get('classifier_lr', 1e-4)
    epochs = _config.get('classifier_epochs', 1)

    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(classifier.parameters(), lr=lr)

    for epoch in tqdm(range(epochs), desc="Epochs"):
        total_loss = 0  

        for batch_features, batch_labels in dataloader:
            batch_features = batch_features.to("cuda:0")
            batch_labels = batch_labels.to("cuda:0")
            optimizer.zero_grad()
            outputs = classifier(batch_features)
            # Convert labels from one-hot vectors to class indices
            target = batch_labels.argmax(dim=1)
            loss = criterion(outputs, target)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        avg_loss = total_loss / len(dataloader)

        if epoch % 1000 == 0:
            print(f"lr: {lr}, epochs: {epoch}, k: {k}, loss: {avg_loss:.4f}")
    
    # Save the classifier
    ckpt_dir = f"/XXXX-5/home-XXXX-3/data/MSD/checkpoint/classifier/{ckpt_name}"
    classifier_name = f"{lr}-{epochs}-{k}.pth"
    os.makedirs(ckpt_dir, exist_ok=True)
    torch.save(classifier.state_dict(), f"{ckpt_dir}/{classifier_name}")
    print(f"Save dir: {ckpt_dir}")
    print(f"Classifier saved to {ckpt_dir}/{classifier_name}")



@ex.automain
def main(_config):
    train_classifier(_config)