import argparse
import torch
import random
import numpy as np
from datetime import datetime

from scipy.io import savemat
from sklearn.neighbors import KNeighborsClassifier
from torch.utils.data import TensorDataset, DataLoader

import generate_pic
import stage3_model_hsi
import utils
from functools import partial
from augment import CenterResizeCrop
from augsubrg import stage2_model_hsi
from data_read import readdata, get_val_num
from hyper_dataset import HyperData
from stage1_model import LPCnet
from sklearn import metrics
from collections import OrderedDict
import torch.nn.functional as F
from pos_embed import interpolate_pos_embed

def train(args):
    for num in range(0, args.num_of_ex):

        num_of_samples = get_val_num(args.dataset, args.ratio)
        train_image_HSI, train_image_LIDAR, train_label, validation_image, validation_image_LIDAR, validation_label, nTrain_perClass, nvalid_perClass, \
        train_index, val_index, index, image, image_LiDAR, gt, s = readdata(args.type, args.dataset, args.windowsize,
                                                                            args.per_num_perclass, num_of_samples, num)
        SAR_text = ['Forest often exhibit complex textures and structures',
            'Residential area often contains houses and apartment buildings with distinct geometric structures',
            'Industrial area usually contains warehouses, factories, and have diverse structures',
            'Low Plants typically exhibit a relatively even distribution with no obvious structure or geometry',
            'Allotment is usually composed of vegetation and soil, and its ability to reflect microwave signals is relatively low',
            'Commercial area includes shops, offices, and parking facilities and reflects moderate microwave signals',
            'Water exhibits high reflection intensity because water has a relatively high reflectivity of electromagnetic waves']
        number_class = np.max(train_label).astype(np.int64) + 1
        input_channels = train_image_HSI.shape[1]
        input_channels_LIDAR = train_image_LIDAR.shape[1]
        if args.augment:
            transform_train = [CenterResizeCrop(scale_begin = args.scale, windowsize = args.windowsize)]
            train_dataset = HyperData((train_image_HSI, train_image_LIDAR, train_label), transform_train)
        else:
            train_dataset = TensorDataset(torch.tensor(train_image_HSI), torch.tensor(train_image_LIDAR), torch.tensor(train_label))
        train_loader = DataLoader(dataset = train_dataset, batch_size = args.batch_size, shuffle = True)

        feature_encoder_hsi = stage2_model_hsi.LPCnet_HSI(img_size=(args.windowsize, args.windowsize), in_chans=input_channels,
                                               in_chans_LIDAR=input_channels_LIDAR, norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
                                               hid_chans=128, hid_chans_LIDAR=128, embed_dim=args.encoder_dim,
                                               depth=args.encoder_depth,
                                               patch_size=args.patch_size,
                                               num_heads=args.encoder_num_heads, mlp_ratio=2.0, num_classes=number_class,
                                               global_pool=False)
        checkpoint_all = torch.load('./model/augsburg_base.pt')
        checkpoint_model_all = checkpoint_all['model']

        state_dict = feature_encoder_hsi.state_dict()
        model_dict_all = {}
        for k, v in checkpoint_model_all.items():
            if k in state_dict:
                model_dict_all[k] = v
        state_dict.update(model_dict_all)
        feature_encoder_hsi.load_state_dict(state_dict)

        interpolate_pos_embed(feature_encoder_hsi, checkpoint_model_all)

        for name, param in feature_encoder_hsi.named_parameters():
            if name in ['mlp_stage3_hsi.weight', 'mlp_stage3_hsi.bias',
                        'Shared_vision_net.mlp_hsi.0.weight', 'Shared_vision_net.mlp_hsi.0.bias',
                        'Shared_vision_net.mlp_hsi.2.weight', 'Shared_vision_net.mlp_hsi.2.bias',
                        'Shared_vision_net.mlp_hsi.2.weight','Lidar_vision_net.prompt_LIDAR'
                        'hsi_lidar_dim.0.weight', 'hsi_lidar_dim.0.bias', 'hsi_lidar_dim.1.weight', 'hsi_lidar_dim.1.bias',
                        'hsi_lidar_dim.1.running_mean', 'hsi_lidar_dim.1.running_var', 'hsi_lidar_dim.1.num_batches_tracked',
                        'hsi_lidar_dim.3.weight', 'hsi_lidar_dim.3.bias', 'hsi_lidar_dim.4.weight', 'hsi_lidar_dim.4.bias',
                        'hsi_lidar_dim.4.running_mean', 'hsi_lidar_dim.4.running_var', 'hsi_lidar_dim.4.num_batches_tracked',
                        ]:
                param.requires_grad = True
            else:
                param.requires_grad = False


        feature_encoder_hsi.cuda(1)
        # optimizer,scheduler,loss
        optimizer = torch.optim.Adam(feature_encoder_hsi.parameters(), lr=args.lr)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=5, T_mult=2)
        crossEntropy = torch.nn.CrossEntropyLoss().cuda(1)

        for epoch in range(args.epochs):

            feature_encoder_hsi.train()
            total_loss = 0
            for idx, (HSI_data, LiDAR_data, label) in enumerate(train_loader):
                HSI_data, LiDAR_data, label = HSI_data.cuda(1), LiDAR_data.cuda(1), label.cuda(1)

                #Supervised comparison loss calculation process
                logits, proj_Lidar_vision_features, text_features = feature_encoder_hsi(HSI_data, SAR_text)
                loss_cls = crossEntropy(logits, label)
                dim_feature = proj_Lidar_vision_features.shape[-1]
                loss_clip_share = utils.get_loss_clip_augsburg(args.per_num_perclass, dim_feature, label,
                                                      proj_Lidar_vision_features, text_features)
                loss_all = loss_cls + loss_clip_share * 0.0001

                feature_encoder_hsi.zero_grad()
                loss_all.backward()
                optimizer.step()
                total_loss = total_loss + loss_all

            scheduler.step()
            total_loss = total_loss / (idx + 1)
            print('epoch:', epoch,
                  'loss:', total_loss.data.cpu().numpy())
            state = {'model': feature_encoder_hsi.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch}

        # torch.save(state, './model/stage3_hsi.pt')

        true_cla, overall_accuracy, average_accuracy, kappa, true_label, \
        test_pred, test_index, cm =test_batch\
            (feature_encoder_hsi.eval(), image, image_LiDAR, index, args.batch_size, nTrain_perClass, nvalid_perClass, int(args.windowsize / 2), SAR_text)
        if overall_accuracy > 85.5:
            day = datetime.now()
            day_str = day.strftime('%m_%d_%H_%M')
            classification_map, gt_map = generate_pic.generate(image, gt, index, nTrain_perClass, nvalid_perClass, test_pred, overall_accuracy, int(args.windowsize / 2), args.dataset, day_str, num)

            savemat('result/'+ args.dataset + '/hsi_missing/' + str(overall_accuracy)+'.mat', {'map':classification_map})
            torch.save(state, 'model/'+ args.dataset + '/hsi_missing/' + str(overall_accuracy)+'net.pt')


