import argparse
import torch
import random
import numpy as np
from datetime import datetime

from scipy.io import savemat
from sklearn.neighbors import KNeighborsClassifier
from torch.utils.data import TensorDataset, DataLoader

import generate_pic
import utils
from functools import partial
from augment import CenterResizeCrop
from data_read import readdata, get_val_num
from hyper_dataset import HyperData
from stage1_model import LPCnet
from sklearn import metrics
from collections import OrderedDict
import torch.nn.functional as F


def train(args):
    for num in range(0, args.num_of_ex):


        num_of_samples = get_val_num(args.dataset, args.ratio)
        train_image_HSI, train_image_LIDAR, train_label, validation_image, validation_image_LIDAR, validation_label, nTrain_perClass, nvalid_perClass, \
        train_index, val_index, index, image, image_LiDAR, gt, s = readdata(args.type, args.dataset, args.windowsize,
                                                                            args.per_num_perclass, num_of_samples, args.seed)
        class_name = ["Apple trees", "Buliding", "Ground", "Woods",
                      "Vineyard", "Road"]
        Share_text = [('A Sample Of ') + name for name in class_name]
        HSI_text = ['The trunk of an apple tree is usually hard wood and the branches are soft',
                    'Buliding usually uses building materials such as concrete, bricks, and wood',
                    'The ground is mainly composed of soil, and the ground composition in different regions is different, such as sand, loam, clay, soil',
                    'Woods are primarily composed of living trees, including various species like oak, pine, maple, and many others',
                    'Vineyard is mainly composed of grape vines. These plants are usually composed of woody stems, leaves and grape fruits',
                    'Road is usually less reflective and have a gritty or rough surface']
        Lidar_text = ['Apple trees grow as small to medium-sized trees, usually between 2 and 4 meters in height',
                      'Building heights can range from a few meters to hundreds of meters above the horizon',
                      'The ground is usually flat or slightly sloped',
                      'The height of woods is determined by the trees present. Trees in woods can vary greatly in height',
                      'Vineyard vines can often grow several feet to several meters tall',
                      'Road is usually low and the surface is approximately level with the surrounding ground']
        number_class = np.max(train_label).astype(np.int64) + 1
        input_channels = train_image_HSI.shape[1]
        input_channels_LIDAR = train_image_LIDAR.shape[1]
        if args.augment:
            transform_train = [CenterResizeCrop(scale_begin = args.scale, windowsize = args.windowsize)]
            train_dataset = HyperData((train_image_HSI, train_image_LIDAR, train_label), transform_train)
        else:
            train_dataset = TensorDataset(torch.tensor(train_image_HSI), torch.tensor(train_image_LIDAR), torch.tensor(train_label))
        train_loader = DataLoader(dataset = train_dataset, batch_size = args.batch_size, shuffle = True)

        feature_encoder = LPCnet(img_size=(args.windowsize, args.windowsize), in_chans=input_channels,
                                               in_chans_LIDAR=input_channels_LIDAR, norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
                                               hid_chans=128, hid_chans_LIDAR=128, embed_dim=args.encoder_dim,
                                               depth=args.encoder_depth,
                                               patch_size=args.patch_size,
                                               num_heads=args.encoder_num_heads, mlp_ratio=2.0, num_classes=number_class,
                                               global_pool=False)
        feature_encoder.cuda(0)
        # optimizer,scheduler,loss
        optimizer = torch.optim.Adam(feature_encoder.parameters(), lr=args.lr)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=5, T_mult=2)
        crossEntropy = torch.nn.CrossEntropyLoss().cuda(0)
        for epoch in range(args.epochs):

            feature_encoder.train()
            total_loss = 0
            for idx, (HSI_data, LiDAR_data, label) in enumerate(train_loader):
                HSI_data, LiDAR_data, label = HSI_data.cuda(0), LiDAR_data.cuda(0), label.cuda(0)

                #Supervised comparison loss calculation process
                logits, proj_Shared_vision_features, proj_HSI_vision_features, proj_Lidar_vision_features, \
                share_text_features, HSI_text_features, Lidar_text_features = feature_encoder(HSI_data, LiDAR_data, Share_text, HSI_text, Lidar_text)
                loss_cls = crossEntropy(logits, label)

                dim_feature = proj_Shared_vision_features.shape[-1]
                loss_clip_share = utils.get_loss_clip_trento(args.per_num_perclass, dim_feature, label, proj_Shared_vision_features, share_text_features)
                loss_clip_HSI = utils.get_loss_clip_trento(args.per_num_perclass, dim_feature, label, proj_HSI_vision_features, HSI_text_features)
                loss_clip_Lidar = utils.get_loss_clip_trento(args.per_num_perclass, dim_feature, label, proj_Lidar_vision_features, Lidar_text_features)

                loss_all = loss_cls \
                           + loss_clip_share * 0.005 + loss_clip_HSI * 0.005 + loss_clip_Lidar * 0.005
                feature_encoder.zero_grad()
                loss_all.backward()
                optimizer.step()
                total_loss = total_loss + loss_all

            scheduler.step()
            total_loss = total_loss / (idx + 1)
            print('epoch:', epoch,
                  'loss:', total_loss.data.cpu().numpy())
            state = {'model': feature_encoder.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch}

        torch.save(state, './model/trento_base_lidar.pt')

        true_cla, overall_accuracy, average_accuracy, kappa, true_label, \
        test_pred, test_index, cm =test_batch\
            (feature_encoder.eval(), image, image_LiDAR, index, 512, nTrain_perClass, nvalid_perClass, int(args.windowsize / 2), Share_text, HSI_text, Lidar_text)
        if overall_accuracy > 88.5:
            day = datetime.now()
            day_str = day.strftime('%m_%d_%H_%M')
            classification_map, gt_map = generate_pic.generate(image, gt, index, nTrain_perClass, nvalid_perClass, test_pred, overall_accuracy, int(args.windowsize / 2), args.dataset, day_str, num)

            savemat('result/'+ args.dataset + '/all/' + str(overall_accuracy)+'.mat', {'map':classification_map})
            torch.save(state, 'model/'+ args.dataset + '/all/' + str(overall_accuracy)+'net.pt')


