import argparse
import datetime
import numpy as np
import time
import sys

import torch
import torch.backends.cudnn as cudnn
import json
import os
import socket

from pathlib import Path

from timm.data import Mixup
try:
    from timm.data import DatasetTar
except ImportError:
    from timm.data import ImageDataset as DatasetTar
from timm.models import create_model
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from timm.scheduler import create_scheduler
from timm.optim import create_optimizer
from timm.utils import get_state_dict, ModelEma

from datasets import build_dataset, build_transform
from engine import train_one_epoch, train_scale, evaluate, NativeScalerGA
from dct_to_space_3d_compress import ToDCT,reconstruct_weights
from losses import DistillationLoss
from torchvision import transforms
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

from samplers import RASampler
import utils
import torch.distributed as dist

import shutil

class Logger(object):
    def __init__(self, filename="Default.log"):
        self.terminal = sys.stdout
        self.log = open(filename, "a")

    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)

    def flush(self):
        self.terminal.flush()
        self.log.flush()
def get_args_parser():
    parser = argparse.ArgumentParser('TLEG training and evaluation script', add_help=False)

    parser.add_argument('--data-path', default='/ImageNet-1K', type=str,
                        help='dataset path')
    
    # Model parameters
    parser.add_argument('--model', default='deit_tiny_patch16_224_L4', type=str, metavar='MODEL',
                        help='Name of model to train')
    parser.add_argument('--pretrained', action='store_true', default=False,
                        help='Start with pretrained version of specified network (if avail)')
    
    # base model parameters
    parser.add_argument('--basemodel', default='deit_tiny_patch16_224_L12', type=str, metavar='MODEL',
                        help='Name of base model to train')
    parser.add_argument('--basepretrained', action='store_true', default=False,
                        help='Start with pretrained version of specified network (if avail)')
    parser.add_argument('--basenb_classes', default=1000, type=int, metavar='N', help='number of base classes')
    parser.add_argument('--basemodel_pretrain_pth', default='', type=str, help='pth')


    # Augmentation parameters
    parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
                        help='Color jitter factor (default: 0.4)')
    parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
                        help='Use AutoAugment policy. "v0" or "original". " + \
                             "(default: rand-m9-mstd0.5-inc1)'),
    parser.add_argument('--smoothing', type=float, default=0.0, help='Label smoothing (default: 0.1)')
    parser.add_argument('--train-interpolation', type=str, default='bicubic',
                        help='Training interpolation (random, bilinear, bicubic default: "bicubic")')

    parser.add_argument('--repeated-aug', action='store_true')
    parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug')
    parser.add_argument('--load-tar', default=False,action='store_true', help='Loading *.tar files for dataset')
    parser.set_defaults(repeated_aug=True)

    parser.add_argument('--train-mode', action='store_true')
    parser.add_argument('--no-train-mode', action='store_false', dest='train_mode')
    parser.set_defaults(train_mode=True)

    # Dataset parameters

    parser.add_argument('--data-set', default='IMNET', choices=['CIFAR100', 'CIFAR10', 'IMNET', 'flowers102', 'stanfordcar', 'food101', 'cub200', 'INAT19'],
                        type=str, help='Image Net dataset path')
    parser.add_argument('--inat-category', default='name',
                        choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'],
                        type=str, help='semantic granularity')

    parser.add_argument('--output_dir', default='',
                        help='path where to save, empty for no saving')
    parser.add_argument('--device', default='cuda',
                        help='device to use for training / testing')
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--resume', default='', help='resume from checkpoint')

    return parser


