import argparse
import torch
import random
import numpy as np
from datetime import datetime

from scipy.io import savemat
from sklearn.neighbors import KNeighborsClassifier
from torch.utils.data import TensorDataset, DataLoader

import generate_pic
import utils
from functools import partial
from augment import CenterResizeCrop
from data_read import readdata, get_val_num
from hyper_dataset import HyperData
from stage1_model import LPCnet
from sklearn import metrics
from collections import OrderedDict
import torch.nn.functional as F


def train(args):
    for num in range(0, args.num_of_ex):


        num_of_samples = get_val_num(args.dataset, args.ratio)
        train_image_HSI, train_image_LIDAR, train_label, validation_image, validation_image_LIDAR, validation_label, nTrain_perClass, nvalid_perClass, \
        train_index, val_index, index, image, image_LiDAR, gt, s = readdata(args.type, args.dataset, args.windowsize,
                                                                            args.per_num_perclass, num_of_samples, num)
        class_name = ["Healthy Grass", "Stressed Grass", "Synthetic Grass", "Tree",
                      "Soil", "Water", "Residential", "Commercial", "Road", "Highway",
                      "Railway", "Parking Lot1", "Parking Lot2", "Tennis Court",
                      "Running Track"]
        Share_text = [('A Sample Of ') + name for name in class_name]
        HSI_text = ['Healthy grass typically appears green on the spectrum and has a uniform, soft turf surface ',
                    'The spectral albedo of stressed grass is low in the visible and near-infrared bands, and the leaves are uneven or yellow',
                    'Synthetic grass is usually dark or light green in color and is made from synthetic fibers',
                    'Different parts of the tree have unique material properties, including wood, cellulose, and a waxy coating on the leaves',
                    'Soil usually consists of sand, rock, or other natural features',
                    'Water usually appears as a dark blue or black liquid',
                    'Residential areas usually use residential construction materials such as bricks, concrete, wood',
                    'Commercial areas usually use commercial construction materials, such as steel structures, glass curtain walls, masonry',
                    'Road is usually less reflective and have a gritty or rough surface',
                    'Highway pavements are usually made of solid pavement materials such as asphalt or concrete durable materials',
                    'Railway pavements are usually made of rails and gravel',
                    'Parking lot 1 are usually made of durable materials',
                    'The surface of a parking lot 2 is usually natural soil or grass',
                    'Tennis courts typically have relatively high albedo',
                    'Running track is usually covered with rubber elastic materials']
        Lidar_text = ['Healthy grass will remain between a few centimeters to tens of centimeters tall',
                      'Stressed grass has lower height',
                      'Artificial grass typically ranges from 2 to 40 millimeter',
                      'The height of a tree can range from a few meters to tens of meters',
                      'The soil is generally level with the surrounding ground, with no significant changes in height',
                      'The height of surface water is usually the same as the height of the surface',
                      'The height of the residential is usually several stories or low-rise buildings',
                      'Commercial areas are usually high-rise buildings',
                      'Road is usually low and the surface is approximately level with the surrounding ground',
                      'Highways have relatively high pavements to cope with geographical requirements such as crossing bridges, tunnels and intersections',
                      'The railway road surface is relatively high to ensure sufficient height above the ground so that train wheels, rails and carriages can pass smoothly',
                      'Parking lot 1 tends to be smoother, reducing bumps and discomfort',
                      'Parking lot 2 are often uneven and may have dips, bumps or uneven areas',
                      'Tennis courts usually require a solid, flat surface',
                      'Running track usually has a certain slope']
        number_class = np.max(train_label).astype(np.int64) + 1
        input_channels = train_image_HSI.shape[1]
        input_channels_LIDAR = train_image_LIDAR.shape[1]
        if args.augment:
            transform_train = [CenterResizeCrop(scale_begin = args.scale, windowsize = args.windowsize)]
            train_dataset = HyperData((train_image_HSI, train_image_LIDAR, train_label), transform_train)
        else:
            train_dataset = TensorDataset(torch.tensor(train_image_HSI), torch.tensor(train_image_LIDAR), torch.tensor(train_label))
        train_loader = DataLoader(dataset = train_dataset, batch_size = args.batch_size, shuffle = True)

        feature_encoder = LPCnet(img_size=(args.windowsize, args.windowsize), in_chans=input_channels,
                                               in_chans_LIDAR=input_channels_LIDAR, norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
                                               hid_chans=128, hid_chans_LIDAR=128, embed_dim=args.encoder_dim,
                                               depth=args.encoder_depth,
                                               patch_size=args.patch_size,
                                               num_heads=args.encoder_num_heads, mlp_ratio=2.0, num_classes=number_class,
                                               global_pool=False)
        feature_encoder.cuda(0)
        # optimizer,scheduler,loss
        optimizer = torch.optim.Adam(feature_encoder.parameters(), lr=args.lr)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=5, T_mult=2)
        crossEntropy = torch.nn.CrossEntropyLoss().cuda(0)
        for epoch in range(args.epochs):

            feature_encoder.train()
            total_loss = 0
            for idx, (HSI_data, LiDAR_data, label) in enumerate(train_loader):
                HSI_data, LiDAR_data, label = HSI_data.cuda(0), LiDAR_data.cuda(0), label.cuda(0)

                #Supervised comparison loss calculation process
                logits, proj_Shared_vision_features, proj_HSI_vision_features, proj_Lidar_vision_features, \
                share_text_features, HSI_text_features, Lidar_text_features = feature_encoder(HSI_data, LiDAR_data, Share_text, HSI_text, Lidar_text)
                loss_cls = crossEntropy(logits, label)

                dim_feature = proj_Shared_vision_features.shape[-1]
                loss_clip_share = utils.get_loss_clip(args.per_num_perclass, dim_feature, label, proj_Shared_vision_features, share_text_features)
                loss_clip_HSI = utils.get_loss_clip(args.per_num_perclass, dim_feature, label, proj_HSI_vision_features, HSI_text_features)
                loss_clip_Lidar = utils.get_loss_clip(args.per_num_perclass, dim_feature, label, proj_Lidar_vision_features, Lidar_text_features)
                # loss_clip_HSI * 0.005 + loss_clip_Lidar * 0.005
                loss_all = loss_cls \
                           + loss_clip_share * 0.005 + loss_clip_HSI * 0.005 + loss_clip_Lidar * 0.005
                feature_encoder.zero_grad()
                loss_all.backward()
                optimizer.step()
                total_loss = total_loss + loss_all

            scheduler.step()
            total_loss = total_loss / (idx + 1)
            print('epoch:', epoch,
                  'loss:', total_loss.data.cpu().numpy())
            state = {'model': feature_encoder.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch}

        torch.save(state, './model/base_lidar.pt')

        true_cla, overall_accuracy, average_accuracy, kappa, true_label, \
        test_pred, test_index, cm =test_batch\
            (feature_encoder.eval(), image, image_LiDAR, index, 512, nTrain_perClass, nvalid_perClass, int(args.windowsize / 2), Share_text, HSI_text, Lidar_text)
        if overall_accuracy > 60:
            day = datetime.now()
            day_str = day.strftime('%m_%d_%H_%M')
            classification_map, gt_map = generate_pic.generate(image, gt, index, nTrain_perClass, nvalid_perClass, test_pred, overall_accuracy, int(args.windowsize / 2), args.dataset, day_str, num)

            savemat('result/'+ args.dataset + '/lidar/' + str(overall_accuracy)+'.mat', {'map':classification_map})
            torch.save(state, 'model/'+ args.dataset + '/lidar/' + str(overall_accuracy)+'net.pt')