def test_batch(model, image, image_LIDAR, index, BATCH_SIZE, nTrain_perClass, nvalid_perClass, halfsize, Share_text, HSI_text, Lidar_text):
    ind = index[0][nTrain_perClass[0] + nvalid_perClass[0]:, :]
    nclass = len(index)
    true_label = np.zeros(ind.shape[0], dtype=np.int32)
    for i in range(1, nclass):
        ddd = index[i][nTrain_perClass[i] + nvalid_perClass[i]:, :]
        ind = np.concatenate((ind, ddd), axis=0)
        tr_label = np.ones(ddd.shape[0], dtype=np.int32) * i
        true_label = np.concatenate((true_label, tr_label), axis=0)
    test_index = np.copy(ind)
    length = ind.shape[0]
    if length % BATCH_SIZE != 0:
        add_num = BATCH_SIZE - length % BATCH_SIZE
        ff = range(length)
        add_ind = np.random.choice(ff, add_num, replace=False)
        add_ind = ind[add_ind]
        ind = np.concatenate((ind, add_ind), axis=0)

    pred_array = np.zeros([ind.shape[0], nclass], dtype=np.float32)
    n = ind.shape[0] // BATCH_SIZE
    windowsize = 2 * halfsize + 1
    image_batch = np.zeros([BATCH_SIZE, windowsize, windowsize, image.shape[2]], dtype=np.float32)
    image_LIDAR_batch = np.zeros([BATCH_SIZE, windowsize, windowsize, image_LIDAR.shape[2]], dtype=np.float32)
    for i in range(n):
        for j in range(BATCH_SIZE):
            m = ind[BATCH_SIZE * i + j, :]
            image_batch[j, :, :, :] = image[(m[0] - halfsize):(m[0] + halfsize + 1),
                                      (m[1] - halfsize):(m[1] + halfsize + 1), :]
            image_b = np.transpose(image_batch, (0, 3, 1, 2))
            image_LIDAR_batch[j, :, :, :] = image_LIDAR[(m[0] - halfsize):(m[0] + halfsize + 1),
                                            (m[1] - halfsize):(m[1] + halfsize + 1), :]
            image_LIDAR_b = np.transpose(image_LIDAR_batch, (0, 3, 1, 2))
        logits, proj_Shared_vision_features, proj_HSI_vision_features, proj_Lidar_vision_features, \
        share_text_features, HSI_text_features, Lidar_text_features = model(torch.tensor(image_b).cuda(0), torch.tensor(image_LIDAR_b).cuda(0), Share_text, HSI_text, Lidar_text)
        if isinstance(logits, tuple):
            logits = logits[-1]
        pred_array[i * BATCH_SIZE:(i + 1) * BATCH_SIZE] = torch.softmax(logits, dim=1).cpu().data.numpy()
    pred_array = pred_array[range(length)]
    predict_label = np.argmax(pred_array, axis=1)

    confusion_matrix = metrics.confusion_matrix(true_label, predict_label)
    overall_accuracy = metrics.accuracy_score(true_label, predict_label)

    true_cla = np.zeros(nclass, dtype=np.int64)
    for i in range(nclass):
        true_cla[i] = confusion_matrix[i, i]
    test_num_class = np.sum(confusion_matrix, 1)
    test_num = np.sum(test_num_class)
    num1 = np.sum(confusion_matrix, 0)
    po = overall_accuracy
    pe = np.sum(test_num_class * num1) / (test_num * test_num)
    kappa = (po - pe) / (1 - pe) * 100
    true_cla = np.true_divide(true_cla, test_num_class) * 100
    average_accuracy = np.average(true_cla)
    print('overall_accuracy: {0:f}'.format(overall_accuracy * 100))
    print('average_accuracy: {0:f}'.format(average_accuracy))
    print('AA: ', true_cla)
    print('kappa:{0:f}'.format(kappa))
    return true_cla, overall_accuracy * 100, average_accuracy, kappa, true_label, predict_label, test_index, confusion_matrix


