import os
import json
import argparse
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm
from datetime import datetime

import sys

from libKMCUDA import kmeans_cuda
import faiss
import faiss.contrib.torch_utils
import kmc2

from ncp_acl.builder import get_base_model, load_pretrained_model
from ncp_acl.loader import TwoCropsTransform, get_transforms
from utils import AverageMeter

parser = argparse.ArgumentParser()

parser.add_argument('--dataset', type=str)
parser.add_argument('--data_dir', type=str)
parser.add_argument('--strength', type=int, default=0.5)
parser.add_argument('--use_weak_transforms', type=bool, default=True)

parser.add_argument('--num_workers', type=int, default=4)
parser.add_argument('--batch_size', type=int, default=256)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--lr', type=float, default=0.5)
parser.add_argument('--optimizer', type=str, default='sgd')
parser.add_argument('--weight_decay', type=float, default=5e-4)
parser.add_argument('--momentum', type=float, default=0.9)
parser.add_argument('--scheduler', type=str, default='warmup_cosine')
parser.add_argument("--min_lr", default=0.0, type=float)
parser.add_argument("--warmup_epochs", default=10, type=int)

parser.add_argument("--epsilon", type=float, default=8 / 255)
parser.add_argument("--step_size", type=float, default=2 / 255)
parser.add_argument("--num_steps", type=int, default=5)
parser.add_argument("--trades_k", type=float, default=6)

parser.add_argument('--backbone', type=str, default='resnet18')
parser.add_argument('--hidden_dim', type=int, default=128)
parser.add_argument('--zero_init_residual', type=bool, default=True)
parser.add_argument('--use_projector', type=bool, default=False)
parser.add_argument('--student_scratch', type=bool, default=False)

parser.add_argument('--teacher_method', type=str, default='SimCLR')
parser.add_argument('--device', type=str)
parser.add_argument('--print_freq', type=int, default=10)

parser.add_argument('--reload_freq', type=int, default=1)
parser.add_argument('--spherical', action='store_true')
parser.add_argument('--tolerance', type=float, default=1e-6)
parser.add_argument('--yinyang_t', type=float, default=0.)
parser.add_argument('--chain_length', type=int, default=500)
parser.add_argument('--n_centroid', type=int, default=10)

parser.add_argument('--multi_bn', action='store_true', default=True)
parser.add_argument('--use_normalize', action='store_true', default=True)
parser.add_argument('--use_amp', action='store_true', default=True)

parser.add_argument('--seed', type=int, default=0)

parser.add_argument('--pre_checkpoint_path', type=str)
parser.add_argument('--save_dir', type=str)

parser.add_argument('--pert_centroid', type=int, default=2)
parser.add_argument('--use_away_target', action='store_true')