def main(args):

    log_dir = Path(args.output_dir)
    log_dir.mkdir(parents=True, exist_ok=True)
    log_file = log_dir / 'training_log.txt'
    sys.stdout = Logger(log_file)

    print("Python executable:", sys.executable)
    print("CUDA_VISIBLE_DEVICES:", os.environ.get('CUDA_VISIBLE_DEVICES'))
    print("CUDA available:", torch.cuda.is_available())
    print("CUDA device count:", torch.cuda.device_count())

    utils.init_distributed_mode(args)

    dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
    dataset_val, _ = build_dataset(is_train=False, args=args)


    sampler_train = torch.utils.data.RandomSampler(dataset_train)
    sampler_val = torch.utils.data.SequentialSampler(dataset_val)

    data_loader_train = torch.utils.data.DataLoader(
        dataset_train, 
        sampler=sampler_train,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=args.drop_last,
    )

    data_loader_val = torch.utils.data.DataLoader(
        dataset_val, 
        sampler=sampler_val,
        batch_size=int(1.5 * args.batch_size),
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=False
    )

    mixup_fn = None
    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    if mixup_active:
        mixup_fn = Mixup(
            mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
            prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
            label_smoothing=args.smoothing, num_classes=args.nb_classes)


    print(f"Creating model: {args.model}")
   
    model = create_model(
        args.model,
        pretrained=args.pretrained,
        num_classes=args.nb_classes,
        drop_rate=args.drop,
        drop_path_rate=args.drop_path,
        weight_init = args.weight_init
    )
    if args.dct:
        
        model_base = create_model(
            args.basemodel,
            pretrained=args.basepretrained,
            num_classes=args.basenb_classes,
            drop_rate=args.drop,
            drop_path_rate=args.drop_path,
            weight_init = args.weight_init
        )
        print("modelbase",model_base)
        model_path=args.basemodel_pretrain_pth

            
        
        reg_name = f'dct_{args.basemodel}_pretrained'
        save_dir_main = Path(os.getcwd())
        save_dir_main.mkdir(parents=True, exist_ok=True)

        todct=ToDCT(keep_ratio=1.0,savepth=save_dir_main,model_name=reg_name)
        
        savepth_blocks_norm = save_dir_main / f'{reg_name}_layer_blocks_norm.pt'
        savepth_blocks_freq = save_dir_main / f'{reg_name}_layer_blocks_freq.pth'
        savepth_presevered_freq = save_dir_main / f'{reg_name}_layer_presevered_freq.pth'
        # if args.rank == 0:
            # todct(model_base)
        todct(model_base)
        # if args.distributed:
        #     dist.barrier()
        layer_blocks=torch.load(savepth_blocks_freq,map_location='cpu')
        layer_norms=torch.load(savepth_blocks_norm,map_location='cpu')
        layer_freq=torch.load(savepth_presevered_freq,map_location='cpu')
        
        model=reconstruct_weights(model,layer_blocks,layer_norms,layer_freq,use_dct=args.use_dct,use_ratio=args.use_ratio,select_n=args.select_n,mode=args.dct_mode)

        del model_base

    print(model)
    
    model.to(device)

    model_without_ddp = model
    if False: #args.distributed
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module
        
    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('number of params:', n_parameters/(1000*1000))

    if args.eval:
        test_stats = evaluate(data_loader_val, model, device)
        print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
        return

    print(f"Start training for {args.epochs} epochs")
    start_time = time.time()
    max_accuracy = 0.0
        
    for epoch in range(args.start_epoch, args.epochs):

        lr_scheduler.step(epoch)

        train_stats = train_one_epoch(
            model, criterion, data_loader_train,
            optimizer, device, epoch, loss_scaler,
            args.clip_grad, 
            # model_ema, 
            mixup_fn,
            set_training_mode = args.train_mode,  # keep in eval mode during finetuning
            accumulation_step = args.accumulation_step
        )

        test_stats = evaluate(data_loader_val, model, device)

        if args.output_dir and utils.is_main_process():
            with (output_dir / "log.txt").open("a") as f:
                f.write(json.dumps(log_stats) + "\n")
    
    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))


if __name__ == '__main__':
    torch.multiprocessing.set_start_method('spawn')
    parser = argparse.ArgumentParser('FRONT training and evaluation script', parents=[get_args_parser()])
    args = parser.parse_args()

    if args.output_dir:
        Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    
    # torch.multiprocessing.set_start_method('fork', force=True)
    
    main(args)
