import argparse
import os
import random
from time import time

import torch
import torch.nn as nn
from tqdm import tqdm
import torch.nn.functional as F
from sklearn.metrics import confusion_matrix
from dataset import IEEGDataset, SubDataset
import numpy as np
from sklearn.metrics import fbeta_score
from exp_settings import ExpSettings_MAYO, ExpSettings_FNUSA
from utils.evaluation import Metrics
import warnings

warnings.filterwarnings('ignore')

torch.autograd.set_detect_anomaly(True)
username = os.getlogin()


def init_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)


Conv2d = nn.Conv2d


def conv_block(in_channels, out_channels, kernel_size, cnn_norm):
    if cnn_norm == 'instance_norm':
        norm = nn.InstanceNorm2d
    else:
        norm = nn.BatchNorm2d
    return nn.Sequential(
        Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=False),
        norm(out_channels),
        nn.ReLU(),
    )


class ConvolutionLayer(nn.Module):
    def __init__(self, in_c, h_c=64, cnn_norm='instance_norm', is_contextual=1, is_channel_level=1):
        super().__init__()
        self.channels = h_c
        self.is_contextual = is_contextual
        self.is_channel_level = is_channel_level
        self.conv0 = conv_block(in_c, self.channels * 2, (3, 3), cnn_norm=cnn_norm)
        self.conv1 = conv_block(self.channels * 2, self.channels * 2, (3, 3), cnn_norm=cnn_norm)
        self.conv2 = conv_block(self.channels * 2, self.channels * 2, (3, 3), cnn_norm=cnn_norm)
        self.conv3 = nn.Conv2d(self.channels * 2, self.channels * 2, (1, 2), bias=False)

        self.transform = nn.Sequential(
            nn.AdaptiveAvgPool2d(output_size=(1, 1)),
            nn.Flatten()
        )

    def forward(self, x):
        o0 = F.max_pool2d(self.conv0(x), kernel_size=(2, 2))
        o1 = F.max_pool2d(self.conv1(o0), kernel_size=(2, 2))
        if self.is_channel_level == 1 and self.is_contextual == 1:
            o2 = self.conv3(o1)

        o0 = self.transform(o0)
        o1 = self.transform(o1)

        if self.is_channel_level == 1 and self.is_contextual == 1:
            o2 = self.transform(o2)
            o = torch.cat([o0, o1, o2], dim=-1)
        else:
            o = torch.cat([o0, o1], dim=-1)
        return o


