import torch
from torch import nn
from torch.utils.data import DataLoader
import torch.utils.data.distributed
import torch.backends.cudnn as cudnn
import torch.distributed as dist

import models_vit
from Dataset import redshift
from timm.utils import ModelEmaV2
from timm.models.layers import trunc_normal_

from util.misc import colorstr, SaveCheckpoint
from util.lr_sched import adjust_learning_rate, param_groups_lrd
from util.pos_embed import interpolate_pos_embed
from collections import OrderedDict

from pathlib import Path
import numpy as np
import os
import time
import argparse
from torch.utils.tensorboard import SummaryWriter


def setup_for_distributed(is_master):
    """
    This function disables printing when not in master process
    """
    import builtins as __builtin__

    builtin_print = __builtin__.print

    def print(*args, **kwargs):
        force = kwargs.pop("force", False)
        if is_master or force:
            builtin_print(*args, **kwargs)

    __builtin__.print = print


def init_distributed_mode(args):
    
    if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
        args.rank = int(os.environ["RANK"])
        args.world_size = int(os.environ["WORLD_SIZE"])
        args.gpu = int(os.environ["LOCAL_RANK"])
        print("Use GPU: {} for training".format(args.gpu))
    elif "SLURM_PROCID" in os.environ:
        args.rank = int(os.environ["SLURM_PROCID"])
        args.gpu = args.rank % torch.cuda.device_count()
    else:
        print("Not using distributed mode")
        args.distributed = False
        return

    args.distributed = True

    torch.cuda.set_device(args.gpu)

    dist.init_process_group(
        backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
    )
    dist.barrier()
    setup_for_distributed(args.rank == 0)


def create_dataset(args):
    train_dataset, val_dataset, output_size = redshift()

    args.batch_size = int(args.batch_size / args.world_size)
    
    print(args)
    print(colorstr('green', "epochs: {}, images per gpu: {}, absolute lr: {}".format(
            args.epochs, args.batch_size, args.lr)))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset, num_replicas=args.world_size, rank=args.rank, shuffle=True)
        val_sampler = torch.utils.data.distributed.DistributedSampler(
            val_dataset, num_replicas=args.world_size, rank=args.rank, shuffle=False)
    else:
        raise ValueError("Distributed init error.")
    
    train_loader = DataLoader(train_dataset, 
                              batch_size=args.batch_size,
                              num_workers=args.workers,
                              pin_memory=True,
                              sampler=train_sampler,
                              drop_last=True)
    
    val_loader = DataLoader(val_dataset,
                            batch_size=args.batch_size,
                            num_workers=args.workers,
                            pin_memory=True,
                            sampler=val_sampler,
                            drop_last=False)
    
    return train_loader, val_loader, output_size, train_sampler


