import datetime
import os
import time
import math
import logging
import pickle
import random
import warnings

import torch
import torch.utils.data
import numpy as np
import torch.nn.functional as F
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data.dataloader import default_collate
import mytorch.utils as utils
from mymodels import MODEL_DICT, WEIGHTS_DICT
from train_target_model import FeatureData, weighted_cross_entropy
from settings import selected_classes
from tqdm import tqdm


def cosine_similarity(features_A, features_B):
    '''
    features_A: [N_A, dim]
    features_B: [N_B, dim]
    '''
    features_A = features_A / features_A.norm(dim=-1, p=2, keepdim=True)
    features_B = features_B / features_B.norm(dim=-1, p=2, keepdim=True)

    simi = features_A @ features_B.permute(1, 0)
    return simi


class ConceptFeatureData(Dataset):
    def __init__(self, concept_features, concept_similarities):
        self.all_features = concept_features
        self.all_similarities = concept_similarities
        self.feature_dim = self.all_features.shape[-1]

    def __len__(self):
        return self.all_features.shape[0]
    
    def __getitem__(self, idx):
        return self.all_features[idx], self.all_similarities[idx]


class CavModel(nn.Module):
    def __init__(self, feature_dim):
        super(CavModel, self).__init__()
        self.cav = nn.Parameter(torch.randn(1, feature_dim))
        nn.init.kaiming_uniform_(self.cav, a=math.sqrt(5))

    def forward(self, x):
        '''
        Return: Cosine Similarities
        '''
        simi = cosine_similarity(self.cav, x)
        return simi.reshape(-1)


