import argparse
import random
from tqdm import tqdm

import numpy as np
from sklearn.cluster import KMeans
import torch
import torch.nn as nn
from torch.optim import lr_scheduler
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')

from datasets.augmentations import get_transform
from datasets import make_dataloader, get_class_splits
from loss.cl_loss import info_nce_logits
from loss.supcon_loss import SupConLoss
from model.backbones import vision_transformer as vits

from utils.optimizer import build_optimizer, get_mean_lr
from utils.utils import AverageMeter, str2bool, init_experiment, ContrastiveLearningViewGenerator, set_seed
from utils.cluster_utils import log_accs_from_preds
from config import cfg

# TODO: Debug
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)


class PretrainModelManager:
    def __init__(self, model, projection_head, train_loader, train_loader_unlabeled, test_loader, args, cfg):
        self.args = args
        self.cfg = cfg
        self.device = args.device
        self.model = model.to(self.device)
        self.projection_head = projection_head.to(self.device)
        self.train_loader = train_loader
        self.train_loader_unlabeled = train_loader_unlabeled
        self.test_loader = test_loader
        self.writer = args.writer

    def train(self):

        optimizer = build_optimizer(self.cfg, list(projection_head.parameters()) + list(self.model.parameters()))
        exp_lr_scheduler = lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=self.cfg.SOLVER.MAX_EPOCHS,
            eta_min=0.002 * self.cfg.SOLVER.BASE_LR,
        )

        sup_con_crit = SupConLoss()
        best_test_acc_lab = 0

        for epoch in range(self.cfg.SOLVER.MAX_EPOCHS):
            loss_record = AverageMeter()
            train_acc_record = AverageMeter()

            self.projection_head.train()
            self.model.train()
            for _, (img, vid, idx, mask_lab) in enumerate(tqdm(self.train_loader)):
                images, class_labels, mask_lab = img, vid, mask_lab
                mask_lab = mask_lab[:, 0]
                class_labels, mask_lab = class_labels.to(self.device), mask_lab.to(self.device).bool()
                images = torch.cat(images, dim=0).to(self.device)

                ####  Extract features with base model  ####
                features = self.model(images)
                features, logits = self.projection_head(features)
                features = torch.nn.functional.normalize(features, dim=-1) # L2-normalize features
                ####  Choose which instances to run the contrastive loss on  ####
                if self.args.contrast_unlabel_only:
                    # Contrastive loss only on unlabelled instances
                    f1, f2 = [f[~mask_lab] for f in features.chunk(2)]
                    con_feats = torch.cat([f1, f2], dim=0)
                else:
                    # Contrastive loss for all examples
                    con_feats = features

                ###  ALL Contrastive Loss  ####
                contrastive_logits, contrastive_labels = info_nce_logits(features=con_feats, args=args)
                contrastive_loss = torch.nn.CrossEntropyLoss()(contrastive_logits, contrastive_labels)

                ####  Supervised contrastive loss  ####
                f1, f2 = [f[mask_lab] for f in features.chunk(2)]
                if f1.shape[0] == 0 or f2.shape[0] == 0:
                    continue
                sup_con_feats = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1)
                sup_con_labels = class_labels[mask_lab]

                sup_con_loss = sup_con_crit(sup_con_feats, labels=sup_con_labels)
                ####  Total loss  ####
                loss = (1 - self.cfg.LOSS.W_SCL) * contrastive_loss + self.cfg.LOSS.W_SCL * sup_con_loss
                print('CL Loss: {}; SCL Loss: {}'.format(contrastive_loss.item(), sup_con_loss.item()))

                # Train acc
                _, pred = contrastive_logits.max(1)
                acc = (pred == contrastive_labels).float().mean().item()
                train_acc_record.update(acc, pred.size(0))

                loss_record.update(loss.item(), class_labels.size(0))
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()


            print('Train Epoch: {} Avg Loss: {:.4f} | Seen Class Acc: {:.4f} '.format(epoch, loss_record.avg,
                                                                                    train_acc_record.avg))


            with torch.no_grad():

                print('Testing on unlabelled examples in the training data...')
                all_acc, old_acc, new_acc = self.test_kmeans(epoch=epoch, loader=self.train_loader_unlabeled, save_name='Train ACC Unlabelled')
                print('Testing on disjoint test set...')
                all_acc_test, old_acc_test, new_acc_test = self.test_kmeans(epoch=epoch, loader=self.test_loader, save_name='Test ACC')

            # ----------------
            # LOG
            # ----------------
            self.writer.add_scalar('Loss', loss_record.avg, epoch)
            self.writer.add_scalar('Train Acc Labelled Data', train_acc_record.avg, epoch)
            self.writer.add_scalar('LR', get_mean_lr(optimizer), epoch)

            print('Train Accuracies: All {:.4f} | Old {:.4f} | New {:.4f}'.format(all_acc, old_acc,
                                                                                new_acc))
            print('Test Accuracies: All {:.4f} | Old {:.4f} | New {:.4f}'.format(all_acc_test, old_acc_test,
                                                                                    new_acc_test))

            # Step schedule
            exp_lr_scheduler.step()

            torch.save(self.model.state_dict(), self.args.model_path)
            print("model saved to {}.".format(self.args.model_path))

            torch.save(self.projection_head.state_dict(), self.args.model_path[:-3] + '_proj_head.pt')
            print("projection head saved to {}.".format(self.args.model_path[:-3] + '_proj_head.pt'))

            if old_acc_test > best_test_acc_lab:

                print(f'Best ACC on old Classes on disjoint test set: {old_acc_test:.4f}...')
                print('Best Train Accuracies: All {:.4f} | Old {:.4f} | New {:.4f}'.format(all_acc, old_acc,
                                                                                    new_acc))

                torch.save(self.model.state_dict(), self.args.model_path[:-3] + f'_best.pt')
                print("model saved to {}.".format(self.args.model_path[:-3] + f'_best.pt'))

                torch.save(self.projection_head.state_dict(), self.args.model_path[:-3] + f'_proj_head_best.pt')
                print("projection head saved to {}.".format(self.args.model_path[:-3] + f'_proj_head_best.pt'))

                best_test_acc_lab = old_acc_test


    def test_kmeans(self, epoch, loader, save_name, mode='test'):

        self.model.eval()

        all_feats = []
        targets = np.array([])
        mask = np.array([])

        print('Collating features...')
        # First extract all features
        for batch_idx, batch in enumerate(tqdm(loader)):
            if mode == 'test':
                (images, label, _) = batch
            elif mode == 'train':
                (images, label, _, _) = batch
                images = images[0]
            images = images.cuda()
            feats = self.model(images)
            feats = torch.nn.functional.normalize(feats, dim=-1)
            all_feats.append(feats.cpu().numpy())
            targets = np.append(targets, label.cpu().numpy())
            mask = np.append(mask, np.array([True if x.item() in range(self.args.num_known_classes)
                                            else False for x in label]))

        # -----------------------
        # K-MEANS
        # -----------------------
        print('Fitting K-Means...')
        all_feats = np.concatenate(all_feats)
        kmeans = KMeans(n_clusters=self.args.num_all_classes,
                        n_init='auto', random_state=0).fit(all_feats)
        preds = kmeans.labels_
        print('Done!')

        # -----------------------
        # EVALUATE
        # -----------------------
        all_acc, old_acc, new_acc = log_accs_from_preds(
            y_true=targets, y_pred=preds, mask=mask,
            T=epoch, eval_funcs=self.args.eval_funcs, save_name=save_name,
            writer=self.writer
        )

        return all_acc, old_acc, new_acc


