import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
import os
import shutil
import argparse
from torch.optim.lr_scheduler import MultiStepLR
import random
import numpy as np
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), "../"))
from ddg import DDG
from torchvision.models import ViT_B_16_Weights, resnet18, resnet50, vit_b_16
from tqdm import tqdm
from mixup import CutMix, mixup_data, MYCutMix, my_mixup_data
from Model.baseNet import get_model
from custom_crossentropy import *


def fix_seed(seed, deterministic=False):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.random.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = deterministic
    torch.backends.cudnn.benchmark = False


def resume_training(net, checkpoint_path, optimizer=None):
    if os.path.isfile(checkpoint_path):
        print(f"Loading checkpoint '{checkpoint_path}'")
        checkpoint = torch.load(checkpoint_path)
        
        # Remove 'module.' prefix if needed
        state_dict = checkpoint['net']
        if list(state_dict.keys())[0].startswith('module.'):
            new_state_dict = {k[7:]: v for k, v in state_dict.items()}
        else:
            new_state_dict = state_dict
        
        net.load_state_dict(new_state_dict)
        
        if optimizer and 'optimizer' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer'])
        
        start_epoch = checkpoint.get('epoch', 0)
        best_acc = checkpoint.get('best_acc', 0.0)
        
        print(f"Loaded checkpoint '{checkpoint_path}' (epoch {start_epoch})")
        return net
    else:
        raise FileNotFoundError(f"No checkpoint found at '{checkpoint_path}'")



# Training
def train(epoch, net, trainloader, optimizer, criterion, device, log_file, mode, args, warmup_scheduler):
    assert mode in ['vanilla', 'mix', 'mixup']
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    if mode == 'mix':
        for batch_idx, (inputs, label1, label2, lam1, lam2) in enumerate(tqdm(trainloader, ncols=80)):
            inputs, label1, label2, lam1, lam2 = inputs.to(device), label1.to(device), label2.to(device), lam1.to(device), lam2.to(device)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss1 = criterion(outputs, label1)
            loss2 = criterion(outputs, label2)
            loss = (lam1 * loss1 + lam2 * loss2).mean()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += label1.size(0)
            correct += predicted.eq(label1).sum().item()
            
            if warmup_scheduler:
                with warmup_scheduler.dampening():
                    scheduler.step()
            
    elif mode == 'mixup':
        for batch_idx, (inputs, targets) in enumerate(tqdm(trainloader, ncols=80)):
            inputs, targets = inputs.to(device), targets.to(device)
            if args.use_mixup and np.random.rand() < 0.5:
                inputs, label1, label2, lam1, lam2 = my_mixup_data(inputs, targets, alpha=1.0, num_classes=args.num_classes)
                inputs, label1, label2, lam1, lam2 = inputs.to(device), label1.to(device), label2.to(device), lam1.to(device), lam2.to(device)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss1 = criterion(outputs, label1)
            loss2 = criterion(outputs, label2)
            loss = (lam1 * loss1 + lam2 * loss2).mean()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
            if warmup_scheduler:
                with warmup_scheduler.dampening():
                    scheduler.step()
                    
    elif mode == 'vanilla':
        for batch_idx, (inputs, targets) in enumerate(tqdm(trainloader, ncols=80)):
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets).mean()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
            if warmup_scheduler:
                with warmup_scheduler.dampening():
                    scheduler.step()
            
    # elif mode == "onehot":
    #     for batch_idx, (inputs, targets) in enumerate(tqdm(trainloader, ncols=80)):
    #         inputs, targets = inputs.to(device), targets.to(device)
    #         optimizer.zero_grad()
    #         outputs = net(inputs)
    #         loss = criterion(outputs, targets).mean()
    #         loss.backward()
    #         optimizer.step()
    #         train_loss += loss.item()
    #         _, predicted = outputs.max(1)
    #         total += targets.size(0)
    #         correct += predicted.eq(targets).sum().item()
            
    train_acc = 100. * correct / total
    train_loss = train_loss / (batch_idx + 1)
    print(f'{epoch}, {optimizer.param_groups[0]["lr"]:.4f}, {train_loss:.4f}, {train_acc:.3f}%, \n')
    with open(log_file, 'a+') as file:
        file.write(f'{epoch}, {optimizer.param_groups[0]["lr"]:.4f}, {train_loss:.4f}, {train_acc:.3f}%, \n')

def test(epoch, net, testloader, criterion, device, log_file):
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            test_loss += loss.mean().item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    acc = 100.*correct/total
    print(f'epoch {epoch} Test Accuracy: {acc:.2f}%')
    with open(log_file, 'a+') as file:
        file.write(f'{epoch}, Test acc = {acc:.3f}%\n')
    return acc 

