
import os
import random
import time
from argparse import Namespace
from datetime import datetime

import hydra
import numpy as np
import omegaconf
import torch

import wandb
from clip import clip
from clip.modified_clip import make_model
from dataset.aircraft import SplitAircraft
from dataset.birdsnap import SplitBirdsnap
from dataset.cars import SplitCars
from dataset.cifar100 import SplitCifar100
from dataset.country211 import SplitCountry211
from dataset.cub import CUB
from dataset.domainnet import DomainNetDataset
from dataset.eurosat import SplitEuroSAT
from dataset.gtsrb import SplitGTSRB
from dataset.seqset import SeqSet
from distributed import init_distributed_device
from trainer import METHOD


def random_seed(seed=42, rank=0):
    torch.manual_seed(seed + rank)
    np.random.seed(seed + rank)
    random.seed(seed + rank)


@hydra.main(version_base=None, config_path="config", config_name="base")
def main(args):
    args = omegaconf.OmegaConf.to_container(args)
    args = Namespace(**args)
    if args.method == 'lwf':
        args.method='distillboth'
        args.batch_size //= 2
        args.distill_loss='visual'
        args.lr=1e-6
        args.epochs=5
        args.tem=5.0
        args.scale=0.01
    if args.method == 'prd':
        args.method='distillboth'
        args.batch_size //= 2
        args.distill_loss='text'
        args.lr=1e-6
        args.epochs=5
        args.tem=1.0
        args.scale=0.1
    if args.method == 'mas':
        args.buffer_size=0.0
        args.lr=2.5e-6
        args.epochs=10
    if args.method == 'flyp':
        args.method='Finetune'
        args.buffer_size = 0
    if args.method == 'maskedit':
        args.cur_importance_batch_percentage = 0.0
        args.scale=0.0









    start = time.time()

    random_seed(args.seed)

    init_distributed_device(args)
    if torch.cuda.is_available():
        # This enables tf32 on Ampere GPUs which is only 8% slower than
        # float16 and almost as accurate as float32
        # This was a default in pytorch until 1.12
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.deterministic = False

    # get the name of the experiments and setup logging
    if args.name is None:
        date_str = datetime.now().strftime("%Y_%m_%d-%H_%M_%S")
        args.name = '-'.join([
            args.method,
            args.dataset,
            os.environ.get("SLURM_JOB_ID", ""),
        ])
    log_base_path = os.path.join(args.logs, args.name)

    os.makedirs(log_base_path, exist_ok=True)
    args.log_path = log_base_path
    args.wandb = 'wandb' in args.report_to or 'all' in args.report_to
    if args.wandb:

        wandb.init(
            # put your wandb initiation

        )

    print("{}".format(args).replace(', ', ',\n'))

    # set up model
    model, transform = clip.load(args.model, download_root='./clip_models/',args=args)
    if args.method == 'masklearn':
        masked_model = make_model(args, weight=model)
        del model
        model = masked_model
        device = "cuda" if torch.cuda.is_available() else "cpu"

        model.to(torch.device(device))

    args.hidden_size = model.visual.proj.shape[0]
    args.visual_layers = len(model.visual.transformer.resblocks)

    if args.dataset == 'domainnet':
        dataset = DomainNetDataset(args.data, args, transform)
    elif args.dataset == 'cifar100':
        dataset = SplitCifar100(args, args.data, transform)
    elif args.dataset == 'cars':
        dataset = SplitCars(args, transform=transform)
    elif args.dataset == 'cub':
        dataset = CUB(args, transform=transform)
    elif args.dataset == 'aircraft':
        dataset = SplitAircraft(args, transform=transform)
    elif args.dataset == 'eurosat':
        dataset = SplitEuroSAT(args, transform=transform)
    elif args.dataset == 'birdsnap':
        dataset = SplitBirdsnap(args, transform=transform)
    elif args.dataset == 'country211':
        dataset = SplitCountry211(args, transform=transform)
    elif args.dataset == 'gtsrb':
        dataset = SplitGTSRB(args, transform=transform)
    elif '-' in args.dataset:
        dataset = SeqSet(args, args.data, transform=transform)
    else:
        raise ValueError



    args.num_classes = dataset.num_classes
    args.num_tasks = dataset.num_tasks
    args.scenario = dataset.scenario
    Trainer = METHOD[args.method](args)

    for task in range(dataset.num_tasks):
        if args.sweep and task == 3:
            break
        print(f'Train task {task}')
        if args.evaluation:
            Trainer.only_evaluation(model, dataset, task)
            break
        Trainer.train(model, dataset, task)
        Trainer.evaluation(model, dataset, task)
        Trainer.save_checkpoint(model, task, args)

    print(f'Total training time in hours: {(time.time() - start) / 3600: .3f}')

    if args.wandb:
        wandb.finish()


if __name__ == '__main__':
    main()
