import time
import numpy as np
import os
import datetime
import torch
import clip
from torch import optim, nn
from tqdm import tqdm
from support.utils import get_features, init_seeds, get_data_provider_by_name
from torch.utils.data import DataLoader
from data_providers.my_dataloader.my_ext_data_loader import MYEXTDATASET

from external.classification.train import train_one_epoch, evaluate

device = "cuda:0" if torch.cuda.is_available() else "cpu"
clip_backbone = "ViT-B/32"
if clip_backbone == "ViT-B/32":
    fea_dim = 512
else:
    fea_dim = 1024


class MLP(nn.Module):
    def __init__(self, in_num, out_num=64, hidden_num=128):
        super().__init__()
        self.in_fea_dim = in_num
        self.out_fea_dim = out_num
        self.hidden_num = hidden_num
        self.layers = nn.Sequential(
            # nn.Flatten(),
            nn.Linear(in_num, hidden_num),
            nn.ReLU(),
            nn.Linear(hidden_num, hidden_num),
            nn.ReLU(),
            nn.Linear(hidden_num, out_num),
            # nn.Softmax(1),
        )

    def forward(self, x):
        return self.layers(x)
    

def train_model(model, lab_loader, test_loader, train_epochs, output_dir, args, save_file_name='', replace=False):
    """
    Trains a model using the provided data loaders for a specified number of epochs.

    Args:
        model: torch model
        lab_loader (torch.utils.data.DataLoader): Data loader for the training dataset.
        test_loader (torch.utils.data.DataLoader): Data loader for the testing dataset.
        train_epochs (int): Number of epochs to train the model.
        output_dir (str): Directory to save the trained model checkpoints.
        args: Additional arguments for training.

    Returns:
        model (torch.nn.Module): Trained model.

    """
    # model = resnet50(weights=ResNet50_Weights.DEFAULT).to(device)
    # preprocess = ResNet50_Weights.DEFAULT.transforms()
    criterion = nn.CrossEntropyLoss()
    if (not replace) and os.path.exists(os.path.join(output_dir, f"model_{train_epochs}{'_'+save_file_name if len(save_file_name) > 0 else ''}.pth")):
        ckpt = torch.load(os.path.join(output_dir, f"model_{train_epochs}{'_' + save_file_name if len(save_file_name) > 0 else ''}.pth"), map_location=device)
        model.load_state_dict(ckpt['model'])
        return model, criterion

    optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9, weight_decay=5e-4, nesterov=True)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.2)
    print("Start training")
    start_time = time.time()
    for epoch in range(train_epochs):
        train_one_epoch(model, criterion, optimizer, lab_loader, device, epoch, args)
        lr_scheduler.step()
        evaluate(model, criterion, test_loader, device=device)
        if (epoch+1) % args.step_size == 0 and epoch > 0:
            checkpoint = {
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "lr_scheduler": lr_scheduler.state_dict(),
                "epoch": epoch,
            }
            torch.save(checkpoint, os.path.join(output_dir, f"model_{epoch+1}{'_'+save_file_name if len(save_file_name) > 0 else ''}.pth"))
            # torch.save(checkpoint, os.path.join(output_dir, "checkpoint.pth"))

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print(f"Training time {total_time_str}")
    return model, criterion


def get_train_lr(dataloader):
    if 'cifar100' in dataloader.dataset.root:
        train_lr = 1e-2
    elif 'imagenet' in dataloader.dataset.root:
        train_lr = 1e-2
    else:
        train_lr = 1e-3
    return train_lr


def get_val_loss(cla_model, val_x, val_y):
    # get
    cla_model.eval()
    loss_ce = nn.CrossEntropyLoss()
    with torch.no_grad():
        predictions = cla_model(val_x)
        error_rate = loss_ce(predictions, val_y)
        accuracy = sum((torch.argmax(predictions, dim=1) == val_y)).cpu() / len(val_y)
    return error_rate, accuracy


def train_classifier(clip_model, cla_model, train_loader, test_loader, lr):
    # Calculate the image features
    train_features, train_labels = get_features(clip_model, train_loader)
    test_features, test_labels = get_features(clip_model, test_loader)

    return train_with_fea_val(cla_model, train_features, train_labels, test_features, test_labels, lr)


