import os
import time
import random
import argparse
import numpy as np
from tqdm import trange
import pickle
import copy

import wandb

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset
from torchvision import datasets, transforms
import kornia.augmentation as K

from utils import evaluate_cl, evaluate_cl_debug, evaluate_analysis_cl
from data import DiffAugment, ParamDiffAug
from models.wrapper import get_model
from generator import SyntheticImageGenerator


def incremental_net(tg_net):    
    # init model
    ref_net = copy.deepcopy(tg_net)

    if hasattr(tg_net, "classifier"):
        in_features = tg_net.classifier.in_features
        out_features = tg_net.classifier.out_features
    if hasattr(tg_net, 'fc'):
        in_features = tg_net.fc.in_features
        out_features = tg_net.fc.out_features

    new_fc = nn.Linear(in_features, out_features + 20)
    if hasattr(tg_net, "classifier"):
        device = tg_net.classifier.weight.data.device
        new_fc.weight.data[:out_features] = tg_net.classifier.weight.data
        new_fc.bias.data[:out_features] = tg_net.classifier.bias.data
        tg_net.classifier = new_fc.to(device)
    if hasattr(tg_net, 'fc'):
        device = tg_net.fc.weight.data
        new_fc.weight.data[:out_features] = tg_net.fc.weight.data
        new_fc.bias.data[:out_features] = tg_net.fc.bias.data
        tg_net.fc = new_fc.to(device)
    return ref_net

def main(args):
    args.device = torch.device(f"cuda:{args.gpu_id}")
    args.dsa_param = ParamDiffAug()
    args.dsa_strategy = 'color_crop_cutout_flip_scale_rotate'
    
    ''' data set '''
    with open("./cifar100_order.pkl", "rb") as f:
        order = pickle.load(f)
    channel, im_size, num_classes = 3, (32, 32), 20       
    normalize = K.Normalize(mean=[0.5071, 0.4866, 0.4409], std=[0.2673, 0.2564, 0.2762])
    dst_test = datasets.CIFAR100(args.data_path, train=False, download=True, transform=transforms.ToTensor())

    accs = [ [] for _ in range(5) ]
    for _ in range(args.num_eval):
        image_syn = torch.FloatTensor([])
        label_syn = torch.LongTensor([])
    
        tg_net = get_model(args, args.model, channel, num_classes, im_size).to(args.device)
        ref_net = None
    
        for phase in range(5):
            ''' class map '''
            class_map = order[:(phase+1)*num_classes].tolist() 

            ''' synthetic data '''
            generator = SyntheticImageGenerator(
                num_classes, im_size, 200, 16, [8,13,18],
                4, 2, 1).to(args.device)
            del generator.encoders
            if phase == 0:
                file_path = f"ANONYMIZED"
            else:
                file_path = f"ANONYMIZED"
            generator.load_state_dict(torch.load(file_path, map_location="cpu"))
            img_syn, lab_syn = generator.get_all_cpu()
            image_syn = torch.cat([image_syn, img_syn], dim=0)
            label_syn = torch.cat([label_syn, lab_syn+phase*num_classes])

            ''' test loader '''
            images_te_all = []
            labels_te_all = []    
            for i in range(len(dst_test)):
                image, label = dst_test[i][0], dst_test[i][1]
                if label in class_map:
                    images_te_all.append(image)
                    labels_te_all.append(class_map.index(label))
            images_te_all = torch.stack(images_te_all, dim=0)
            labels_te_all = torch.LongTensor(labels_te_all)
            testloader = torch.utils.data.DataLoader(
                TensorDataset(images_te_all, labels_te_all), batch_size=128, shuffle=False, num_workers=0)

            #_, acc = evaluate_cl_debug(args, tg_net, image_syn, label_syn, testloader, normalize)
            #_, acc = evaluate_cl(args, tg_net, ref_net, image_syn, label_syn, testloader, normalize)
            _, acc = evaluate_cl(args, tg_net, ref_net, image_syn, label_syn, testloader, normalize)
            accs[phase].append(acc)

            ref_net = incremental_net(tg_net)

    
    print('\n==================== Final Results ====================\n')
    for phase in range(5):
        print(f"phase: {phase}, mean = {np.mean(accs[phase])}, std = {np.std(accs[phase])}")

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Parameter Processing')   

    # data
    parser.add_argument('--data_path', type=str, default='ANONYMIZED')   

    # save
    parser.add_argument('--save_path', type=str, default='results')
    parser.add_argument('--exp_name', type=str, default=None)

    # repeat
    parser.add_argument('--num_eval', type=int, default=3)
    
    # eval
    parser.add_argument('--model', type=str, default='ConvNet')
    parser.add_argument('--epoch', type=int, default=200)
    parser.add_argument('--lr', type=float, default=0.01)
    parser.add_argument('--batch', type=int, default=256)
    parser.add_argument('--T', type=float, default=2.0)
    parser.add_argument('--beta', type=float, default=1.0)

    parser.add_argument('--gpu-id', type=int, default=0)

    args = parser.parse_args()

    args.dataset = "CIFAR100_cl"
    args.ipc = 20
    args.model_eval_pool = args.model
    main(args)


