import os
import sys

import moxing as mox
#
# os.system('pip install --upgrade pip')
# os.system('pip install sacred')
# os.system('pip install tqdm')
# os.system('pip install tensorboard')
# os.system('pip install mir_eval')
# os.system('pip install soundfile')
# os.system('pip install torchlibrosa')
# os.system('pip install librosa')
# os.system('pip install einops')
# current_path = os.path.dirname(__file__)
# sys.path.append(current_path)

import pandas as pd
import torch
from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import numpy as np

'''1、函数导入'''
from src import MyDataset, cycle, summary, CM, FL, mae, pitchlabel2freqs, freqs2cents, DW, mse
from evaluate import evaluate

'''1、函数导入'''

# base = '/cache/whj/'
# os.makedirs(base, exist_ok=True)
# input_data = 'obs://bucket-5192-shanghai/weihaojie/CM_Music/code_END3/'
# full_data = "/cache/whj/"
# mox.file.remove(full_data, recursive=True)
# mox.file.copy_parallel(input_data, full_data)
output_data = 'obs://bucket-5192-shanghai/weihaojie/CM_Music_0/Result0'


def train(svs_type, pe_type, weight_type, t):
    """
    动态加权
    learned DWHS
    权重loss为BRR loss
    参考label 为ones
    """
    alpha = 5
    gamma = 2
    scale = 2
    w_noise = 0
    logdir = 'runs/MIR1K0/' + svs_type + '_' + pe_type + '_J1_DW_10lr_BPR_w' + str(t) + '_' + str(scale) + '_' + weight_type
    num_class = 360 if pe_type.lower() == 'crepe' else 352
    seq_l = 2.56
    pitch_th = 0.5
    hop_length = 320
    learning_rate = 1e-3
    batch_size = 16
    clip_grad_norm = 3
    learning_rate_decay_rate = 0.95
    learning_rate_decay_epochs = 3
    train_epochs = 250
    early_stop_epochs = 5
    data_para = True

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    columns = ['itter', '1', '2', '3', '4']
    df_pe = pd.DataFrame(columns=columns)
    df_svs = pd.DataFrame(columns=columns)
    df_nums = pd.DataFrame(columns=['itter', '1', '2', '3', '4'])

    train_dataset = MyDataset(path='./dataset/MIR1K', hop_length=hop_length, groups=['train'],
                              sequence_length=seq_l, num_class=num_class, pitch_th=0.5)
    print('train nums:', len(train_dataset))
    valid_dataset = MyDataset(path='./dataset/MIR1K', hop_length=hop_length, groups=['test'],
                              sequence_length=None, num_class=num_class, pitch_th=0.5)
    print('valid nums:', len(valid_dataset))
    data_loader = DataLoader(train_dataset, batch_size, shuffle=True, drop_last=True)
    epoch_nums = len(data_loader)
    print('epoch_nums:', epoch_nums)
    learning_rate_decay_steps = len(data_loader) * learning_rate_decay_epochs
    iterations = epoch_nums * train_epochs

    resume_iteration = None
    os.makedirs(logdir, exist_ok=True)
    writer = SummaryWriter(logdir)

    if resume_iteration is None:
        '''4、模型设置'''
        model = CM(1, 1024, hop_length, svs_type, pe_type)
        if weight_type == 'more':
            dw_model = DW(2, 64, num_class, num_class, hop_length)
        else:
            dw_model = DW(2, 64, num_class, 1, 1)
        '''4、模型设置'''
        if data_para:
            model = nn.DataParallel(model).to(device)
            dw_model = nn.DataParallel(dw_model).to(device)
        else:
            model = model.to(device)
            dw_model = dw_model.to(device)
        optimizer = torch.optim.Adam(model.parameters(), learning_rate)
        optimizer_dw = torch.optim.Adam(dw_model.parameters(), learning_rate * 10)
        resume_iteration = 0
    else:
        model_path = os.path.join(logdir, f'model-{resume_iteration}.pt')
        model = torch.load(model_path)
        dw_model = torch.load(model_path)
        optimizer = torch.optim.Adam(model.parameters(), learning_rate)
        optimizer_dw = torch.optim.Adam(dw_model.parameters(), learning_rate)
        optimizer.load_state_dict(torch.load(os.path.join(logdir, 'last-optimizer-state.pt')))

    scheduler = StepLR(optimizer, step_size=learning_rate_decay_steps, gamma=learning_rate_decay_rate)
    scheduler_dw = StepLR(optimizer_dw, step_size=learning_rate_decay_steps, gamma=learning_rate_decay_rate)
    summary(model)
    SDR, RPA, GNSDR, RCA, it = 0, 0, 0, 0, 0
    loop = tqdm(range(resume_iteration + 1, iterations + 1))

    for i, data in zip(loop, cycle(data_loader)):
        audio_m = data['audio_m'].to(device)
        audio_v = data['audio_v'].to(device)
        pitch_label = data['pitch'].to(device)
        out_audio, pitch_pred, v_pitch_pred, out_spec, spec_v = model(audio_m, audio_v)
        weight_pe, weight_svs = dw_model(torch.stack([pitch_pred, v_pitch_pred], dim=1), scale=scale)

        """动态加权"""
        freqs = pitchlabel2freqs(pitch_label, pitch_th)
        cents = freqs2cents(freqs)
        freqs_pred = pitchlabel2freqs(pitch_pred, pitch_th)
        cents_pred = freqs2cents(freqs_pred)
        v_freqs = pitchlabel2freqs(v_pitch_pred, pitch_th)
        v_cents = freqs2cents(v_freqs)
        pitch_pred_bool = (np.abs(cents_pred - cents) <= 50).astype(np.float32)
        spec_v2pitch_pred_bool = (np.abs(v_cents - cents) <= 50).astype(np.float32)

        index_1 = torch.from_numpy(pitch_pred_bool * spec_v2pitch_pred_bool).unsqueeze(-1).to(device)
        index_2 = torch.from_numpy(pitch_pred_bool * (1 - spec_v2pitch_pred_bool)).unsqueeze(-1).to(device)
        index_3 = torch.from_numpy((1 - pitch_pred_bool) * spec_v2pitch_pred_bool).unsqueeze(-1).to(device)
        index_4 = torch.from_numpy((1 - pitch_pred_bool) * (1 - spec_v2pitch_pred_bool)).unsqueeze(-1).to(device)

        label_svs, label_pe = torch.ones_like(weight_svs, device=device), torch.ones_like(weight_pe, device=device)

        d_pe = weight_pe.shape[-1]
        pe_1 = torch.sum(weight_pe * index_1) / (torch.sum(index_1) * d_pe)
        pe_2 = torch.sum(weight_pe * index_2) / (torch.sum(index_2) * d_pe)
        pe_3 = torch.sum(weight_pe * index_3) / (torch.sum(index_3) * d_pe)
        pe_4 = torch.sum(weight_pe * index_4) / (torch.sum(index_4) * d_pe)
        df_pe.loc[i - 1, columns] = [i, round(pe_1.item(), 2), round(pe_2.item(), 2), round(pe_3.item(), 2),
                                     round(pe_4.item(), 2)]
        df_pe.to_csv(os.path.join(logdir, 'weight_pe.csv'), index=False)

        d_svs = weight_svs.shape[-1]
        svs_1 = torch.sum(weight_svs * index_1) / (torch.sum(index_1) * d_svs)
        svs_2 = torch.sum(weight_svs * index_2) / (torch.sum(index_2) * d_svs)
        svs_3 = torch.sum(weight_svs * index_3) / (torch.sum(index_3) * d_svs)
        svs_4 = torch.sum(weight_svs * index_4) / (torch.sum(index_4) * d_svs)
        df_svs.loc[i - 1, columns] = [i, round(svs_1.item(), 2), round(svs_2.item(), 2), round(svs_3.item(), 2),
                                      round(svs_4.item(), 2)]
        df_svs.to_csv(os.path.join(logdir, 'weight_svs.csv'), index=False)

        sum_1 = np.sum(pitch_pred_bool * spec_v2pitch_pred_bool)
        sum_2 = np.sum(pitch_pred_bool * (1 - spec_v2pitch_pred_bool))
        sum_3 = np.sum((1 - pitch_pred_bool) * spec_v2pitch_pred_bool)
        sum_4 = np.sum((1 - pitch_pred_bool) * (1 - spec_v2pitch_pred_bool))
        df_nums.loc[i - 1, ['itter', '1', '2', '3', '4']] = [i, sum_1, sum_2, sum_3, sum_4]
        df_nums.to_csv(os.path.join(logdir, 'nums.csv'), index=False)
        print(sum_1, '\t', sum_2, '\t', sum_3, '\t', sum_4, end='\t')

        loss_aux_1, loss_aux_2, loss_aux_3, loss_aux_4 = 0, 0, 0, 0
        if index_1.sum().item():
            loss_aux_1 = (torch.abs(weight_svs - label_svs) * index_1).sum() / (index_1.sum() * d_svs) + \
                         (torch.abs(weight_pe - label_pe) * index_1).sum() / (index_1.sum() * d_pe)
        if index_2.sum().item():
            loss_aux_2 = (torch.abs(weight_svs - label_svs) * index_2).sum() / (index_2.sum() * d_svs) + \
                         (-((label_pe - weight_pe).sigmoid().log() * index_2).sum() / (index_2.sum() * d_pe)) * t
        if index_3.sum().item():
            loss_aux_3 = (-((weight_svs - label_svs).sigmoid().log() * index_3).sum() / (index_3.sum() * d_svs)) * t + \
                         (torch.abs(weight_pe - label_pe) * index_3).sum() / (index_3.sum() * d_pe)
        if index_4.sum().item():
            loss_aux_4 = (torch.abs(weight_svs - label_svs) * index_4).sum() / (index_4.sum() * d_svs) + \
                         (-((weight_pe - label_pe).sigmoid().log() * index_4).sum() / (index_4.sum() * d_pe)) * t
        """动态加权"""

        '''5、loss修改'''
        if weight_svs.shape[-1] != hop_length:
            weight_svs = weight_svs.repeat((1, 1, hop_length))
        loss_svs = mae(out_audio, audio_v, torch.flatten(weight_svs, start_dim=1))
        loss_pitch = FL(pitch_pred, pitch_label, alpha, gamma, weight_pe)
        loss_aux = loss_aux_1 + loss_aux_2 + loss_aux_3 + loss_aux_4
        loss_total = loss_svs + loss_pitch + loss_aux
        '''5、loss修改'''
        print(i, end='\t')
        print('loss_svs:{:.6f}'.format(loss_svs.item()), end='\t')
        print('loss_pe:{:.6f}'.format(loss_pitch.item()), end='\t')
        print('loss_aux:{:.6f}'.format(loss_aux.item()))

        optimizer.zero_grad()
        optimizer_dw.zero_grad()
        loss_total.backward()
        if clip_grad_norm:
            clip_grad_norm_(model.parameters(), clip_grad_norm)
            clip_grad_norm_(dw_model.parameters(), clip_grad_norm)
        optimizer.step()
        optimizer_dw.step()
        scheduler.step()
        scheduler_dw.step()

        writer.add_scalar('loss/loss_total', loss_total.item(), global_step=i)
        writer.add_scalar('loss/loss_svs', loss_svs.item(), global_step=i)
        writer.add_scalar('loss/loss_pe', loss_pitch.item(), global_step=i)

        if i % epoch_nums == 0:
            print('*' * 50)
            print(i, '\t', epoch_nums)
            model.eval()
            with torch.no_grad():
                # dataset, model, batch_size, hop_length, seq_l, device, path=None, pitch_th=0.5
                '''6、验证方法'''
                metrics = evaluate(valid_dataset, model, batch_size, hop_length, seq_l, device, None, pitch_th)
                '''6、验证方法'''
                for key, value in metrics.items():
                    writer.add_scalar('validation/' + key, np.mean(value), global_step=i)
                gnsdr = np.round((np.sum(metrics["NSDR_W"]) / np.sum(metrics["LENGTH"])), 2)
                writer.add_scalar('validation/GNSDR', gnsdr, global_step=i)
                sdr = np.round(np.mean(metrics['SDR']), 2)
                rpa = np.round(np.mean(metrics['RPA']) * 100, 2)
                rca = np.round(np.mean(metrics['RCA']) * 100, 2)
                if sdr + rpa >= SDR + RPA:
                    SDR, GNSDR, RPA, RCA, it = sdr, gnsdr, rpa, rca, i
                    with open(os.path.join(logdir, 'result.txt'), 'a') as f:
                        f.write(str(i) + '\t')
                        f.write(str(SDR) + '\t')
                        f.write(str(GNSDR) + '\t')
                        f.write(str(RPA) + '\t')
                        f.write(str(RCA) + '\n')
                    torch.save(model, os.path.join(logdir, f'model-{i}.pt'))
                    torch.save(dw_model, os.path.join(logdir, 'dw_model.pt'))
                    torch.save(optimizer.state_dict(), os.path.join(logdir, 'last-optimizer-state.pt'))
            model.train()

        if i % (epoch_nums * 10) == 0 or i == epoch_nums:
            mox.file.copy_parallel('./runs', output_data)
            for file in os.listdir(logdir):
                if file.endswith('.pt'):
                    os.remove(os.path.join(logdir, file))

        if i - it >= epoch_nums * early_stop_epochs:
            mox.file.copy_parallel('./runs', output_data)
            for file in os.listdir(logdir):
                if file.endswith('.pt'):
                    os.remove(os.path.join(logdir, file))
            break


for t in [1, 5, 10]:
    train('ByteSep', 'CREPE', 'more', t)
    # train('ByteSep', 'CREPE', 'one', t)
# mox.file.copy_parallel('./runs', output_data)
#
# train('ByteSep', 'CREPE', 'one')
# mox.file.copy_parallel('./runs', output_data)
