import torch
from tqdm import tqdm
from torch import nn
from torch.optim import Adam
from sklearn.metrics import f1_score,accuracy_score
import numpy as np
from GCL.eval import BaseEvaluator


class LogisticRegression(nn.Module):
    def __init__(self, num_features, num_classes):
        super(LogisticRegression, self).__init__()
        self.fc = nn.Linear(num_features, num_classes)
        torch.nn.init.xavier_uniform_(self.fc.weight.data)

    def forward(self, x):
        z = self.fc(x)
        return z


class LREvaluator(BaseEvaluator):
    def __init__(self, num_epochs: int = 5000, learning_rate: float = 0.01,
                 weight_decay: float = 0.0, test_interval: int = 20, train_percent: float = 1):
        self.num_epochs = num_epochs
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.test_interval = test_interval
        self.train_percent = train_percent

    def evaluate(self, x: torch.FloatTensor, y: torch.LongTensor, split: dict):
        device = x.device
        x = x.detach().to(device)
        input_dim = x.size()[1]
        y = y.to(device)
        num_classes = y.int().max().item() + 1
        classifier = LogisticRegression(input_dim, num_classes).to(device)
        optimizer = Adam(classifier.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
        output_fn = nn.LogSoftmax(dim=-1)
        criterion = nn.NLLLoss()

        # Adjust train indices based on train_percent
        train_indices = split['train'].cpu().numpy()
        num_train_samples = int(len(train_indices) * self.train_percent)
        train_indices = np.random.choice(train_indices, num_train_samples, replace=False)
        train_indices = torch.tensor(train_indices, device=device)  # Convert back to PyTorch tensor


        best_val_micro = 0
        best_test_micro = 0
        best_test_macro = 0
        best_val_acc = 0
        best_epoch = 0

        with tqdm(total=self.num_epochs, desc='(LR)',
                  bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}{postfix}]') as pbar:
            for epoch in range(self.num_epochs):
                classifier.train()
                optimizer.zero_grad()

                # output = classifier(x[split['train']])
                # loss = criterion(output_fn(output), y[split['train']].long())
                output = classifier(x[train_indices])
                loss = criterion(output_fn(output), y[train_indices].long())

                loss.backward()
                optimizer.step()

                if (epoch + 1) % self.test_interval == 0:
                    classifier.eval()
                    y_test = y[split['test']].detach().cpu().numpy()
                    y_pred = classifier(x[split['test']]).argmax(-1).detach().cpu().numpy()
                    test_micro = f1_score(y_test, y_pred, average='micro')
                    test_macro = f1_score(y_test, y_pred, average='macro')
                    test_acc = accuracy_score(y_test, y_pred)

                # classification results by class label
                    class_accuracies = {}
                    for class_label in np.unique(y_test):
                        mask = y_test == class_label
                        correct = (y_pred[mask] == y_test[mask]).sum()
                        total = mask.sum()
                        class_accuracies[class_label] = correct / total if total > 0 else 0

                    y_val = y[split['valid']].detach().cpu().numpy()
                    y_pred = classifier(x[split['valid']]).argmax(-1).detach().cpu().numpy()
                    val_micro = f1_score(y_val, y_pred, average='micro')
                    val_acc = accuracy_score(y_val, y_pred)

                    if val_acc > best_val_acc:
                        best_val_micro = val_micro
                        best_test_micro = test_micro
                        best_test_macro = test_macro
                        best_epoch = epoch
                        best_acc = test_acc
                        best_val_acc = val_acc
                        best_class_accuracies = class_accuracies


                    if val_micro > best_val_micro:
                        best_val_micro = val_micro
                        best_test_micro = test_micro
                        best_test_macro = test_macro
                        best_epoch = epoch

                    pbar.set_postfix({'best_acc':best_acc,'F1Mi': best_test_micro, 'F1Ma': best_test_macro})
                    pbar.update(self.test_interval)

        return {
            'acc': best_acc,
            'micro_f1': best_test_micro,
            'macro_f1': best_test_macro,
            'class_accuracies': best_class_accuracies
        }