if __name__ == "__main__":

    parser = argparse.ArgumentParser(
            description='cluster',
            formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        "--config_file", default="", help="path to config file", type=str
    )
    parser.add_argument("opts", help="Modify config options using the command-line", default=None,
                        nargs=argparse.REMAINDER)
    parser.add_argument('--eval_funcs', nargs='+', help='Which eval functions to use', default=['v1', 'v2'])
    parser.add_argument('--grad_from_block', type=int, default=11)
    parser.add_argument('--transform', type=str, default='imagenet')
    parser.add_argument('--contrast_unlabel_only', type=str2bool, default=False)

    # ----------------------
    # INIT
    # ----------------------
    # Merge args and config
    args = parser.parse_args()
    if args.config_file != "":
        cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    args.dataset_name = cfg.DATASETS.NAMES
    args = get_class_splits(args)
    args.batch_size = cfg.SOLVER.IMS_PER_BATCH
    args.n_views = cfg.DATALOADER.N_VIEWS
    args.temperature = cfg.LOSS.TEMP_INFO
    args.exp_root = cfg.OUTPUT_DIR
    # Load experiments settings
    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args.n_gpu = torch.cuda.device_count()
    print('Num of GPUs: {}'.format(args.n_gpu))
    init_experiment(args, runner_name=['metric_learn_gcd'])
    print(f'Using evaluation function {args.eval_funcs[0]} to print results')
    set_seed(cfg.SOLVER.SEED)

    # NOTE: Hardcoded image size as we do not finetune the entire ViT model
    args.image_size = 224
    args.feat_dim = 768
    args.num_mlp_layers = 3
    args.mlp_out_dim = 65536
    args.interpolation = 3
    args.crop_pct = 0.875

    # ----------------- ---
    # CONTRASTIVE TRANSFORM
    # --------------------
    train_transforms, test_transforms = get_transform(args.transform, image_size=args.image_size, args=args)
    train_transforms = ContrastiveLearningViewGenerator(base_transform=train_transforms, n_views=args.n_views)

    # --------------------
    # DATALOADER
    # --------------------
    # DATALOADER
    train_loader, train_loader_unlabelled, test_loader, num_known_classes, num_unknown_classes, num_all_classes = \
        make_dataloader(cfg, args, train_transforms=train_transforms, test_transforms=test_transforms)
    args.num_private_classes = len(args.private_classes)
    args.num_known_classes = num_known_classes
    args.num_unknown_classes = num_unknown_classes
    args.num_all_classes = num_all_classes
    cfg.MODEL.PROTO_NUM = num_unknown_classes

    # ----------------------
    # BASE MODEL
    # ----------------------
    if cfg.MODEL.NAME == 'dino':
        # ----------------------
        # PRETRAIN MODEL
        # ----------------------
        pretrain_path = cfg.MODEL.PRETRAIN_PATH
        print("cfg.MODEL.PRETRAIN_PATH: {}".format(cfg.MODEL.PRETRAIN_PATH))
        model = vits.__dict__['vit_base'](img_size=cfg.INPUT.SIZE_CROP, \
            stride_size=cfg.MODEL.STRIDE_SIZE, drop_path_rate=cfg.MODEL.DROP_PATH, cfg=cfg, args=args)
        # state_dict = torch.load(pretrain_path, map_location='cpu')
        model.load_param_finetune(pretrain_path)
        if args.n_gpu > 1:
            model = nn.DataParallel(model)

        # ----------------------
        # HOW MUCH OF BASE MODEL TO FINETUNE
        # ----------------------
        for m in model.parameters():
            m.requires_grad = False
        # Only finetune layers from block 'args.grad_from_block' onwards
        for name, m in model.named_parameters():
            if 'block' in name:
                if args.n_gpu > 1:
                    block_num = int(name.split('.')[2])
                else:
                    block_num = int(name.split('.')[1])
                if block_num >= args.grad_from_block:
                    m.requires_grad = True
        # ----------------------
        # PROJECTION HEAD
        # ----------------------
        projection_head = vits.__dict__['DINOHead'](in_dim=args.feat_dim,
                                out_dim=args.mlp_out_dim, nlayers=args.num_mlp_layers, args=args)
        if cfg.MODEL.PRETRAIN_PROJ_PATH != '':
            # proj_state_dict = torch.load(cfg.MODEL.PRETRAIN_PROJ_PATH, map_location='cpu')
            # projection_head.load_state_dict(proj_state_dict)
            projection_head.load_param_finetune(cfg.MODEL.PRETRAIN_PROJ_PATH)
        if args.n_gpu > 1:
            projection_head = nn.DataParallel(projection_head)
        projection_head.to(args.device)
    else:

        raise NotImplementedError

    pt_manager = PretrainModelManager(model, projection_head, train_loader, train_loader_unlabelled, test_loader, args, cfg)
    # ----------------------
    # PRETRAIN
    # ----------------------
    print('Pre-training begin...')
    pt_manager.train()
    print('Pre-training finished!')