class DMNet(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.segment_length = args.segment_length
        self.segment_num = args.segment_num
        self.sum_up_length = args.sum_up_length
        self.is_contextual = args.is_contextual
        self.is_channel_level = args.is_channel_level
        self.is_cnn = args.is_cnn
        self.is_dm = args.is_dm
        self.is_fr = args.is_fr
        c_N = (args.segment_length+1) // 2 + (args.segment_length+1) % 2
        in_c = c_N // self.sum_up_length
        if args.diff_norm == 'min_max':
            if args.diff_sign == 1:
                self.differential = self.differential_signed_min_max
            else:
                self.differential = self.differential_abs_min_max
        else:
            self.differential = self.differential_std_norm
        self.cnn = ConvolutionLayer(in_c=in_c, h_c=args.h_c, cnn_norm=args.cnn_norm, is_contextual=args.is_contextual, is_channel_level=args.is_channel_level)
        self.k = args.k
        if self.is_channel_level == 1 and self.is_contextual == 1:
            self.fc = nn.Sequential(
                nn.Linear(in_features=args.h_c*2*3, out_features=2),
            )
        else:
            self.fc = nn.Sequential(
                nn.Linear(in_features=args.h_c * 2 * 2, out_features=2),
            )
        self.fc_no_cnn = nn.Sequential(
            nn.Linear(in_features=27225, out_features=2),
        )
        self.fc_no_dm = nn.Sequential(
            nn.Linear(825, 2)
        )
        self.fc_no_fr = nn.Sequential(
            nn.Linear(5250, 2)
        )

    @staticmethod
    def differential_signed_min_max(x):
        """
            b: batch_size
            c: channel_num
            m: segment_num
            n: feature_dim
        """
        b, c, m, n = x.shape
        A = torch.unsqueeze(x, dim=-2)
        B = torch.unsqueeze(x, dim=-3)
        val = A - B
        neg = val < 0
        abs_c = torch.abs(val).view(b, c, -1, n)
        max_val, _ = torch.max(abs_c, dim=-2)
        max_val = torch.unsqueeze(max_val, dim=-2)
        norm_val = abs_c / max_val
        norm_val = norm_val.view(b, c, m, m, -1)
        norm_val[neg] *= -1
        norm_val = norm_val.permute(0, 1, 4, 2, 3)
        return norm_val

    @staticmethod
    def differential_abs_min_max(x):
        """
            b: batch_size
            c: channel_num
            m: segment_num
            n: feature_dim
        """
        b, c, m, n = x.shape
        A = torch.unsqueeze(x, dim=-2)
        B = torch.unsqueeze(x, dim=-3)
        val = A - B
        abs_c = torch.abs(val).view(b, c, -1, n)
        max_val, _ = torch.max(abs_c, dim=-2)
        max_val = torch.unsqueeze(max_val, dim=-2)
        norm_val = abs_c / max_val
        norm_val = norm_val.view(b, c, m, m, -1)
        norm_val = norm_val.permute(0, 1, 4, 2, 3)
        return norm_val

    @staticmethod
    def differential_std_norm(x):
        """
            b: batch_size
            c: channel_num
            m: segment_num
            n: feature_dim
        """
        b, c, m, n = x.shape
        A = torch.unsqueeze(x, dim=-2)
        B = torch.unsqueeze(x, dim=-3)
        val = A - B
        neg = val < 0
        val = torch.abs(val).view(b, c, -1, n)
        sigma, mu = torch.std_mean(val, dim=-2, keepdim=True)
        norm_val = (val-mu) / (sigma + 1e-9)
        norm_val = norm_val.view(b, c, m, m, -1)
        norm_val[neg] *= -1
        norm_val = norm_val.permute(0, 1, 4, 2, 3)
        return norm_val

    def cal_metrics(self, x):
        b, c, t = x.size()
        x = x.view(b, c, -1, self.segment_length)
        x = torch.log(torch.abs(torch.fft.rfft(x, dim=-1)) + 1)
        c_N = x.shape[-1]
        avai_c_N = c_N // self.sum_up_length * self.sum_up_length
        x = x[:, :, :, :avai_c_N].view(b, c, self.segment_num * 2 + 1, -1, self.sum_up_length).sum(dim=-1)
        return x

    def forward(self, x, gs):
        b, c = x.shape[0], x.shape[1]
        gs = gs.float()
        metrics = self.cal_metrics(x)
        if self.is_contextual == 0:
            metrics = metrics[:, :, self.segment_num * 2, :].unsqueeze(2)
        if self.is_channel_level == 1:
            metrics = torch.cat([gs[:, :, :(self.k - 1), :], metrics, gs[:, :, -(self.k - 1):, :]], dim=2)
        diff_images = self.differential(metrics)
        _, _, ic, m, m = diff_images.size()
        if self.is_cnn == 1:
            diff_images = diff_images.view(-1, ic, m, m)
            if self.is_channel_level == 1 and self.is_contextual == 1:
                cnn_hidden = self.cnn(diff_images[:, :, (self.k - 1):-(self.k - 1), :])
            else:
                cnn_hidden = self.cnn(diff_images)
            cnn_hidden = cnn_hidden.view(b, c, -1)
            logits = self.fc(cnn_hidden)
            return logits.view(b * c, -1)
        else:
            diff_images = diff_images.reshape(b * c, -1)
            logits = self.fc_no_cnn(diff_images)
            return logits



def validation(validation_loader, model):
    model.eval()
    with torch.no_grad():
        tot_labels = []
        tot_preds = []
        cnt = 0
        for batches in tqdm(validation_loader):
            x, y, ssp = batches
            x = x.cuda()
            y = y.cuda()
            ssp = ssp.cuda()
            outputs = model(x, ssp)
            pred = torch.max(outputs, dim=-1)[1]
            tot_labels += y.view(-1).cpu().numpy().tolist()
            tot_preds += pred.cpu().numpy().tolist()
            cnt += 1
        print('Validation Confusion Matrix: \n', confusion_matrix(tot_labels, tot_preds))
        f1_score = fbeta_score(tot_labels, tot_preds, beta=1)
        print('F1:', f1_score)
        return f1_score


def test(test_loader, model: DMNet, args):
    ckpt_save_path = f'{args.save_path}/ckpts'
    model.load_state_dict(
        torch.load(os.path.join(ckpt_save_path, f'{args.exp_id}_best.pth'), map_location='cpu'))
    model.cuda()
    model.eval()
    with torch.no_grad():
        start_time = time()
        tot_labels = []
        tot_preds = []
        tot_logits = []
        batch_count = 0
        for batches in tqdm(test_loader):
            x, y, ssp = batches
            x = x.cuda()
            y = y.cuda()
            ssp = ssp.cuda()
            outputs = model(x, ssp)
            pred = torch.max(outputs, dim=-1)[1]
            tot_labels += y.view(-1).cpu().numpy().tolist()
            tot_preds += pred.cpu().numpy().tolist()
            tot_logits.append(outputs)
            batch_count += 1

        tot_labels = np.array(tot_labels).reshape(-1)
        tot_preds = np.array(tot_preds).reshape(-1)
        tot_logits = torch.cat(tot_logits, dim=0).cpu().numpy()
        result_save_path = f'{args.save_path}/{args.exp_id}'
        if not os.path.exists(result_save_path):
            os.makedirs(result_save_path)
        metric = Metrics(tot_preds, tot_labels)
        res = {
            'acc': metric.acc,
            'pre': metric.prec,
            'rec': metric.rec,
            'f1': metric.f_one,
            'f2': metric.f_doub
        }
        if not os.path.exists(f'./result/{args.dataset_name}'):
            os.makedirs(f'./result/{args.dataset_name}')
        print(res)
        torch.save(res, f'./result/{args.dataset_name}/{args.exp_id}.pt')
        end_time = time()
        print('时间消耗:', end_time - start_time)


def test2(test_loaders, model: DMNet, args):
    ckpt_save_path = f'{args.save_path}/ckpts'
    model.load_state_dict(
        torch.load(os.path.join(ckpt_save_path, f'{args.exp_id}_best.pth'), map_location='cpu'))
    model.cuda()
    model.eval()
    with torch.no_grad():
        for test_loader in test_loaders:
            start_time = time()
            filename = test_loader.dataset.filename
            tot_labels = []
            tot_preds = []
            tot_logits = []
            batch_count = 0
            for batches in tqdm(test_loader):
                x, y, ssp = batches
                x = x.cuda()
                y = y.cuda()
                ssp = ssp.cuda()
                outputs = model(x, ssp)
                pred = torch.max(outputs, dim=-1)[1]
                tot_labels += y.view(-1).cpu().numpy().tolist()
                tot_preds += pred.cpu().numpy().tolist()
                tot_logits.append(outputs)
                batch_count += 1

            tot_labels = np.array(tot_labels).reshape(-1)
            tot_preds = np.array(tot_preds).reshape(-1)
            tot_logits = torch.cat(tot_logits, dim=0).cpu().numpy()
            result_save_path = f'{args.save_path}/{args.exp_id}'
            if not os.path.exists(result_save_path):
                os.makedirs(result_save_path)
            metric = Metrics(tot_preds, tot_labels)
            print(f'{filename}: ')
            print(metric)
            end_time = time()
            print('时间消耗:', end_time - start_time)


def fit(train_loader, valid_loader, optimizer, model, args):
    best_f1 = -1
    for epoch in range(args.num_epochs):
        model.train()
        tot_loss = 0
        tot_labels = []
        tot_preds = []
        cnt = 0
        for batches in tqdm(train_loader):
            x, y, ssp = batches
            x = x.cuda()
            y = y.cuda()
            ssp = ssp.cuda()
            outputs = model(x, ssp)
            pred = torch.max(outputs, dim=-1)[1]
            tot_labels += y.view(-1).cpu().numpy().tolist()
            tot_preds += pred.cpu().numpy().tolist()
            loss = F.cross_entropy(outputs, y.view(-1), weight=torch.Tensor([args.ratio, 1]).cuda())
            tot_loss += loss.item()
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            cnt += 1

        print(f'EPOCH-{epoch}:')
        print('LR:', optimizer.state_dict()['param_groups'][0]['lr'])
        print('Train Confusion Matrix: \n', confusion_matrix(tot_labels, tot_preds))
        print('Train Loss: ', (tot_loss / cnt))

        dev_f1 = validation(validation_loader=valid_loader, model=model)
        if dev_f1 > best_f1:
            print(f"F1 increase: {best_f1 : .4f} -> {dev_f1 : .4f}.")
            best_f1 = dev_f1
            ckpt_save_path = f'{args.save_path}/ckpts'
            if not os.path.exists(ckpt_save_path):
                os.makedirs(ckpt_save_path)
            torch.save(model.state_dict(), os.path.join(ckpt_save_path, f'{args.exp_id}_best.pth'))


def class_ratio(loader):
    y = []
    for _, y_, _ in tqdm(loader):
        y.append(y_.reshape(-1))
    y = torch.concat(y).reshape(-1)
    pos = torch.sum(y == 1)
    neg = torch.sum(y == 0)
    return (pos / neg)


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)