class LREvaluator_white(BaseEvaluator):
    def __init__(self, num_epochs: int = 5000, learning_rate: float = 0.01,
                 weight_decay: float = 0.0, test_interval: int = 20, train_percent: float = 1):
        self.num_epochs = num_epochs
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.test_interval = test_interval
        self.train_percent = train_percent

    def evaluate(self, x: torch.FloatTensor, y: torch.LongTensor, split: dict, class_time, dataset_name):
        device = x.device
        x = x.detach().to(device)
        input_dim = x.size()[1]
        y = y.to(device)
        num_classes = y.int().max().item() + 1
        classifier = LogisticRegression(input_dim, num_classes).to(device)
        optimizer = Adam(classifier.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
        output_fn = nn.LogSoftmax(dim=-1)
        criterion = nn.NLLLoss()

        # Adjust train indices based on train_percent
        train_indices = split['train'].cpu().numpy()
        num_train_samples = int(len(train_indices) * self.train_percent)
        train_indices = np.random.choice(train_indices, num_train_samples, replace=False)
        train_indices = torch.tensor(train_indices, device=device)  # Convert back to PyTorch tensor


        best_val_micro = 0
        best_test_micro = 0
        best_test_macro = 0
        best_val_acc = 0
        best_epoch = 0
        best_model_dict = None

        with tqdm(total=self.num_epochs, desc='(LR)',
                  bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}{postfix}]') as pbar:
            for epoch in range(self.num_epochs):
                classifier.train()
                optimizer.zero_grad()

                # output = classifier(x[split['train']])
                # loss = criterion(output_fn(output), y[split['train']].long())
                output = classifier(x[train_indices])
                loss = criterion(output_fn(output), y[train_indices].long())

                loss.backward()
                optimizer.step()

                if (epoch + 1) % self.test_interval == 0:
                    classifier.eval()
                    y_test = y[split['test']].detach().cpu().numpy()
                    y_pred = classifier(x[split['test']]).argmax(-1).detach().cpu().numpy()
                    test_micro = f1_score(y_test, y_pred, average='micro')
                    test_macro = f1_score(y_test, y_pred, average='macro')
                    test_acc = accuracy_score(y_test, y_pred)

                # classification results by class label
                    class_accuracies = {}
                    for class_label in np.unique(y_test):
                        mask = y_test == class_label
                        correct = (y_pred[mask] == y_test[mask]).sum()
                        total = mask.sum()
                        class_accuracies[class_label] = correct / total if total > 0 else 0

                    y_val = y[split['valid']].detach().cpu().numpy()
                    y_pred = classifier(x[split['valid']]).argmax(-1).detach().cpu().numpy()
                    val_micro = f1_score(y_val, y_pred, average='micro')
                    val_acc = accuracy_score(y_val, y_pred)

                    if val_acc > best_val_acc:
                        best_val_micro = val_micro
                        best_test_micro = test_micro
                        best_test_macro = test_macro
                        best_epoch = epoch
                        best_acc = test_acc
                        best_val_acc = val_acc
                        best_class_accuracies = class_accuracies
                        best_model_dict = classifier.state_dict()


                    if val_micro > best_val_micro:
                        best_val_micro = val_micro
                        best_test_micro = test_micro
                        best_test_macro = test_macro
                        best_epoch = epoch

                    pbar.set_postfix({'best_acc':best_acc,'F1Mi': best_test_micro, 'F1Ma': best_test_macro})
                    pbar.update(self.test_interval)
        save_path = f"pkl_classifier/classifier_{dataset_name}_split{class_time}.pkl"
        torch.save(best_model_dict, save_path)
        return {
            'acc': best_acc,
            'micro_f1': best_test_micro,
            'macro_f1': best_test_macro,
            'class_accuracies': best_class_accuracies
        }