def train_with_fea_val(classifier, train_fea, train_lab, test_fea, test_label, lr, max_train_epochs=0):
    classifier.train()
    loss_ce = nn.CrossEntropyLoss()
    optimizer = optim.Adam(classifier.parameters(), lr=lr)
    trf = torch.tensor(train_fea).to(device)
    trl = torch.tensor(train_lab).to(device)
    tef = torch.tensor(test_fea).to(device)
    tel = torch.tensor(test_label).to(device)
    if max_train_epochs == 0:
        epochs_total = train_fea.shape[0]
    else:
        epochs_total = min(max_train_epochs, train_fea.shape[0])
    accuracy = None
    pbar = tqdm(total=epochs_total)
    for ep in range(epochs_total):
        pred = classifier(trf)
        optimizer.zero_grad()
        lce = loss_ce(pred, trl)
        lce.backward()
        optimizer.step()

        if ep == epochs_total-1:
            classifier.eval()
            with torch.no_grad():
                predictions = classifier(tef)
                error_rate = loss_ce(predictions, tel)
                accuracy = sum((torch.argmax(predictions, dim=1) == tel)).cpu() / len(tel)
            
            pbar.set_postfix({"loss": lce, "acc": accuracy, "lr": lr})
        pbar.update(1)
        classifier.train()
    return classifier, error_rate.detach(), accuracy, loss_ce, trf, trl, tef, tel


def get_cla_clip_models(NCLASSES:int):
    print(clip.available_models())
    init_seeds(0)
    clip_model, clip_preprocess = clip.load(clip_backbone, device="cpu", jit=False)
    cla_model = MLP(in_num=fea_dim, hidden_num=128, out_num=NCLASSES).to(device)
    clip_model = clip_model.to(device)
    return cla_model, clip_model, clip_preprocess


def get_data_loader(iter:int, dataset:str, output_dir:str, clip_preprocess:object, split_valid_from_train=True, init_lab_size:int=500, validset_size:int=500):
    DATA_PROVIDER = get_data_provider_by_name(dataset)
    if iter == 0:
        # split the data and get the indexes
        dpv = DATA_PROVIDER(
            train_batch_size=32,
            test_batch_size=16,
            init_lab_size=init_lab_size,
            n_worker=0,
            resize_scale=0.08,
            distort_color="tf",
            image_size=224,
            num_replicas=None,
            lab_idx=None,
            unlab_idx=None,
            split_valid_from_train=split_valid_from_train,
            valid_size=validset_size,
        )
        valid_idx = dpv.val_indexes
        unlab_idx = dpv.unlab_indexes
        lab_idx = dpv.lab_indexes
        
        if dpv.val_indexes is not None:
            np.savetxt(os.path.join(output_dir, str(iter), 'valid_idx.txt'), valid_idx, fmt="%d")
        np.savetxt(os.path.join(output_dir, str(iter), 'unlab_idx.txt'), unlab_idx, fmt="%d")
        np.savetxt(os.path.join(output_dir, str(iter), 'lab_idx.txt'), lab_idx, fmt="%d")

    else:
        try:
            valid_idx = np.loadtxt(os.path.join(output_dir, str(iter), 'valid_idx.txt'), dtype=int)
        except FileNotFoundError as e:
            valid_idx = None
        unlab_idx = np.loadtxt(os.path.join(output_dir, str(iter), 'unlab_idx.txt'), dtype=int)
        lab_idx = np.loadtxt(os.path.join(output_dir, str(iter), 'lab_idx.txt'), dtype=int)

    # create data loaders
    dpv = DATA_PROVIDER(
        train_batch_size=128,
        test_batch_size=64,
        init_lab_size=init_lab_size,
        n_worker=0 if dataset != 'inaturalist21' else 8,
        resize_scale=0.08,
        distort_color="tf",
        image_size=224,
        num_replicas=None,
        lab_idx=lab_idx,
        val_idx=valid_idx,
        unlab_idx=unlab_idx,
        train_transform=clip_preprocess,
        valid_transform=clip_preprocess,
        split_valid_from_train=split_valid_from_train,
    )
    lab_loader = dpv.train
    unlab_loader = dpv.unlab
    val_loader = dpv.valid
    test_loader = dpv.test
    return lab_loader, unlab_loader, val_loader, test_loader, lab_idx, unlab_idx, valid_idx


def get_ext_dataloader(dataset_name, method_dir_name, iter, query_batch_size, clip_preprocess, class_to_idx, train_batch_size=128, ext_type="gen"):
    if "gen" in ext_type:
        data_dir = os.path.join(os.getcwd(), f"gen_img_save/{dataset_name}/{method_dir_name}/")
    elif "randomtxt" in ext_type:
        data_dir = os.path.join(os.getcwd(), f"rand_gen_img_save/{dataset_name}/{method_dir_name}/")
    else:   # random
        raise NotImplementedError()
    
    # create data loaders
    dataset = MYEXTDATASET(class_to_idx, data_root=data_dir, max_iter=iter, sampling_num=query_batch_size, 
                           transforms=clip_preprocess)
    loader = DataLoader(dataset, batch_size=train_batch_size, shuffle=True)
    return loader