def test_batch(model, image, image_LIDAR, index, BATCH_SIZE, nTrain_perClass, nvalid_perClass, halfsize, Share_text, HSI_text, Lidar_text):
    ind = index[0][nTrain_perClass[0] + nvalid_perClass[0]:, :]
    nclass = len(index)
    true_label = np.zeros(ind.shape[0], dtype=np.int32)
    for i in range(1, nclass):
        ddd = index[i][nTrain_perClass[i] + nvalid_perClass[i]:, :]
        ind = np.concatenate((ind, ddd), axis=0)
        tr_label = np.ones(ddd.shape[0], dtype=np.int32) * i
        true_label = np.concatenate((true_label, tr_label), axis=0)
    test_index = np.copy(ind)
    length = ind.shape[0]
    if length % BATCH_SIZE != 0:
        add_num = BATCH_SIZE - length % BATCH_SIZE
        ff = range(length)
        add_ind = np.random.choice(ff, add_num, replace=False)
        add_ind = ind[add_ind]
        ind = np.concatenate((ind, add_ind), axis=0)

    pred_array = np.zeros([ind.shape[0], nclass], dtype=np.float32)
    n = ind.shape[0] // BATCH_SIZE
    windowsize = 2 * halfsize + 1
    image_batch = np.zeros([BATCH_SIZE, windowsize, windowsize, image.shape[2]], dtype=np.float32)
    image_LIDAR_batch = np.zeros([BATCH_SIZE, windowsize, windowsize, image_LIDAR.shape[2]], dtype=np.float32)
    for i in range(n):
        for j in range(BATCH_SIZE):
            m = ind[BATCH_SIZE * i + j, :]
            image_batch[j, :, :, :] = image[(m[0] - halfsize):(m[0] + halfsize + 1),
                                      (m[1] - halfsize):(m[1] + halfsize + 1), :]
            image_b = np.transpose(image_batch, (0, 3, 1, 2))
            image_LIDAR_batch[j, :, :, :] = image_LIDAR[(m[0] - halfsize):(m[0] + halfsize + 1),
                                            (m[1] - halfsize):(m[1] + halfsize + 1), :]
            image_LIDAR_b = np.transpose(image_LIDAR_batch, (0, 3, 1, 2))
        logits, proj_Shared_vision_features, proj_HSI_vision_features, proj_Lidar_vision_features, \
        share_text_features, HSI_text_features, Lidar_text_features = model(torch.tensor(image_b).cuda(0), torch.tensor(image_LIDAR_b).cuda(0), Share_text, HSI_text, Lidar_text)
        if isinstance(logits, tuple):
            logits = logits[-1]
        pred_array[i * BATCH_SIZE:(i + 1) * BATCH_SIZE] = torch.softmax(logits, dim=1).cpu().data.numpy()
    pred_array = pred_array[range(length)]
    predict_label = np.argmax(pred_array, axis=1)

    confusion_matrix = metrics.confusion_matrix(true_label, predict_label)
    overall_accuracy = metrics.accuracy_score(true_label, predict_label)

    true_cla = np.zeros(nclass, dtype=np.int64)
    for i in range(nclass):
        true_cla[i] = confusion_matrix[i, i]
    test_num_class = np.sum(confusion_matrix, 1)
    test_num = np.sum(test_num_class)
    num1 = np.sum(confusion_matrix, 0)
    po = overall_accuracy
    pe = np.sum(test_num_class * num1) / (test_num * test_num)
    kappa = (po - pe) / (1 - pe) * 100
    true_cla = np.true_divide(true_cla, test_num_class) * 100
    average_accuracy = np.average(true_cla)
    print('overall_accuracy: {0:f}'.format(overall_accuracy * 100))
    print('average_accuracy: {0:f}'.format(average_accuracy))
    print('AA: ', true_cla)
    print('kappa:{0:f}'.format(kappa))
    return true_cla, overall_accuracy * 100, average_accuracy, kappa, true_label, predict_label, test_index, confusion_matrix