def main(args):
    print(args)
    args.num_of_ex = 10
    # set up seed
    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    train(args)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # Pre training
    parser.add_argument('--per_num_perclass', type=int, default=40)

    parser.add_argument('--ratio', default=0.2, type=float,
                        help='ratio of val (default: 0.2)')
    # Pre training
    parser.add_argument('--seed', dest='seed', default=2, type=int,
                        help='Random seed')
    parser.add_argument('--windowsize', type=int, default=11)
    parser.add_argument('--patch_size', type=int, default=1)
    parser.add_argument('--batch_size', type=int, default=1024)
    parser.add_argument('--epochs', type=int, default=301)
    parser.add_argument('--lr', type=float, default=1e-4)

    parser.add_argument('--type', type=str, default='PCA')
    parser.add_argument('--dataset', type=str, default='Trento')
    # Augmentation
    parser.add_argument('--augment', default=True, type=bool,
                        help='either use data augmentation or not (default: False)')
    parser.add_argument('--scale', default=9, type=int,
                        help='the minimum scale for center crop (default: 19)')


    parser.add_argument('--encoder_dim', default=64, type=int,
                        help='feature dimension for encoder (default: 64)')
    parser.add_argument('--encoder_depth', default=4, type=int,
                        help='encoder_depth; number of blocks ')
    parser.add_argument('--encoder_num_heads', default=8, type=int,
                        help='number of heads of encoder (default: 8)')

    args = parser.parse_args()
    args.cuda = torch.cuda.is_available()
    main(args)