def fix_seed(SEED):
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def main():
    
    args = parser.parse_args()
      
    os.environ['CUDA_VISIBLE_DEVICES'] = args.device.split(':')[1]
    args.device = 'cuda'
    
    args.teacher_ckpt_path = args.pre_checkpoint_path
    
    assert os.path.isfile(args.teacher_ckpt_path), 'You should input valid checkpoint path!!!'
    assert args.dataset in ['cifar10', 'cifar100', 'stl10'], "Only CIFAR10, CIFAR100 and STL10 are supported"

    fix_seed(args.seed)

    teacher_model = get_base_model(dataset=args.dataset,
                                   backbone=args.backbone,
                                   hidden_dim=args.hidden_dim,
                                   use_normalize=args.use_normalize,
                                   zero_init_residual=args.zero_init_residual,
                                   use_projector=args.use_projector,
                                   multi_bn=False)

    student_model = get_base_model(dataset=args.dataset,
                                   backbone=args.backbone,
                                   hidden_dim=args.hidden_dim,
                                   use_normalize=args.use_normalize,                                   
                                   zero_init_residual=False if args.multi_bn else args.zero_init_residual,
                                   use_projector=args.use_projector,
                                   multi_bn=args.multi_bn)

    teacher_model = load_pretrained_model(teacher_model,
                                          args.teacher_ckpt_path,
                                          args.use_projector,
                                          multi_bn=False)
    initial_teacher_model_state_dict = {k: v.cpu() for k, v in teacher_model.eval().state_dict().items()}

    if not args.student_scratch:
        student_model = load_pretrained_model(student_model,
                                              args.teacher_ckpt_path,
                                              args.use_projector,
                                              multi_bn=args.multi_bn)

    train_transforms = TwoCropsTransform(get_transforms(args))
    cluster_transforms = transforms.Compose([transforms.ToTensor(), ])

    
    if args.dataset in ['cifar10', 'cifar100']:
        train_dataset = vars(torchvision.datasets)[args.dataset.upper()](args.data_dir, 
                                                                         train=True, 
                                                                         download=True, 
                                                                         transform=train_transforms)
        cluster_dataset = vars(torchvision.datasets)[args.dataset.upper()](args.data_dir, 
                                                                         train=True, 
                                                                         download=True, 
                                                                         transform=cluster_transforms)
    else:
        train_dataset = vars(torchvision.datasets)[args.dataset.upper()](args.data_dir, 
                                                                         split='train+unlabeled', 
                                                                         download=True, 
                                                                         transform=train_transforms)
        cluster_dataset = vars(torchvision.datasets)[args.dataset.upper()](args.data_dir, 
                                                                         split='train+unlabeled', 
                                                                         download=True, 
                                                                         transform=cluster_transforms)
    train_sampler = None
    cluster_sampler = None
        
    train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=False if train_sampler else True,
                                                   sampler=train_sampler,
                                                   num_workers=args.num_workers,
                                                   pin_memory=True,
                                                   drop_last=False)
    
    cluster_dataloader = torch.utils.data.DataLoader(cluster_dataset,
                                                   batch_size=512,
                                                   shuffle=False if cluster_sampler else True,
                                                   sampler=cluster_sampler,
                                                   num_workers=4,
                                                   pin_memory=True,
                                                   drop_last=False)

    if args.spherical:
        criterion = nn.CosineSimilarity(dim=-1).cuda(args.device)
    else:
        criterion = nn.MSELoss().cuda(args.device)
        
    args.lr = args.lr * args.batch_size / 256

    extra_optimizer_args = {}
    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD
        extra_optimizer_args['momentum'] = args.momentum
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam
    elif args.optimizer == 'adamw':
        optimizer = torch.optim.AdamW
    else:
        raise ValueError("Only SGD, Adam, and AdamW are supported")

    optimizer = optimizer(student_model.parameters(),
                          lr=args.lr,
                          weight_decay=args.weight_decay,
                          **extra_optimizer_args)

    if args.scheduler == 'warmup_cosine':
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                      lr_lambda=lambda step: cosine_annealing(step,
                                                                                              args.epochs,
                                                                                              1,  
                                                                                              args.min_lr / args.lr,
                                                                                              warmup_steps=args.warmup_epochs)
                                                     )
    else:
        raise ValueError("Only Cosine Annealing with Warmup is supported")

    model_save_path = args.save_dir

    os.makedirs(model_save_path, exist_ok=False)

    with open(os.path.join(model_save_path, 'train_kmacl_args_{}.txt'.format(datetime.now().strftime("%Y_%m_%d"))), 'w') as f:
        json.dump(args.__dict__, f, indent=2)

    start_epoch = 0

    if args.use_amp:
        scaler = torch.cuda.amp.GradScaler()
    else:
        scaler = None

    avg_loss_list = []
    for epoch in range(start_epoch, args.epochs):
        if epoch % args.reload_freq == 0 and (args.reload_freq != -1 or epoch == 0):
            cluster_representations = generate_representations(pretrained_model=student_model,
                                                               multi_bn=args.multi_bn,
                                                               device=args.device,
                                                               cluster_dataloader=cluster_dataloader,
                                                               spherical=args.spherical)
            c_init = kmc2.kmc2(cluster_representations.numpy(),
                               k=args.n_centroid,
                               chain_length=args.chain_length,
                               afkmc2=True,
                               random_state=args.seed,
                               weights=None).astype(np.float32)

            centroids, _ = kmeans_cuda(samples=cluster_representations.numpy(),
                                       clusters=args.n_centroid,
                                       tolerance=args.tolerance,
                                       init=c_init,
                                       yinyang_t=args.yinyang_t,
                                       metric= 'cos' if args.spherical else 'L2',
                                       average_distance=False,
                                       seed=args.seed,
                                       device=0,
                                       verbosity=1)                
            if args.spherical:
                kmeans_model = faiss.IndexFlatIP(centroids.shape[-1])
            else:
                kmeans_model = faiss.IndexFlatL2(centroids.shape[-1])
                
            kmeans_model.add(centroids)

        avg_loss = train(train_dataloader, teacher_model, student_model, kmeans_model, centroids, criterion, optimizer, scheduler, epoch, args, scaler)
        
        avg_loss_list.append(avg_loss)
        if epoch == 0:
            sanity_check(teacher_model.state_dict(), initial_teacher_model_state_dict)

    save_checkpoint({
        'epoch': epoch,
        'arch': args.backbone,
        'state_dict': student_model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'avg_loss_history': avg_loss_list
    }, filename=os.path.join(model_save_path, 'checkpoint_{:04d}.pth.tar'.format(epoch)))