def create_model(args, num_class):
    vit = ['vit_base_patch16', 'vit_large_patch16', 'vit_huge_patch14']
    assert args.model_name in vit

    model = models_vit.__dict__[args.model_name](
        img_size=args.input_size,
        num_classes=num_class, 
        drop_path_rate=args.drop_path, 
        global_pool=args.global_pool,
    )

    if args.finetune:
        ckpt = torch.load(args.finetune, map_location='cpu')['model']

        # interpolate position embedding
        interpolate_pos_embed(model, ckpt)

        new_state_dict = OrderedDict()
        for k, v in ckpt.items():
            if 'decoder' not in k:
                name = k[7:]   # remove 'module.'
                new_state_dict[name] = v
        msg = model.load_state_dict(new_state_dict, strict=False)
        # msg = model.load_state_dict(ckpt, strict=False)
        print(msg)
        
        if args.global_pool:
            assert set(msg.missing_keys) == {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'}

        trunc_normal_(model.head.weight, std=2e-5)
        print(colorstr('green', "Load pre-trained checkpoint from: %s" % args.finetune))
    
    class PatchEmbed(nn.Module):
        def __init__(self, embed_dim=768, k=16):
            super(PatchEmbed, self).__init__()
            self.proj = nn.Conv2d(in_channels=5, out_channels=embed_dim, kernel_size=k, stride=k)
            self.norm = nn.Identity()

        def forward(self, x):
            x = self.proj(x).flatten(2).transpose(1, 2)
            return self.norm(x)

    model.patch_embed = PatchEmbed(embed_dim=768, k=16)
 
    return model


def main(args):
    init_distributed_mode(args)

    cudnn.benchmark = True

    device = torch.device(args.device)

    # data loaders
    train_loader, val_loader, num_class, train_sampler = create_dataset(args=args)

    # create model
    model = create_model(args=args, num_class=num_class)
    model.to(device)

    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before DDP wrapper
        model_ema = ModelEmaV2(model, decay=args.model_ema_decay)
    
    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module
    
    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('number of params (M): %.4f' % (n_parameters / 1.e6))      

    # build optimizer with layer-wise lr decay (lrd)
    param_groups = param_groups_lrd(model_without_ddp, args.weight_decay,
        no_weight_decay_list=model_without_ddp.no_weight_decay(),
        layer_decay=args.layer_decay,
    )
    optimizer = torch.optim.AdamW(params=param_groups,
                                  lr=args.lr,
                                  betas=(0.9, 0.999))
    # loss
    criterion = nn.MSELoss().to(device)
    scaler = torch.cuda.amp.GradScaler()

    # file path
    if dist.get_rank() == 0:
        # weights
        save_dir = Path(args.save_dir)
        weights = save_dir / 'weights'
        weights.mkdir(parents=True, exist_ok=True)
        last = weights / 'last'
        best = weights / 'best'

        # tensorboard
        logdir = save_dir / 'logs'
        logdir.mkdir(parents=True, exist_ok=True)
        summary_writer = SummaryWriter(logdir, flush_secs=120)

        # result
        model_file = str(save_dir / 'model.txt')
        with open(model_file, "a") as f:
            print(model_without_ddp, file=f)
            print(args, file=f)
    
    if args.resume:
        if args.gpu is None:
            checkpoint = torch.load(args.resume)
        elif torch.cuda.is_available():
            # Map model to be loaded to specified single gpu.
            loc = 'cuda:{}'.format(args.gpu)
            checkpoint = torch.load(args.resume, map_location=loc)
           
        args.start_epoch = checkpoint['epoch']
        model_without_ddp.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        best_mse = torch.tensor(checkpoint['best_mse'])
        best_bias = checkpoint.get('best_bias', 0.0)
        if args.gpu is not None:
            # best_acc may be from a checkpoint from a different GPU
            best_mse = best_mse.to(args.gpu)

        if model_ema:
            model_ema.load_state_dict(checkpoint["model_ema"])

        print(colorstr('green', 'Resuming training from {} epoch'.format(args.start_epoch)))
    else:
        best_mse = float('inf')
        best_bias = 0.0
    
    print("Start training")
    for epoch in range(args.start_epoch, args.epochs):
        print("Epoch {}/{}".format(epoch + 1, args.epochs))
        if args.distributed:
            train_sampler.set_epoch(epoch)
        
        train_epoch_loss, train_bias = train(model=model,
                                             train_loader=train_loader,
                                             optimizer=optimizer,
                                             criterion=criterion,
                                             scaler=scaler,
                                             args=args,
                                             epoch=epoch,
                                             model_ema=model_ema)
        
        val_epoch_loss, val_bias, delta, mad, eta = validate(model=model,
                                                             val_loader=val_loader,
                                                             criterion=criterion, 
                                                             args=args)
        
        s = "Train Loss: {:.8f}, Train Bias: {:.8f}, Test Loss: {:.8f}, Test Bias: {:.8f}, Test delta: {:.8f}, lr: {:.1e}".format(
            train_epoch_loss, train_bias, val_epoch_loss, val_bias, delta, optimizer.param_groups[-1]['lr'])
        print(colorstr('green', s))

        if dist.get_rank() == 0:
            # save model
            is_best = val_epoch_loss < best_mse
            best_mse = min(best_mse, val_epoch_loss)
            best_bias = val_bias if is_best else best_bias
            state = {
                'epoch': epoch + 1,
                'model': model_without_ddp.state_dict(),
                'optimizer': optimizer.state_dict(),
                'best_mse': best_mse,
                'best_bias': best_bias,
            }
            if model_ema:
                state["model_ema"] = model_ema.state_dict()
            
            last_path = last / 'epoch_{}_mse{:.4f}'.format(
                epoch + 1, val_epoch_loss)
            best_path = best / 'epoch_{}_mse_{:.4f}'.format(
                epoch + 1, best_mse)

            SaveCheckpoint(state, last, last_path, best, best_path, is_best, interval=None, interval_path=None)
                       
            summary_writer.add_scalar('lr', optimizer.param_groups[-1]['lr'], epoch)
            summary_writer.add_scalar('train_loss', train_epoch_loss, epoch)
            summary_writer.add_scalar('train_bias', train_bias, epoch)
            summary_writer.add_scalar('val_loss', val_epoch_loss, epoch)
            summary_writer.add_scalar('val_bias', val_bias, epoch)

            summary_writer.add_scalar('delta', delta, epoch)
            summary_writer.add_scalar('mad', mad, epoch)
            summary_writer.add_scalar('eta', eta, epoch)
   
    if dist.get_rank() == 0:
        summary_writer.close()


def train(model, train_loader, optimizer, criterion, scaler, args, epoch, model_ema):
    train_mse = AverageMeter()
    train_bias = AverageMeter()
    
    # Model on train mode
    model.train()
    step_per_epoch = len(train_loader)
    for step, (images, labels) in enumerate(train_loader):
        torch.cuda.synchronize()
        start = time.time()

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, step / step_per_epoch + epoch, args)

        if args.gpu is not None and torch.cuda.is_available():
            images = images.cuda(args.gpu, non_blocking=True)
            labels = labels.cuda(args.gpu, non_blocking=True)

        # compute output
        if args.use_amp:
            with torch.cuda.amp.autocast():
                logits = model(images)
                loss = criterion(logits, labels)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            logits = model(images)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()

        # measure mse and record loss
        bias = compute_bias(logits, labels)

        train_mse.update(loss.item(), images.size(0))
        train_bias.update(bias.item(), images.size(0))

        torch.cuda.synchronize()
        if model_ema is not None:
            model_ema.update(model)
        
        t = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
        s1 = '\r{} [{}/{}]'.format(t, step+1, step_per_epoch)
        s2 = ' - {:.2f}ms/step - train_mse: {:.7f} - train_bias: {:.7f}'.format(
             1000 * (time.time()-start), train_mse.val, train_bias.val)
        print(s1+s2, end='', flush=True)

    print()
    return train_mse.avg, train_bias.avg


