import os
import time
import torch

from src.args import parse_arguments
from src.datasets.common import get_dataloader, maybe_dictionarize
from src.datasets.registry import get_dataset
from src.eval import evaluate, eval_single_dataset
from src.modeling import ImageEncoder, ImageClassifier, MultiHeadImageClassifier
from src.utils import cosine_lr, LabelSmoothing, state_dict_to_vector
from src.heads import get_classification_head

def recycle(iterable):
  """Variant of itertools.cycle that does not save iterates."""
  while True:
    for i in iterable:
      yield i


def finetune(args, step):
    update_every = args.update_every
    train_dataset = args.train_dataset
    ckpdir = os.path.join(args.save, train_dataset)

    # Check if checkpoints already exist
    merge_path = os.path.join(args.save, f'uniform_soup_{step}.pt')
    # ft_path = os.path.join(ckpdir, f'checkpoint_{step}.pt')

    assert train_dataset is not None, "Please provide a training dataset."

    if step > 0:
        assert os.path.exists(merge_path), "Please provide a valid merge path."

    if os.path.exists(merge_path):
        image_encoder = ImageEncoder.load(merge_path)
    else:
        print('Building image encoder.')
        image_encoder = ImageEncoder(args, keep_lang=False)

    classification_head = get_classification_head(args, train_dataset)

    model = ImageClassifier(image_encoder, classification_head)

    model.freeze_head()

    preprocess_fn = model.train_preprocess
    print_every = 100

    dataset = get_dataset(
        train_dataset,
        preprocess_fn,
        location=args.data_location,
        batch_size=args.batch_size
    )

    devices = list(range(torch.cuda.device_count()))
    print('Using devices', devices)
    model = torch.nn.DataParallel(model, device_ids=devices)

    if args.ls > 0:
        loss_fn = LabelSmoothing(args.ls)
    else:
        loss_fn = torch.nn.CrossEntropyLoss()

    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.AdamW(params, lr=args.lr, weight_decay=args.wd)

    scheduler = cosine_lr(optimizer, args.lr, args.warmup_length, args.steps)

    model = model.cuda()
    model.train()
    data_loader = get_dataloader(
        dataset, is_train=True, args=args, image_encoder=None)

    for batch in recycle(data_loader):
        start_time = time.time()

        scheduler(step)
        optimizer.zero_grad()

        batch = maybe_dictionarize(batch)
        inputs = batch['images'].to('cuda:0')
        labels = batch['labels'].to('cuda:0')
        data_time = time.time() - start_time

        logits = model(inputs)

        loss = loss_fn(logits, labels)

        loss.backward()

        torch.nn.utils.clip_grad_norm_(params, 1.0)

        optimizer.step()
        step += 1
        batch_time = time.time() - start_time

        if step % print_every == 0:
            # Evaluate
            image_encoder = model.module.image_encoder
            results = eval_single_dataset(image_encoder, train_dataset, args,)

            percent_complete = 100 * step / args.steps
            print(
                f"Train Step: {step} [{percent_complete:.0f}% {step%update_every}/{update_every}]\t"
                f"Val Acc: {results['top1']}"
                f"Loss: {loss.item():.6f}\tData (t) {data_time:.3f}\tBatch (t) {batch_time:.3f}", flush=True
            )

        if step % update_every == 0:
            os.makedirs(ckpdir, exist_ok=True)
            model_path = os.path.join(ckpdir, f'checkpoint_{step}.pt')
            model.module.image_encoder.save(model_path)
            break


if __name__ == '__main__':
    args = parse_arguments()
    data_location = '/root/dataset'
    models = ['ViT-B-32', 'ViT-B-16', 'ViT-L-14']
    datasets = ['Cars', 'DTD', 'EuroSAT', 'GTSRB', 'MNIST', 'RESISC45', 'SUN397', 'SVHN']
    NUM_MODELS = len(datasets)
    update_every = args.update_every
    steps = args.steps

    for model in models:
        cur_steps = {
            'Cars': 0,
            'DTD': 0,
            'EuroSAT': 0,
            'GTSRB': 0,
            'MNIST': 0,
            'RESISC45': 0,
            'SUN397': 0,
            'SVHN': 0,
        }

        # initialize the split
        args.save = f'checkpoints_iterative_merge_{steps}_{update_every}/{model}'
        overall_merge_acc = []
        while all(v < steps for v in cur_steps.values()):
            for idx, dataset in enumerate(datasets):
                print('=' * 100)
                print(f'Finetuning {model} on {dataset} with index {idx}')
                print('=' * 100)

                args.lr = 1e-5
                args.data_location = data_location
                args.train_dataset = dataset + 'Val'
                args.batch_size = 128
                args.model = model

                finetune(args, cur_steps[dataset])
                cur_steps[dataset] += update_every

            # merge finetuned into one
            print("current step:", cur_steps)
            # assert all values in cur_steps are the same
            assert all(v == cur_steps['Cars'] for v in cur_steps.values()), "All steps should be the same"

            merged = ImageEncoder(args, keep_lang=False)
            # create the uniform soup sequentially to not overload memory
            model_paths = [os.path.join(args.save, dataset + 'Val', f'checkpoint_{cur_steps[dataset]}.pt') for dataset in datasets]
            for j, model_path in enumerate(model_paths):

                print(f'Adding model {j} of {NUM_MODELS - 1} to uniform soup.')

                assert os.path.exists(model_path)
                state_dict = ImageEncoder.load(model_path).state_dict()
                if j == 0:
                    uniform_soup = {k: v * (1. / NUM_MODELS) for k, v in state_dict.items()}
                else:
                    uniform_soup = {k: v * (1. / NUM_MODELS) + uniform_soup[k] for k, v in state_dict.items()}

            merged.load_state_dict(uniform_soup)
            step = cur_steps['Cars']
            merged.save(os.path.join(args.save, f'uniform_soup_{step}.pt'))

            # Evaluate
            raw_acc = []
            for dataset in datasets:
                results = eval_single_dataset(merged, dataset, args)
                raw_acc.append(results['top1'])

            print(f'Accuracy: {raw_acc}')
            overall_merge_acc.append(raw_acc)

        print(f'Overall Accuracy: {overall_merge_acc}')