def get_cav_loss(target_simi, clip_simi, target_features, temperature=1.0):
    num_samples = target_simi.shape[0]

    # Split similarities & features
    target_posi_simi = target_simi[:num_samples // 2]
    target_nega_simi = target_simi[num_samples // 2:]
    clip_posi_simi = clip_simi[:num_samples // 2]
    clip_nega_simi = clip_simi[num_samples // 2:]
    target_posi_features = target_features[:num_samples // 2]
    target_nega_features = target_features[num_samples // 2:]

    # Calculate mean & std
    target_posi_to_posi_simi = cosine_similarity(target_posi_features, target_posi_features)
    target_posi_to_nega_simi = cosine_similarity(target_posi_features, target_nega_features)
    target_posi_to_posi_mean, target_posi_to_posi_std = target_posi_to_posi_simi.mean(), target_posi_to_posi_simi.std()
    target_posi_to_nega_mean, target_posi_to_nega_std = target_posi_to_nega_simi.mean(), target_posi_to_nega_simi.std()
    clip_posi_mean, clip_posi_std = clip_posi_simi.mean(), clip_posi_simi.std()
    clip_nega_mean, clip_nega_std = clip_nega_simi.mean(), clip_nega_simi.std()
    
    # Transform
    clip_posi_simi = target_posi_to_posi_std * ((clip_posi_simi - clip_posi_mean) / clip_posi_std) + target_posi_to_posi_mean
    clip_nega_simi = target_posi_to_nega_std * ((clip_nega_simi - clip_nega_mean) / clip_nega_std) + target_posi_to_nega_mean

    posi_cav_posi_cost = torch.abs(clip_posi_simi - target_posi_simi)
    cav_cost = posi_cav_posi_cost.mean()
    
    return cav_cost * temperature


def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # If you are using multi-GPU
    np.random.seed(seed)  # Numpy module
    random.seed(seed)  # Python random module
    torch.manual_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def train_one_epoch(model, criterion, optimizer, dataset, data_loader, device, epoch, args, model_ema=None, scaler=None):
    model.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
    metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}"))

    for i, (feature, clip_simi) in enumerate(data_loader):
        start_time = time.time()
        feature, clip_simi = feature.to(device), clip_simi.to(device)
        with torch.cuda.amp.autocast(enabled=scaler is not None):
            # output, target_similarities = model(feature)
            output = model(feature)
            cav_loss = get_cav_loss(output, clip_simi, dataset.all_features)
            loss = cav_loss

        optimizer.zero_grad()
        if scaler is not None:
            scaler.scale(loss).backward()
            if args.clip_grad_norm is not None:
                # we should unscale the gradients of optimizer's assigned params if do gradient clipping
                scaler.unscale_(optimizer)
                nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            if args.clip_grad_norm is not None:
                nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
            optimizer.step()

        if model_ema and i % args.model_ema_steps == 0:
            model_ema.update_parameters(model)
            if epoch < args.lr_warmup_epochs:
                # Reset ema buffer to keep copying weights during warmup period
                model_ema.n_averaged.fill_(0)

        batch_size = feature.shape[0]
        metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
        metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time))
    print("Loss: %.4f, Sum: %.4f" % (loss.item(), model.cav.sum().item()))
    # print(f"{header} Train Acc@1 {metric_logger.acc1.global_avg:.3f}")
    # return metric_logger.acc1.global_avg


def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=""):
    model.eval()
    metric_logger = utils.MetricLogger(delimiter="  ")
    header = f"Test: {log_suffix}"

    num_processed_samples = 0
    with torch.inference_mode():
        for image, target in data_loader:
            image = image.to(device, non_blocking=True)
            target = target.type(torch.LongTensor).to(device, non_blocking=True)
            output = model(image)
            if isinstance(output, tuple):
                output = output[0]
            loss = criterion(output, target)

            acc1 = utils.accuracy(output, target, topk=(1,))[0]
            # FIXME need to take into account that the datasets
            # could have been padded in distributed setup
            batch_size = image.shape[0]
            metric_logger.update(loss=loss.item())
            metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
            num_processed_samples += batch_size
    # gather the stats from all processes

    num_processed_samples = utils.reduce_across_processes(num_processed_samples)
    if (
        hasattr(data_loader.dataset, "__len__")
        and len(data_loader.dataset) != num_processed_samples
        and torch.distributed.get_rank() == 0
    ):
        # See FIXME above
        warnings.warn(
            f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} "
            "samples were used for the validation, which might bias the results. "
            "Try adjusting the batch size and / or the world size. "
            "Setting the world size to 1 is always a safe bet."
        )

    metric_logger.synchronize_between_processes()

    # print(f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f}")
    return metric_logger.acc1.global_avg


def main(args):
    set_seed(args.seed)
    # if args.output_dir:
    #     utils.mkdir(args.output_dir)

    utils.init_distributed_mode(args)
    print(args)

    device = torch.device(args.device)

    if args.use_deterministic_algorithms:
        torch.backends.cudnn.benchmark = False
        torch.use_deterministic_algorithms(True)
    else:
        torch.backends.cudnn.benchmark = True

    log_dir = os.path.join(args.output_dir, args.model)
    utils.mkdir(log_dir)
    filename = 'target_cav_img_similarities'
    log_file = os.path.join(log_dir, filename + '.log')
    simi_file = os.path.join(log_dir, filename + '.pth')
    cav_file = os.path.join(log_dir, 'target_cavs.pth')
    logging.basicConfig(
        filename=log_file,
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s'
    )
        
    all_features_path = os.path.join('saved_activations', args.model, 'target_model_all_train_features.pth')
    all_target_path = os.path.join('saved_activations', args.model, 'target_model_all_train_target.pth')
    clip_img_similarities_path = os.path.join('saved_contents', args.clip_model, 'clip_model_all_train_img_similarites(cls-relative-text).pth')
    # Load datasets
    all_features = torch.load(all_features_path)
    all_target = torch.load(all_target_path)
    clip_img_similarities = torch.load(clip_img_similarities_path).to(torch.float32)
    num_concepts = torch.unique(all_target).shape[0]

    logging.info("Start running with args: \n{}".format(args))
    print("Start training")
    start_time = time.time()
    all_cavs = []
    for concept_idx in tqdm(range(num_concepts)):
        concept_features = all_features
        ori_class_id = concept_idx
        concept_similarities = clip_img_similarities[:, ori_class_id]

        num_select_samples = 500
        all_sort_ids = concept_similarities.argsort(descending=True)
        posi_ids, nega_ids = all_sort_ids[:num_select_samples], all_sort_ids[-num_select_samples:]
        concept_features = torch.cat((concept_features[posi_ids], concept_features[nega_ids]), dim=0)
        concept_similarities = torch.cat((concept_similarities[posi_ids], concept_similarities[nega_ids]), dim=0)

        dataset = ConceptFeatureData(concept_features=concept_features, concept_similarities=concept_similarities)
        # train_sampler = torch.utils.data.RandomSampler(dataset)
        train_sampler = torch.utils.data.SequentialSampler(dataset)
    
        collate_fn = default_collate
        data_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=args.batch_size,
            sampler=train_sampler,
            num_workers=args.workers,
            pin_memory=True,
            collate_fn=collate_fn,
        )
    
        # Cav model
        model = CavModel(feature_dim=dataset.feature_dim)

        # Init CAV
        # model.cav.data = dataset.all_features[:num_select_samples].mean(dim=0, keepdim=True)
        model.to(device)

        custom_keys_weight_decay = []
        if args.bias_weight_decay is not None:
            custom_keys_weight_decay.append(("bias", args.bias_weight_decay))
        if args.transformer_embedding_decay is not None:
            for key in ["class_token", "position_embedding", "relative_position_bias_table"]:
                custom_keys_weight_decay.append((key, args.transformer_embedding_decay))
        parameters = utils.set_weight_decay(
            model,
            args.weight_decay,
            norm_weight_decay=args.norm_weight_decay,
            custom_keys_weight_decay=custom_keys_weight_decay if len(custom_keys_weight_decay) > 0 else None,
        )
        
        opt_name = args.opt.lower()
        if opt_name.startswith("sgd"):
            optimizer = torch.optim.SGD(
                parameters,
                lr=args.lr,
                momentum=args.momentum,
                weight_decay=args.weight_decay,
                nesterov="nesterov" in opt_name,
            )
        elif opt_name == "adamw":
            optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay)

        scaler = torch.cuda.amp.GradScaler() if args.amp else None

        args.lr_scheduler = args.lr_scheduler.lower()
        if args.lr_scheduler == "steplr":
            main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)

        lr_scheduler = main_lr_scheduler
        model_ema = None

        for epoch in range(args.start_epoch, args.epochs):
            if args.distributed:
                train_sampler.set_epoch(epoch)
            train_one_epoch(model, None, optimizer, dataset, data_loader, device, epoch, args, model_ema, scaler)
            lr_scheduler.step()

        all_cavs.append(model.cav.detach().cpu())
    all_cavs = torch.cat(all_cavs, dim=0)
    torch.save(all_cavs, cav_file)

    # Save similarities between CAVs and image features from the target model
    features_path = os.path.join('saved_activations', args.model, 'target_model_all_train_features.pth')
    save_simi_path = os.path.join('saved_activations', args.model, 'target_model_all_train_img_similarites(cls-relative-text)_del.pth')
    features = torch.load(features_path)
    features_cavs_simi = cosine_similarity(features, all_cavs)
    torch.save(features_cavs_simi, save_simi_path)

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print(f"Training time {total_time_str}")


