import os
import sys
import datetime
import time
import math
import json
from pathlib import Path

import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
from torch.utils.data import Subset
from torchvision import datasets, transforms
from torchvision import models as torchvision_models
from tqdm import tqdm

from dreamsim.dataset_nights.dataset import TwoAFCDataset
from dreamsim.util.utils import get_preprocess
from objectives import objective_utils
import models.vision_transformer as vits
from models.vision_transformer import DINOHead
from objectives.objective_utils import HingeLoss
from scripts import pidfile

torchvision_archs = sorted(name for name in torchvision_models.__dict__
    if name.islower() and not name.startswith("__")
    and callable(torchvision_models.__dict__[name]))

def train_dreamsim(args, writer):
    try:
        objective_utils.init_distributed_mode(args)
    except:
        pass
    objective_utils.fix_random_seeds(args.seed)
    print("git:\n  {}\n".format(objective_utils.get_sha()))
    print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
    cudnn.benchmark = True
    exp_name = f'dreamsim-n{args.samples}-{args.data_name}-{args.arch}'
    log_dir = os.path.join(args.output_dir, exp_name)
    os.makedirs(log_dir, exist_ok=True)
    # pidfile.exit_if_job_done(log_dir)

    # ============ preparing data ... ============
    transform = get_preprocess('resnet18')
    dataset = TwoAFCDataset(root_dir=args.data_path, split="train", preprocess=transform, include_prompt=False)
    random_idcs = torch.randperm(len(dataset))[:args.samples]
    dataset = Subset(dataset, random_idcs)
    sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True)
    data_loader = torch.utils.data.DataLoader(
        dataset,
        sampler=sampler,
        batch_size=args.batch_size_per_gpu,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=False,
    )
    print(f"Data loaded: there are {len(dataset)} images.")

    # ============ building student and teacher networks ... ============
    # we changed the name DeiT-S for ViT-S to avoid confusions
    args.arch = args.arch.replace("deit", "vit")
    # if the network is a Vision Transformer (i.e. vit_tiny, vit_small, vit_base)
    if args.arch in vits.__dict__.keys():
        student = vits.__dict__[args.arch](
            patch_size=args.patch_size,
            drop_path_rate=args.drop_path_rate,  # stochastic depth
        )
        embed_dim = student.embed_dim
    # if the network is a XCiT
    elif args.arch in torch.hub.list("facebookresearch/xcit:main"):
        student = torch.hub.load('facebookresearch/xcit:main', args.arch,
                                 pretrained=False, drop_path_rate=args.drop_path_rate)
        embed_dim = student.embed_dim
    # otherwise, we check if the architecture is in torchvision models
    elif args.arch in torchvision_models.__dict__.keys():
        student = torchvision_models.__dict__[args.arch](pretrained=True)
        student.fc = nn.Identity()
        # embed_dim = student.fc.weight.shape[1]
    else:
        print(f"Unknow  architecture: {args.arch}")

    # move networks to gpu
    student.train()
    student = student.cuda()
    # synchronize batch norms (if any)
    if objective_utils.has_batchnorms(student):
        print('Synchronizing batch norms')
        student = nn.SyncBatchNorm.convert_sync_batchnorm(student)
    student = nn.parallel.DistributedDataParallel(student, device_ids=[args.gpu], broadcast_buffers=False)
    print(f"Student is built: they are both {args.arch} network.")

    # ============ preparing loss ... ============
    dreamsim_loss = HingeLoss('cuda', 0.05).cuda()

    # ============ preparing optimizer ... ============
    optimizer = torch.optim.Adam(student.parameters(), lr=3e-4)
    # for mixed precision training
    fp16_scaler = None
    if args.use_fp16:
        fp16_scaler = torch.cuda.amp.GradScaler()

    print(f"Loss, optimizer and schedulers ready.")

    to_restore = {"epoch": 0}
    start_epoch = to_restore["epoch"]

    start_time = time.time()
    print("Starting DINO training !")
    for epoch in range(start_epoch, args.epochs):
        data_loader.sampler.set_epoch(epoch)

        # ============ training one epoch of DINO ... ============
        train_stats = train_one_epoch(student, dreamsim_loss,
            data_loader, optimizer,
            epoch, fp16_scaler, args)

        # ============ writing logs ... ============
        save_dict = {
            'student': student.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': epoch + 1,
            'args': args,
            'dreamsim_loss': dreamsim_loss.state_dict(),
        }
        for k, v in train_stats.items():
            writer.add_scalar(k, v, epoch)
        writer.flush()

        if fp16_scaler is not None:
            save_dict['fp16_scaler'] = fp16_scaler.state_dict()
        torch.save(save_dict, os.path.join(log_dir, 'checkpoint.pth'))
        # if args.saveckp_freq and epoch % args.saveckp_freq == 0:
        #     objective_utils.save_on_master(save_dict, os.path.join(log_dir, f'checkpoint{epoch:04}.pth'))
        log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
                     'epoch': epoch}
        if objective_utils.is_main_process():
            with (Path(log_dir) / "log.txt").open("a") as f:
                f.write(json.dumps(log_stats) + "\n")
    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
    # pidfile.mark_job_done(log_dir)


def train_one_epoch(student, dreamsim_loss,
            data_loader, optimizer,
            epoch, fp16_scaler, args):
    metric_logger = objective_utils.MetricLogger(delimiter="  ")
    header = 'Epoch: [{}/{}]'.format(epoch, args.epochs)
    for it, batch in tqdm(enumerate(data_loader), total=len(data_loader)):
        # update weight decay and learning rate according to their schedule
        img_ref, img_left, img_right, p, id = batch
        it = len(data_loader) * epoch + it  # global training iteration

        # move images to gpu
        img_ref = img_ref.cuda(non_blocking=True)
        img_left = img_left.cuda(non_blocking=True)
        img_right = img_right.cuda(non_blocking=True)
        p = p.cuda(non_blocking=True)
        # teacher and student forward passes + compute dino loss
        with torch.cuda.amp.autocast(fp16_scaler is not None):
            embed_ref = student(img_ref)
            embed_left = student(img_left)
            embed_right = student(img_right)

            dist_left = 1 - F.cosine_similarity(embed_ref, embed_left, dim=-1)
            dist_right = 1 - F.cosine_similarity(embed_ref, embed_right, dim=-1)

            logit = dist_left - dist_right
            loss = dreamsim_loss(logit.squeeze(), p)
            loss /= p.shape[0]

        if not math.isfinite(loss.item()):
            print("Loss is {}, stopping training".format(loss.item()), force=True)
            sys.exit(1)

        # student update
        optimizer.zero_grad()
        param_norms = None
        if fp16_scaler is None:
            loss.backward()
            if args.clip_grad:
                param_norms = objective_utils.clip_gradients(student, args.clip_grad)
            objective_utils.cancel_gradients_last_layer(epoch, student,
                                              args.freeze_last_layer)
            optimizer.step()
        else:
            fp16_scaler.scale(loss).backward()
            if args.clip_grad:
                fp16_scaler.unscale_(optimizer)  # unscale the gradients of optimizer's assigned params in-place
                param_norms = objective_utils.clip_gradients(student, args.clip_grad)
            objective_utils.cancel_gradients_last_layer(epoch, student,
                                              args.freeze_last_layer)
            fp16_scaler.step(optimizer)
            fp16_scaler.update()

        # logging
        torch.cuda.synchronize()
        metric_logger.update(loss=loss.item())
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
        metric_logger.update(wd=optimizer.param_groups[0]["weight_decay"])
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