def test_batch(model, image, image_LIDAR, index, BATCH_SIZE, nTrain_perClass, nvalid_perClass, halfsize, SAR_text):
    ind = index[0][nTrain_perClass[0] + nvalid_perClass[0]:, :]
    nclass = len(index)
    true_label = np.zeros(ind.shape[0], dtype=np.int32)
    for i in range(1, nclass):
        ddd = index[i][nTrain_perClass[i] + nvalid_perClass[i]:, :]
        ind = np.concatenate((ind, ddd), axis=0)
        tr_label = np.ones(ddd.shape[0], dtype=np.int32) * i
        true_label = np.concatenate((true_label, tr_label), axis=0)
    test_index = np.copy(ind)
    length = ind.shape[0]
    if length % BATCH_SIZE != 0:
        add_num = BATCH_SIZE - length % BATCH_SIZE
        ff = range(length)
        add_ind = np.random.choice(ff, add_num, replace=False)
        add_ind = ind[add_ind]
        ind = np.concatenate((ind, add_ind), axis=0)

    pred_array = np.zeros([ind.shape[0], nclass], dtype=np.float32)
    n = ind.shape[0] // BATCH_SIZE
    windowsize = 2 * halfsize + 1
    image_batch = np.zeros([BATCH_SIZE, windowsize, windowsize, image.shape[2]], dtype=np.float32)
    image_LIDAR_batch = np.zeros([BATCH_SIZE, windowsize, windowsize, image_LIDAR.shape[2]], dtype=np.float32)
    for i in range(n):
        for j in range(BATCH_SIZE):
            m = ind[BATCH_SIZE * i + j, :]
            image_batch[j, :, :, :] = image[(m[0] - halfsize):(m[0] + halfsize + 1),
                                      (m[1] - halfsize):(m[1] + halfsize + 1), :]
            image_b = np.transpose(image_batch, (0, 3, 1, 2))
            image_LIDAR_batch[j, :, :, :] = image_LIDAR[(m[0] - halfsize):(m[0] + halfsize + 1),
                                            (m[1] - halfsize):(m[1] + halfsize + 1), :]
            image_LIDAR_b = np.transpose(image_LIDAR_batch, (0, 3, 1, 2))
        logits, proj_Lidar_vision_features, text_features = model(torch.tensor(image_b).cuda(1), SAR_text)
        if isinstance(logits, tuple):
            logits = logits[-1]
        pred_array[i * BATCH_SIZE:(i + 1) * BATCH_SIZE] = torch.softmax(logits, dim=1).cpu().data.numpy()
    pred_array = pred_array[range(length)]
    predict_label = np.argmax(pred_array, axis=1)

    confusion_matrix = metrics.confusion_matrix(true_label, predict_label)
    overall_accuracy = metrics.accuracy_score(true_label, predict_label)

    true_cla = np.zeros(nclass, dtype=np.int64)
    for i in range(nclass):
        true_cla[i] = confusion_matrix[i, i]
    test_num_class = np.sum(confusion_matrix, 1)
    test_num = np.sum(test_num_class)
    num1 = np.sum(confusion_matrix, 0)
    po = overall_accuracy
    pe = np.sum(test_num_class * num1) / (test_num * test_num)
    kappa = (po - pe) / (1 - pe) * 100
    true_cla = np.true_divide(true_cla, test_num_class) * 100
    average_accuracy = np.average(true_cla)
    print('overall_accuracy: {0:f}'.format(overall_accuracy * 100))
    print('average_accuracy: {0:f}'.format(average_accuracy))
    print('kappa:{0:f}'.format(kappa))
    return true_cla, overall_accuracy * 100, average_accuracy, kappa, true_label, predict_label, test_index, confusion_matrix