def validate(model, val_loader, criterion, args):
    val_mse = AverageMeter()
    val_bias = AverageMeter()

    truth_labels = []
    predictions = []

    def compute_results(truth, pred):
        z_phot = np.array(pred)
        z_spec = np.array(truth)

        assert len(z_phot) == len(z_spec)

        delta_z = (z_phot - z_spec) / (1 + z_spec)

        # 1. bias
        bias = np.mean(delta_z)

        # 2. median absolute deviation
        mad = np.median(np.abs(delta_z - np.median(delta_z)))
        sigma_mad = 1.4826 * mad

        # 3. outlier fraction
        outlier_threshold = 3 * np.std(delta_z)
        outliers = np.abs(delta_z) > outlier_threshold
        eta = np.sum(outliers) * 100 / len(z_phot)

        return bias, sigma_mad, eta

    # model to evaluate mode
    model.eval()
    with torch.no_grad():
        for step, (images, labels) in enumerate(val_loader):
            if args.gpu is not None and torch.cuda.is_available():
                 images = images.cuda(args.gpu, non_blocking=True)
                 labels = labels.cuda(args.gpu, non_blocking=True)

            # compute output
            logits = model(images)
            loss = criterion(logits, labels)

            truth_labels.extend(labels.cpu().numpy())
            predictions.extend(logits.cpu().numpy())

            # measure loss
            bias = compute_bias(logits, labels)

            if args.distributed:
                loss = reduce_tensor(loss, args)
                bias = reduce_tensor(bias, args)
            
            val_mse.update(loss.item(), images.size(0))
            val_bias.update(bias.item(), images.size(0))
    
    delta, mad, eta = compute_results(truth=truth_labels, pred=predictions)
    
    return val_mse.avg, val_bias.avg, delta, mad, eta


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def reduce_tensor(tensor, args):
    rt = tensor.clone()
    dist.all_reduce(rt, op=dist.ReduceOp.SUM)
    rt /= dist.get_world_size()
    return rt