def get_args_parser(add_help=True):
    import argparse

    parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)
    
    parser.add_argument("--data-set", default="ImageNet", type=str, help="dataset name")
    parser.add_argument("--data-path", default="datasets/ILSVRC2012", type=str, help="dataset path")
    parser.add_argument("--model", default="resnet18", type=str, help="model name")
    parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
    parser.add_argument(
        "-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
    )
    parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
    parser.add_argument(
        "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
    )
    parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
    parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate")
    parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
    parser.add_argument(
        "--wd",
        "--weight-decay",
        default=1e-4,
        type=float,
        metavar="W",
        help="weight decay (default: 1e-4)",
        dest="weight_decay",
    )
    parser.add_argument(
        "--norm-weight-decay",
        default=None,
        type=float,
        help="weight decay for Normalization layers (default: None, same value as --wd)",
    )
    parser.add_argument(
        "--bias-weight-decay",
        default=None,
        type=float,
        help="weight decay for bias parameters of all layers (default: None, same value as --wd)",
    )
    parser.add_argument(
        "--transformer-embedding-decay",
        default=None,
        type=float,
        help="weight decay for embedding parameters for vision transformer models (default: None, same value as --wd)",
    )
    parser.add_argument(
        "--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing"
    )
    parser.add_argument("--mixup-alpha", default=0.0, type=float, help="mixup alpha (default: 0.0)")
    parser.add_argument("--cutmix-alpha", default=0.0, type=float, help="cutmix alpha (default: 0.0)")
    parser.add_argument("--lr-scheduler", default="steplr", type=str, help="the lr scheduler (default: steplr)")
    parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)")
    parser.add_argument(
        "--lr-warmup-method", default="constant", type=str, help="the warmup method (default: constant)"
    )
    parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr")
    parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
    parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
    parser.add_argument("--lr-min", default=0.0, type=float, help="minimum lr of lr schedule (default: 0.0)")
    parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
    parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
    parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
    parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
    parser.add_argument(
        "--cache-dataset",
        dest="cache_dataset",
        help="Cache the datasets for quicker initialization. It also serializes the transforms",
        action="store_true",
    )
    parser.add_argument(
        "--sync-bn",
        dest="sync_bn",
        help="Use sync batch norm",
        action="store_true",
    )
    parser.add_argument(
        "--test-only",
        dest="test_only",
        help="Only test the model",
        action="store_true",
    )
    parser.add_argument("--auto-augment", default=None, type=str, help="auto augment policy (default: None)")
    parser.add_argument("--ra-magnitude", default=9, type=int, help="magnitude of auto augment policy")
    parser.add_argument("--augmix-severity", default=3, type=int, help="severity of augmix policy")
    parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)")

    # Mixed precision training parameters
    parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")

    # distributed training parameters
    parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
    parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
    parser.add_argument(
        "--model-ema", action="store_true", help="enable tracking Exponential Moving Average of model parameters"
    )
    parser.add_argument(
        "--model-ema-steps",
        type=int,
        default=32,
        help="the number of iterations that controls how often to update the EMA model (default: 32)",
    )
    parser.add_argument(
        "--model-ema-decay",
        type=float,
        default=0.99998,
        help="decay factor for Exponential Moving Average of model parameters (default: 0.99998)",
    )
    parser.add_argument(
        "--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
    )
    parser.add_argument(
        "--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)"
    )
    parser.add_argument(
        "--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)"
    )
    parser.add_argument(
        "--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)"
    )
    parser.add_argument(
        "--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
    )
    parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
    parser.add_argument("--ra-sampler", action="store_true", help="whether to use Repeated Augmentation in training")
    parser.add_argument(
        "--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)"
    )
    parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
    parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive")
    parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms")
    # My added
    parser.add_argument("--clip-model", default="ViT-L-14", type=str, help="CLIP model name")
    parser.add_argument("--seed", default=1028, type=int, help="the random seed")
    parser.add_argument("--use-clip", action="store_true", help="Use CLIP cavs")
    parser.add_argument("--probe-ratio", type=float, default=0.25, help="the ratio of probe imgs")
    parser.add_argument("--train-concept-relation", action="store_true", help="Use ensemble concepts")

    return parser


if __name__ == "__main__":
    args = get_args_parser().parse_args()
    main(args)