def save_checkpoint(state, exp_dir, filename='checkpoint.pth'):
    """Save the model checkpoint."""
    os.makedirs(exp_dir, exist_ok=True)
    filepath = os.path.join(exp_dir, filename)
    torch.save(state, filepath)
    print(f'Model saved to {filepath}')

def main():
    parser = argparse.ArgumentParser(description='PyTorch FGVC Training')
    parser.add_argument('--mode', default='train')
    parser.add_argument('--dataset', default='cub', choices=['cub','car','aircraft'])
    parser.add_argument(
        '--seed', type=int, default=0, help='fix random seed')
    parser.add_argument("-g", "--gpu", default="-1", type=int)
    parser.add_argument(
        '--model', type=str, choices=['resnet18', 'resnet50', 'densenet121', 'vit'], default='resnet18', help='choose NeuralNetwork model')
    parser.add_argument('--resume', '-r', action='store_true',
                        help='resume from checkpoint')
    parser.add_argument('--train_mode', choices=['vanilla', 'ddg', 'diffusemix', 'dafusion'], default='vanilla')
    # train_mode = parser.parse_known_args(args)[0].train_mode.lower()
    parser.add_argument(
        '--lr', type=float, default=0.01, help='set learning rate')
    parser.add_argument('--prob_aug', default=0.5, type=float,
                    help='probability of loading synthetic image')
    parser.add_argument('--prob_syn', default=0.25, type=float,
                        help='probability of loading generated fore')
    parser.add_argument('--prob_mix', default=0.5, type=float,
                        help='probability of mixing foreground')
    parser.add_argument('--num_syn', default=3, type=int,
                        help='the number of synthetic images')
    parser.add_argument('--size', default=448, type=int, help='the size of image')
    parser.add_argument('--strength', default=0.4, type=float, help='generation strength')
    parser.add_argument("--use_cutmix", default=False, action="store_true")
    parser.add_argument("--use_mixup", default=False, action="store_true")
    parser.add_argument(
        '--mixup_type', type=str, default='plain', help='select mixup type (hidden or plain)', choices=['hidden', 'plain','hook_hidden'])
    
    parser.add_argument(
        '--randaug', action="store_true", default=False, help='apply randaugment')

    args = parser.parse_args()
    configs = vars(args)
    torch.cuda.set_device(args.gpu)
    fix_seed(seed=args.seed, deterministic=True)
    
    if args.train_mode=='vanilla':
        exp_dir = f"vanilla_{args.dataset}_{args.model}_{args.size}"
    elif args.train_mode=='diffusemix':
        exp_dir = f"diffusemix_{args.dataset}_{args.model}_{args.size}_probaug{args.prob_aug}_numsyn{args.num_syn}"
    elif args.train_mode=='dafuion':
        exp_dir = f"dafusion_{args.dataset}_{args.model}_{args.size}_probaug{args.prob_aug}_numsyn{args.num_syn}"
    elif args.train_mode=='ddg':
        exp_dir = f"ddg_{args.dataset}_{args.model}_{args.size}_probaug{args.prob_aug}_probsyn{args.prob_syn}_numsyn{args.num_syn}_probmix{args.prob_mix}_strength{args.strength}"
    else:
        raise ValueError(f"{args.train_mode} is not valid")
    
    if args.use_mixup:
        exp_dir += '_mixup'
    elif args.use_cutmix:
        exp_dir += "_cutmix"
    
    # Prepare logging
    exp_dir = os.path.join("results", 'vit_results', exp_dir)
    print(f"save to {exp_dir}")
    os.makedirs(exp_dir, exist_ok=True)
    shutil.copyfile(__file__, exp_dir + "/train.py")
    
    log_file = os.path.join(exp_dir, 'train_log.csv')
    with open(log_file, 'w+') as file:
        file.write('Epoch, lr, Train_Loss, Train_Acc, Test_Acc\n')
    
    NUM_CLASSES_DICT={'cub':200, 'aircraft':100, 'car':196}
    nb_class = NUM_CLASSES_DICT[args.dataset]
 
    configs['num_classes'] = nb_class
    if args.model == 'vit':
        net = vit_b_16(
            weights=ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1,
        )
        net.heads.head = nn.Linear(net.heads.head.in_features, nb_class)
        args.size = 384
    else:
        net = get_model(configs)
        
    for param in net.parameters():
        param.requires_grad = True  # make parameters in model learnable
    
    IMG_SIZE_DICT = {
    "224": {"resize": 256, "crop_size": 224},
    "384": {"resize": 440, "crop_size": 384},
    "448": {"resize": 512, "crop_size": 448},
    }
    image_size, crop_size = IMG_SIZE_DICT[str(args.size)]['resize'], IMG_SIZE_DICT[str(args.size)]['crop_size']
    train_transform = transforms.Compose(
        [
            transforms.Resize((image_size, image_size)),
            transforms.RandomCrop(crop_size, padding=8),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ]
    )
    test_transform = transforms.Compose(
        [
            transforms.Resize((image_size, image_size)),
            transforms.CenterCrop((crop_size, crop_size)),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ]
    )
    
    DATASETNAME_DICT = {'cub':'CUB_200_2011', 'aircraft':'Aircraft', 'car':'StandfordCar'}
    dataset_name = DATASETNAME_DICT[args.dataset]
    orig_dir = f'path/to/iclr25/orig_dataset/{dataset_name}'
    cdp_dir = f'path/to/iclr25/aug_dataset/cdp_and_cip/{dataset_name}/cdp'
    cip_dir = f'path/to/iclr25/aug_dataset/cdp_and_cip/{dataset_name}/cip_pad'
    syn_cdp_dir = f'path/to/iclr25/cluster_inversion/generated_fore/{dataset_name}-{args.strength}-1clusters'
    diffusemix_dir = f'path/to/iclr25/aug_dataset/diffusemix/result-{args.dataset}/blended' 

    orig_train_dir, orig_test_dir = os.path.join(orig_dir, 'train'), os.path.join(orig_dir, 'test')
    
    if args.train_mode=='vanilla':
        train_set = ImageFolder(orig_train_dir, transform=train_transform)
        if args.use_cutmix:
            train_set = MYCutMix(
                train_set, prob=0.5
            )

    elif args.train_mode=='ddg':
        train_set = DDG(root_orig=orig_train_dir, root_cdp = cdp_dir, root_cip= cip_dir, root_syncdp= syn_cdp_dir, prob_aug=args.prob_aug, prob_syn=args.prob_syn, prob_mix=args.prob_mix, num_syn=args.num_syn, transform=train_transform)

    
    # elif args.train_mode=='diffusemix':
    #     train_set = AugImageFolder(root_orig=orig_train_dir, root_syn=diffusemix_dir, prob_aug=args.prob_aug, num_syn=args.num_syn, transform=train_transform)
    
    batch_size = 8 if args.model=='densenet121' else 16
    trainloader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=16)
    test_set = ImageFolder(orig_test_dir, transform=test_transform)
    testloader = torch.utils.data.DataLoader(test_set, batch_size=64, shuffle=False, num_workers=16)
    
    
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    if args.model == 'vit':
        # The training of vit model could not converge if using CrossEntropyLoss
        lr = 0.0005 if args.dataset == 'pascal' else 0.001
        criterion = LabelSmoothingLoss(classes=nb_class, smoothing=0.1)
        optimizer = torch.optim.SGD(
            net.parameters(),
            lr=lr,
            momentum=0.9,
            weight_decay=args.weight_decay,
        )
        total_steps = args.nepoch * len(trainloader.dataset) // batch_size
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_steps)
        
        import pytorch_warmup
        warmup_scheduler = pytorch_warmup.LinearWarmup(
            optimizer, warmup_period=max(int(0.1 * total_steps), 1)
        )
        epochs = 120
        
    else:
        lr = 0.01
        optimizer = torch.optim.SGD(
            net.parameters(),
            lr=lr,
            momentum=0.9,
            weight_decay=5e-4
        )
        criterion = nn.CrossEntropyLoss(reduction='none')
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[150, 225], gamma=0.1)
        warmup_scheduler = None
        epochs = 300
            
    if args.resume:
        resume_training(net)
    best_acc = 0.0
    
    if args.train_mode in ['ddg']:
        mode='mix'
    elif args.train_mode in ['vanilla', 'diffusemix', 'dafusion']:
        mode='vanilla'
    else:
        raise ValueError(f'{args.train_mode} not implemented')
    
    if args.use_mixup:
        mode='vanilla'
    if args.use_cutmix:
        mode='mix'
    print(f"train mode: {args.train_mode} use train mode {mode}")
    net.to(device)
    
    for epoch in range(epochs):
        train(epoch, net, trainloader, optimizer, criterion, device, log_file, mode=mode, args=args, warmup_scheduler=warmup_scheduler)
        acc = test(epoch, net, testloader, criterion, device, log_file)
        scheduler.step()
        ##### only save model with highest acc
        if acc > best_acc:
            print('Saving..')
            state = {
                'net': net.state_dict(),
                'acc': acc,
                'epoch': epoch,
            }
            save_checkpoint(state, exp_dir, 'max_acc.pth')
            best_acc = acc

    # Logging
    with open(log_file, 'a+') as file:
        file.write('Dataset {}\tACC:{:.2f}\n'.format(args.dataset, best_acc))
    with open(os.path.join(exp_dir, 'acc_{}_{:.2f}'.format(args.dataset, best_acc)), 'a+') as file:
        # save accuracy as file name
        pass

if __name__ == "__main__":
    main()