if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='DM')
    parser.add_argument("--gpu_id", type=int, default=1)
    parser.add_argument("--exp_id", type=int, default=1)
    parser.add_argument("--segment_num", type=int, default=7)
    parser.add_argument("--is_training", type=int, default=1)
    parser.add_argument("--segment_length", type=int, default=500)
    parser.add_argument("--sum_up_length", type=int, default=5)
    parser.add_argument("--is_contextual", type=int, default=1)
    parser.add_argument("--is_channel_level", type=int, default=1)
    parser.add_argument("--diff_norm", type=str, default='min_max')
    parser.add_argument("--dataset_name", type=str, default='MAYO')
    parser.add_argument("--diff_sign", type=int, default=1)
    parser.add_argument("--cnn_norm", type=str, default='batch_normalization')
    parser.add_argument("--is_cnn", type=int, default=1)
    parser.add_argument("--k", type=int, default=7)
    parser.add_argument("--h_c", type=int, default=8)
    parser.add_argument("--is_dm", type=int, default=0)
    parser.add_argument("--is_fr", type=int, default=1)
    parser.add_argument("--learning_rate", type=float, default=3e-4)
    parser.add_argument("--num_epochs", type=int, default=5)
    init_seed()
    args = parser.parse_args()
    args.save_path = f'/data/{username}/path to result/{args.dataset_name}_sn_{args.segment_num}_sl_{args.segment_length}_sul_{args.sum_up_length}_h_c_{args.h_c}_k_{args.k}_is_contextual_{args.is_contextual}_is_channel_level_{args.is_channel_level}_diff_sign_{args.diff_sign}_diff_norm_{args.diff_norm}_is_cnn_{args.is_cnn}_cnn_norm_{args.cnn_norm}_is_dm_{args.is_dm}_is_fr_{args.is_fr}/'
    args.data_path = f'/data/{username}/path to data/{args.dataset_name}/'
    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)
    exp_settings = ExpSettings_MAYO() if args.dataset_name == 'MAYO' else ExpSettings_FNUSA()
    train_filenames = exp_settings.exps[args.exp_id]['source']
    valid_filenames = exp_settings.exps[args.exp_id]['valid']
    test_filenames = exp_settings.exps[args.exp_id]['target']

    train_loader = IEEGDataset(filenames=train_filenames, args=args).get_dataloader()
    valid_loader = IEEGDataset(filenames=valid_filenames, args=args).get_dataloader()
    test_loader = IEEGDataset(filenames=test_filenames, args=args).get_dataloader()

    test_loaders2 = [SubDataset(filename=str(_), args=args).get_dataloader() for _ in test_filenames]

    args.ratio = class_ratio(train_loader)

    print(args.ratio)

    print('Ratio:', args.ratio)

    model = DMNet(args).cuda()
    print("n_params: ", count_parameters(model))
    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    if args.is_training == 1:
        fit(train_loader=train_loader,
            valid_loader=valid_loader,
            model=model,
            optimizer=optimizer,
            args=args)

    test(test_loader=test_loader, model=model, args=args)
    test2(test_loaders=test_loaders2, model=model, args=args)
