from torch.utils.data import ConcatDataset
import torch
from torch.utils.data import Dataset


class MyDataset(Dataset):
    def __init__(self, data_list):
        self.data = data_list  # data_list: list of (image, label)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image, label = self.data[idx]
        return image, label

def evaluate_model(train_data, test_data, model, previous_test, args):
    train_set = [item for sublist in train_data for item in sublist]
    test_set = [item for item in test_data]
    model.eval()
    train_dataset = MyDataset(train_set)
    test_dataset = MyDataset(test_set)
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False)
    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)
    total_correct_train, total_samples_train = 0, 0
    total_correct_test, total_samples_test = 0, 0


    device = args.device
    use_amp = (torch.cuda.is_available() and device.startswith("cuda"))
    amp_ctx = torch.autocast(device_type="cuda", dtype=torch.bfloat16) if use_amp else contextlib.nullcontext()

    with torch.inference_mode():
        with amp_ctx:
            for images, labels in train_dataloader:
                images = images.to(args.device)
                labels = labels.to(args.device)
                if args.method == 'FedCIL':
                    outputs = model(images)[2]
                elif args.method == 'AFFCL':
                    outputs = model(images)[0]
                else:
                    outputs = model(images)
                # time.sleep(10)
                _, predicted = torch.max(outputs, 1)
                total_correct_train += (predicted == labels).sum().item()
                total_samples_train += labels.size(0)

    with torch.inference_mode():
        with amp_ctx:
            for images, labels in test_dataloader:
                images = images.to(args.device)
                labels = labels.to(args.device)
                if args.method == 'FedCIL':
                    outputs = model(images)[2]
                elif args.method == 'AFFCL':
                    outputs = model(images)[0]
                else:
                    outputs = model(images)
                _, predicted = torch.max(outputs, 1)
                total_correct_test += (predicted == labels).sum().item()
                total_samples_test += labels.size(0)

    avg_train_acc = total_correct_train / total_samples_train if total_samples_train > 0 else 0
    avg_test_acc = total_correct_test / total_samples_test if total_samples_test > 0 else 0

    ##forgetting
    a_t_i = []
    with torch.inference_mode():
        with amp_ctx:
            for task_id in range(len(previous_test)):
                val_dataloader_id = torch.utils.data.DataLoader(previous_test[task_id],
                                                                batch_size=args.batch_size, shuffle=False)
                total_correct_test_id, total_samples_test_id = 0, 0
                for images, labels in val_dataloader_id:
                    images = images.to(args.device)
                    labels = labels.to(args.device)
                    if args.method == 'FedCIL':
                        outputs = model(images)[2]
                    elif args.method == 'AFFCL':
                        outputs = model(images)[0]
                    else:
                        outputs = model(images)
                    _, predicted = torch.max(outputs, 1)
                    total_correct_test_id += (predicted == labels).sum().item()
                    total_samples_test_id += labels.size(0)
                a_t_i.append(round(total_correct_test_id / total_samples_test_id if total_samples_test_id > 0 else 0, 4))
            all_test_acc = sum(a_t_i) / len(a_t_i)

    ##all previous test
    all_test_data = previous_test[:-1]
    if len(all_test_data) >= 1:
        all_test_dataset = ConcatDataset(all_test_data)
        all_test_loader = torch.utils.data.DataLoader(all_test_dataset, batch_size=args.batch_size, shuffle=False)
        total_correct_all_test, total_samples_all_test = 0, 0
        with torch.inference_mode():
            with amp_ctx:
                for images, labels in all_test_loader:
                    images = images.to(args.device)
                    labels = labels.to(args.device)
                    if args.method == 'FedCIL':
                        outputs = model(images)[2]
                    elif args.method == 'AFFCL':
                        outputs = model(images)[0]
                    else:
                        outputs = model(images)
                    # time.sleep(10)
                    _, predicted = torch.max(outputs, 1)
                    total_correct_all_test += (predicted == labels).sum().item()
                    total_samples_all_test += labels.size(0)
            all_previous_test_acc = total_correct_all_test / total_samples_all_test if total_samples_all_test > 0 else 0
    else:
        all_previous_test_acc = 0

    return avg_train_acc, avg_test_acc, a_t_i, all_previous_test_acc, all_test_acc
