import torch
from torch.utils.data import DataLoader, Subset
import numpy as np
import sys
import os
sys.path.append(os.path.abspath('fxh/seq_unlearn'))
from utils import evaluate_accuracy
from models import CustomModel

def original_baseline(train_dataset, indices, subset_indexs, T, batch_size, 
                      num_classes, model_type, resume_model_path, device, test_loader):
    # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # indices = np.arange(len(train_dataset))
    indices = indices
    print(model_type, num_classes, resume_model_path)
    baseline_model = CustomModel(model_name=model_type, num_classes=num_classes, pretrained=False, model_path=resume_model_path).to(device)

    prev_F_t_1_indices = []

    for t in range(1, T+1):
        # Ft_indices = indices[(t-1)*1000 : t*1000]
        Ft_indices = subset_indexs[t-1]
        Rt_indices = np.setdiff1d(indices, [*prev_F_t_1_indices, *Ft_indices])
        if len(Rt_indices) == 0:
            Acc_Rt = 0
        else:
            Rt_loader = DataLoader(Subset(train_dataset, Rt_indices), batch_size=batch_size)
            Acc_Rt = evaluate_accuracy(baseline_model, Rt_loader)
        
        print(f'indices length: {len(indices)}, Ft_indices length:{len(Ft_indices)},  Rt_indices length: {len( Rt_indices)}, prev_F_t_1_indices length: {len(prev_F_t_1_indices)}')
        Ft_loader = DataLoader(Subset(train_dataset, Ft_indices), batch_size=batch_size, shuffle=False)
        Acc_Ft = evaluate_accuracy(baseline_model, Ft_loader)

        if t == 1:
            Acc_F_t_1 = 0.0
        else:
            F_t_1_indices = prev_F_t_1_indices
            Acc_F_t_1 = evaluate_accuracy(baseline_model, DataLoader(Subset(train_dataset, F_t_1_indices), batch_size=batch_size, shuffle=False))
        
        prev_F_t_1_indices.extend(Ft_indices)

        Acc_test = evaluate_accuracy(baseline_model, test_loader)

        print(f"Original time {t}: Acc_Ft: {Acc_Ft:.4f}, Acc_Rt: {Acc_Rt:.4f}, Acc_F_t-1: {Acc_F_t_1:.4f}, Acc_test: {Acc_test:.4f}")
    return baseline_model