def main(args):
    print(args)
    args.num_of_ex = 100
    # set up seed
    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    train(args)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # Pre training
    parser.add_argument('--per_num_perclass', type=int, default=40)

    parser.add_argument('--ratio', default=0.2, type=float,
                        help='ratio of val (default: 0.2)')
    # Pre training
    parser.add_argument('--seed', dest='seed', default=2, type=int,
                        help='Random seed')
    parser.add_argument('--windowsize', type=int, default=11)
    parser.add_argument('--patch_size', type=int, default=1)
    parser.add_argument('--batch_size', type=int, default=1024)
    parser.add_argument('--epochs', type=int, default=301)
    parser.add_argument('--lr', type=float, default=1e-3)

    parser.add_argument('--type', type=str, default='PCA')
    parser.add_argument('--dataset', type=str, default='Augsburg_SAR')
    # Augmentation
    parser.add_argument('--augment', default=True, type=bool,
                        help='either use data augmentation or not (default: False)')
    parser.add_argument('--scale', default=9, type=int,
                        help='the minimum scale for center crop (default: 19)')

    parser.add_argument('--encoder_dim', default=64, type=int,
                        help='feature dimension for encoder (default: 64)')
    parser.add_argument('--encoder_depth', default=4, type=int,
                        help='encoder_depth; number of blocks ')
    parser.add_argument('--encoder_num_heads', default=8, type=int,
                        help='number of heads of encoder (default: 8)')

    args = parser.parse_args()
    args.cuda = torch.cuda.is_available()
    main(args)