def generate_representations(pretrained_model, multi_bn, device, cluster_dataloader, spherical):

    pretrained_model.to(device)
    pretrained_model.eval()

    representations = []

    with torch.no_grad():
        for data, _ in cluster_dataloader:
            data = data.to(device)
            if multi_bn:
                representations.append(pretrained_model(data, bn_name='normal').cpu())
            else:
                representations.append(pretrained_model(data).cpu())

    representations = torch.cat(representations)
    if spherical:
        representations = F.normalize(representations, p=2, dim=-1)
    
    return representations

def train(train_dataloader, teacher_model, student_model, kmeans_model, centroids, criterion, optimizer, scheduler, epoch, args, scaler):
    losses = AverageMeter()
    losses.reset()

    teacher_model.to(args.device)
    student_model.to(args.device)

    for batch_idx, ((_, data), _) in tqdm(enumerate(train_dataloader)):
        teacher_model.eval()
        student_model.train()

        data = data.to(args.device)
        
        with torch.cuda.amp.autocast(enabled=args.use_amp):
            teacher_targets = teacher_model(data).detach()
            if args.multi_bn:
                student_outputs = student_model(data, bn_name='normal')
            else:
                student_outputs = student_model(data)
                
            if args.spherical:
                representations = F.normalize(student_outputs.detach().clone().float(), p=2, dim=-1).cpu().numpy()
            else:
                representations = student_outputs.detach().clone().float().cpu().numpy()
            _, I = kmeans_model.search(representations, args.pert_centroid)
    
            first_centroids = I[:, 0]
            away_targets = torch.from_numpy(centroids[first_centroids]).to(args.device)
            
            target_centroids = I[:, args.pert_centroid-1]
            targets = torch.from_numpy(centroids[target_centroids]).to(args.device)
                
            adv_images = generate_training_AE(student_model, data, away_targets.detach(), targets.detach(), args)
                
            if args.multi_bn:
                adv_outputs = student_model.train()(adv_images, bn_name='pgd')
            else:
                adv_outputs = student_model.train()(adv_images)

            loss = -criterion(student_outputs, teacher_targets).mean() + \
                    args.trades_k * (-criterion(student_outputs, adv_outputs).mean())
            
            if not args.spherical: loss *= -1

        optimizer.zero_grad()
        
        if scaler:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()

        losses.update(float(loss.detach().cpu()), data.shape[0])

        if batch_idx % args.print_freq == 0:
            print(f'Epoch: [{epoch}][{batch_idx}/{len(train_dataloader)}] \t'
                  f'Average Loss : {losses.avg}')

    scheduler.step()
    return losses.avg


def generate_training_AE(model, data, away_targets, targets, args):
    images = data.clone().detach()

    delta = torch.zeros_like(images).uniform_(-args.epsilon, args.epsilon)
    delta = nn.Parameter(delta)
    model.eval()

    if args.spherical:
        criterion = F.cosine_similarity
    else:
        criterion = F.mse_loss
        
    for _ in range(args.num_steps):
        if args.multi_bn:
            outputs = model(images + delta, bn_name='pgd')
        else:
            outputs = model(images + delta)

        model.zero_grad()
        
        if args.use_away_target:
            if args.pert_centroid == 1:
                loss = (-criterion(outputs, away_targets).mean())
            else:
                loss = (-criterion(outputs, away_targets).mean() + criterion(outputs, targets).mean()) * 1/2
        else:
            loss = (criterion(outputs, targets).mean())
        
        if not args.spherical: loss *= -1

        if args.use_amp:
            loss *= 65536.
        
        loss.backward()

        delta.data = delta.data + args.step_size * delta.grad.sign()
        delta.grad = None
        delta.data = torch.clamp(delta.data, min=-args.epsilon, max=args.epsilon)
        delta.data = torch.clamp(images + delta.data, min=0, max=1) - images

    model.train()
    return (images + delta).detach()


def cosine_annealing(step, total_steps, lr_max, lr_min, warmup_steps=0):
    assert warmup_steps >= 0

    if step < warmup_steps:
        lr = lr_max * (step + 1) / warmup_steps
    else:
        lr = lr_min + (lr_max - lr_min) * 0.5 * (1 + np.cos((step - warmup_steps) / (total_steps - warmup_steps) * np.pi))
    return lr


def save_checkpoint(state, filename='checkpoint.pth.tar'):
    torch.save(state, filename)


def sanity_check(state_dict, initial_state_dict):
    print("=> sanity check for teacher model")

    for k in list(state_dict.keys()):
        assert ((state_dict[k].cpu() == initial_state_dict[k]).all()), '{} is changed in training.'.format(k)

    print("=> sanity check passed.")


if __name__ == '__main__':
    main()