import torch
import copy
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm
import sys
import os
sys.path.append(os.path.abspath('fxh/seq_unlearn'))
from utils import EarlyStopping, create_pseudo_labels, evaluate_accuracy
from models import CustomModel
import time


def retrain_baseline(train_dataset, indices, subset_indexs, T, eta, batch_size, epochs, num_classes, 
                     model_type, pretrained_tag=True, early_stopping=True, save_interval=1, device="cuda:0", save_path="", test_loader=None):
    # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    early_stopping_tag = early_stopping

    # indices = np.arange(len(train_dataset))
    indices = indices
    Rt_indices = copy.deepcopy(indices)
    prev_F_t_1_indices = []
           
    T_opt_models = []
    for t in range(1, T+1):

        start = time.time()

        # initialize
        baseline_model = CustomModel(model_name = model_type, num_classes=num_classes, pretrained=pretrained_tag).to(device)
        loss_fn = nn.CrossEntropyLoss()
        baseline_optimizer = optim.SGD(baseline_model.parameters(), lr=eta, momentum=0.9, weight_decay=5e-4)
        baseline_model.train()

        # get unlearning samples id
        Ft_indices = subset_indexs[t-1]
        Rt_indices = np.setdiff1d(indices, [*prev_F_t_1_indices, *Ft_indices])
        
        if len(Rt_indices) > 0:
            Rt_loader = DataLoader(Subset(train_dataset, Rt_indices), batch_size=batch_size, shuffle=True)
            print(f'Current time {t}---------R loader length (batch num): {len(Rt_loader)}')

            if early_stopping_tag:
                early_stopping = EarlyStopping(patience=5, verbose=True, delta=0.01)

            # retrain epochs
            for epoch in range(epochs):
                epoch_loss = torch.tensor(0.0).to(device)
                samples = 0
                for inputs, targets in Rt_loader:
                    inputs, targets = inputs.to(device), targets.to(device)
                    baseline_optimizer.zero_grad()
                    outputs = baseline_model(inputs)
                    batch_loss_mean = loss_fn(outputs, targets)
                    epoch_loss += batch_loss_mean * inputs.size(0)
                    samples += inputs.size(0)
                    batch_loss_mean.backward()
                    baseline_optimizer.step()
                mean_epoch_loss = epoch_loss / samples
                print(f'Retrain model----------epoch: {epoch}---------total epoch loss: {epoch_loss}---------mean epoch loss: {mean_epoch_loss}')

                if early_stopping_tag:
                    early_stopping(mean_epoch_loss.item())
                    if early_stopping.early_stop:
                        print(f'Early stopping at epoch: {epoch}')
                        break

        end = time.time()

        efficiency = end - start

        # save current time optimal model
        if t % save_interval == 0:
            T_opt_models.append(copy.deepcopy(baseline_model).cpu())
            print(f'Optimal model of time {t} has saved.')

        # evaluate
        print(f'indices length: {len(indices)}, Ft_indices length:{len(Ft_indices)},  Rt_indices length: {len( Rt_indices)}, prev_F_t_1_indices length: {len(prev_F_t_1_indices)}')
        if len(Rt_indices) == 0:
            Acc_Rt = 0
        else:
            Acc_Rt = evaluate_accuracy(baseline_model, Rt_loader)
            
        Ft_loader = DataLoader(Subset(train_dataset, Ft_indices), batch_size=batch_size, shuffle=False)
        Acc_Ft = evaluate_accuracy(baseline_model, Ft_loader)

        if t == 1:
            Acc_F_t_1 = 0.0
        else:
            Acc_F_t_1 = evaluate_accuracy(baseline_model, DataLoader(Subset(train_dataset, prev_F_t_1_indices), batch_size=batch_size, shuffle=False))
        
        prev_F_t_1_indices.extend(Ft_indices)

        Acc_test = evaluate_accuracy(baseline_model, test_loader)

        print(f"Retrain time {t}: Acc_Ft: {Acc_Ft:.4f}, Acc_Rt: {Acc_Rt:.4f}, Acc_F_t-1: {Acc_F_t_1:.4f}, Acc_test: {Acc_test:.4f}, time: {efficiency:.4f}")
    
    current_model_path = save_path + '-retrain' + '.pth'
    print(current_model_path)
    torch.save(T_opt_models, current_model_path)
    return T_opt_models