def main(args):
    print(args)
    args.num_of_ex = 10
    # set up seed
    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    train(args)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # Pre training
    parser.add_argument('--per_num_perclass', type=int, default=40)

    parser.add_argument('--ratio', default=0.2, type=float,
                        help='ratio of val (default: 0.2)')
    # Pre training
    parser.add_argument('--seed', dest='seed', default=2, type=int,
                        help='Random seed')
    parser.add_argument('--windowsize', type=int, default=11)
    parser.add_argument('--patch_size', type=int, default=1)
    parser.add_argument('--batch_size', type=int, default=1024)
    parser.add_argument('--epochs', type=int, default=601)

    parser.add_argument('--lr', type=float, default=1e-4)

    parser.add_argument('--type', type=str, default='PCA')
    parser.add_argument('--dataset', type=str, default='2013houston')
    # Augmentation
    parser.add_argument('--augment', default=True, type=bool,
                        help='either use data augmentation or not (default: False)')
    parser.add_argument('--scale', default=9, type=int,
                        help='the minimum scale for center crop (default: 19)')

    # MAE encoder specifics
    parser.add_argument('--encoder_dim', default=64, type=int,
                        help='feature dimension for encoder (default: 64)')
    parser.add_argument('--encoder_depth', default=4, type=int,
                        help='encoder_depth; number of blocks ')
    parser.add_argument('--encoder_num_heads', default=8, type=int,
                        help='number of heads of encoder (default: 8)')

    args = parser.parse_args()
    args.cuda = torch.cuda.is_available()
    main(args)