def compute_mse(output, target):
    """
    Computes the mean squared error between predictions and targets.
    
    Parameters:
    - output: Tensor containing the predicted values.
    - target: Tensor containing the ground truth values.
    
    Returns:
    - mse: The mean squared error between the predictions and the targets.
    """
    with torch.no_grad():
        # Ensure both tensors have the same shape
        if output.size() != target.size():
            raise ValueError("Output and target sizes must match.")
        
        # Compute the squared difference
        diff_squared = (output - target) ** 2
        
        # Compute the mean of squared differences
        mse = torch.mean(diff_squared)
        
        return mse


def compute_bias(true_labels, predicted_values):
    with torch.no_grad():
        bias = (predicted_values - true_labels) / (1 + true_labels)
    return torch.mean(bias)


def testmodel(model, test_data, args):
    # model to evaluate mode
    model.eval()

    test_dataloader = DataLoader(test_data, batch_size=args.batch_size, shuffle=False,
                                 num_workers=args.workers, pin_memory=True)
    truth_labels = []
    predictions = []
    with torch.no_grad():
        for step, (images, labels) in enumerate(test_dataloader):
            images, labels = images.cuda(), labels.cuda()

            logits = model(images)

            truth_labels.extend(labels.cpu().numpy())
            predictions.extend(logits.cpu().numpy())
    
    truth_labels = list(truth_labels)
    predictions = list(predictions)

    z_phot = np.array(predictions)
    z_spec = np.array(truth_labels)

    assert len(z_phot) == len(z_spec)

    delta_z = (z_phot - z_spec) / (1 + z_spec)

    # 1. bias
    bias = np.mean(delta_z)

    # 2. median absolute deviation
    mad = np.median(np.abs(delta_z - np.median(delta_z)))
    sigma_mad = 1.4826 * mad

    # outlier fraction
    outlier_threshold = 3 * np.std(delta_z)
    outliers = np.abs(delta_z) > outlier_threshold
    eta = np.sum(outliers) * 100 / len(z_phot)

    return bias, sigma_mad, eta


if __name__ == '__main__':  
    parser = argparse.ArgumentParser(description='MAE fine-tuning for image classification.')
    # model parameters
    parser.add_argument("--model_name", type=str, default="vit_base_patch16_224", help="model architecture")
    parser.add_argument('--input_size', default=224, type=int, help='images input size')
    parser.add_argument("--drop_path", type=float, default=0.1, help='Drop path rate')
    parser.add_argument('--model_ema', action='store_true')
    parser.add_argument('--model-ema-decay', type=float, default=0.9999)

    # Finetuning parameters
    parser.add_argument('--finetune', type=str, default='', help='finetune from pretrained checkpoint')
    parser.add_argument('--global_pool', action='store_true')
    parser.set_defaults(global_pool=True)
    parser.add_argument('--use_amp', action='store_true')

    # optimizer parameters
    parser.add_argument("--dataset", type=str, default='galaxy')
    parser.add_argument("--epochs", type=int, default=100)
    parser.add_argument("--start_epoch", default=0, type=int, help="start epoch")
    parser.add_argument("--batch_size", type=int, default=64, help="batch size")
    parser.add_argument('--workers', default=8, type=int, help='number of data loading workers')
    parser.add_argument("--weight_decay", type=float, default=0.05)
    parser.add_argument('--lr', type=float, default=1e-3, metavar='LR', help='learning rate (absolute lr)')
    parser.add_argument('--layer_decay', type=float, default=0.75, help='layer-wise lr decay from ELECTRA/BEiT')
    parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR', help='lower lr bound for cyclic schedulers that hit 0')
    parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N', help='epochs to warmup LR')

    parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")

    # distributed training parameters
    parser.add_argument('--dist_url', default='env://', type=str, help='url used to set up distributed training')
    parser.add_argument('--dist_backend', default='nccl', type=str, help='distributed backend')
    parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
    
    parser.add_argument("--resume",type=str, help="ckpt's path to resume most recent training")
    parser.add_argument("--save_dir", type=str, default="./run", help="save path, eg, acc_loss, weights, tensorboard, and so on")
    args = parser.parse_args()

    print(colorstr('green', 'Fine-tuning ' + args.model_name + ' on ' + args.dataset + ' ...'))
    main(args=args)