import os
# import sys

# import moxing as mox
#
# os.system('pip install --upgrade pip')
# os.system('pip install sacred')
# os.system('pip install tqdm')
# os.system('pip install tensorboard')
# os.system('pip install mir_eval')
# os.system('pip install soundfile')
# os.system('pip install torchlibrosa')
# os.system('pip install librosa')
# os.system('pip install einops')
# current_path = os.path.dirname(__file__)
# sys.path.append(current_path)

import pandas as pd
import torch
from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import numpy as np
import torch.nn.functional as F

'''1、函数导入'''
from src import MyDataset, cycle, summary, CM, FL, mae, pitchlabel2freqs, freqs2cents
from evaluate import evaluate

'''1、函数导入'''

# base = '/cache/whj/'
# os.makedirs(base, exist_ok=True)
# input_data = 'obs://bucket-5192-shanghai/weihaojie/CM_Music/code_END3/'
# full_data = "/cache/whj/"
# mox.file.remove(full_data, recursive=True)
# mox.file.copy_parallel(input_data, full_data)
output_data = 'obs://bucket-5192-shanghai/weihaojie/CM_Music_0/Result'


def train_J1(svs_type, pe_type, p_weight, n_weight):
    """
    动态加权，超参数调节权重
    naive DWHS
    """
    # svs = 'ByteSep'
    # pe = 'CREPE'
    alpha = 5
    gamma = 2
    logdir = 'runs/MDB/' + svs_type + '_' + pe_type + '_J1_' + str(p_weight) + '_' + str(n_weight)
    num_class = 360 if pe_type.lower() == 'crepe' else 352
    seq_l = 2.56
    pitch_th = 0.5
    hop_length = 320
    learning_rate = 1e-3
    batch_size = 16
    clip_grad_norm = 3
    learning_rate_decay_rate = 0.95
    learning_rate_decay_epochs = 3
    train_epochs = 250
    early_stop_epochs = 25
    data_para = False

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    columns = ['itter', '1', '2', '3', '4']
    df_pe = pd.DataFrame(columns=columns)
    df_svs = pd.DataFrame(columns=columns)
    df_nums = pd.DataFrame(columns=columns)

    train_dataset = MyDataset(path='./dataset/MDB', hop_length=hop_length, groups=['train'],
                              sequence_length=seq_l, num_class=num_class)
    print('train nums:', len(train_dataset))
    valid_dataset = MyDataset(path='./dataset/MDB', hop_length=hop_length, groups=['test'],
                              sequence_length=None, num_class=num_class)
    print('valid nums:', len(valid_dataset))
    data_loader = DataLoader(train_dataset, batch_size, shuffle=True, drop_last=True)
    epoch_nums = len(data_loader)
    print('epoch_nums:', epoch_nums)
    learning_rate_decay_steps = len(data_loader) * learning_rate_decay_epochs
    iterations = epoch_nums * train_epochs

    resume_iteration = None
    os.makedirs(logdir, exist_ok=True)
    writer = SummaryWriter(logdir)

    if resume_iteration is None:
        '''4、模型设置'''
        model = CM(1, 1024, hop_length, svs_type, pe_type)
        '''4、模型设置'''
        if data_para:
            model = nn.DataParallel(model).to(device)
        else:
            model = model.to(device)
        optimizer = torch.optim.Adam(model.parameters(), learning_rate)
        resume_iteration = 0
    else:
        model_path = os.path.join(logdir, f'model-{resume_iteration}.pt')
        model = torch.load(model_path)
        optimizer = torch.optim.Adam(model.parameters(), learning_rate)
        optimizer.load_state_dict(torch.load(os.path.join(logdir, 'last-optimizer-state.pt')))

    scheduler = StepLR(optimizer, step_size=learning_rate_decay_steps, gamma=learning_rate_decay_rate)
    summary(model)
    SDR, RPA, GNSDR, RCA, it = 0, 0, 0, 0, 0
    loop = tqdm(range(resume_iteration + 1, iterations + 1))

    for i, data in zip(loop, cycle(data_loader)):
        audio_m = data['audio_m'].to(device)
        audio_v = data['audio_v'].to(device)
        pitch_label = data['pitch'].to(device)
        # out_audio, pitch_pred, v_pitch_pred, spec_m, spec_v, F.l1_loss(out_audio, audio_v)
        out_audio, pitch_pred, v_pitch_pred, out_spec, spec_v = model(audio_m, audio_v)

        """动态加权"""
        freqs = pitchlabel2freqs(pitch_label, pitch_th)
        cents = freqs2cents(freqs)
        freqs_pred = pitchlabel2freqs(pitch_pred, pitch_th)
        cents_pred = freqs2cents(freqs_pred)
        v_freqs = pitchlabel2freqs(v_pitch_pred, pitch_th)
        v_cents = freqs2cents(v_freqs)

        pitch_pred_bool = (np.abs(cents_pred - cents) <= 50).astype(np.float)
        spec_v2pitch_pred_bool = (np.abs(v_cents - cents) <= 50).astype(np.float)

        pt = (pitch_label * pitch_pred + (1 - pitch_label) * (1 - pitch_pred)).detach()
        d_weight = torch.clamp(1 / pt, 1, p_weight)

        weight_pe = torch.from_numpy((1 - pitch_pred_bool) * (1 - spec_v2pitch_pred_bool)).to(device)
        weight_pe = weight_pe.unsqueeze(-1) * d_weight
        weight_pe[weight_pe == 0] = 1

        weight_svs = torch.from_numpy((1 - pitch_pred_bool) * spec_v2pitch_pred_bool).to(device)
        weight_svs = (weight_svs * torch.mean(d_weight, dim=-1)).unsqueeze(-1)
        weight_svs[weight_svs == 0] = 1
        weight_svs = weight_svs.repeat((1, 1, hop_length))

        noise_index = torch.from_numpy(1 - pitch_pred_bool * (1 - spec_v2pitch_pred_bool)).to(device)
        weight_pe = noise_index.unsqueeze(-1) * weight_pe
        weight_pe[weight_pe == 0] = n_weight

        index_1 = torch.from_numpy(pitch_pred_bool * spec_v2pitch_pred_bool).unsqueeze(-1).to(device)
        index_2 = torch.from_numpy(pitch_pred_bool * (1 - spec_v2pitch_pred_bool)).unsqueeze(-1).to(device)
        index_3 = torch.from_numpy((1 - pitch_pred_bool) * spec_v2pitch_pred_bool).unsqueeze(-1).to(device)
        index_4 = torch.from_numpy((1 - pitch_pred_bool) * (1 - spec_v2pitch_pred_bool)).unsqueeze(-1).to(device)
        pe_1 = torch.sum(weight_pe * index_1) / (torch.sum(index_1) * num_class)
        pe_2 = torch.sum(weight_pe * index_2) / (torch.sum(index_2) * num_class)
        pe_3 = torch.sum(weight_pe * index_3) / (torch.sum(index_3) * num_class)
        pe_4 = torch.sum(weight_pe * index_4) / (torch.sum(index_4) * num_class)
        df_pe.loc[i - 1, columns] = [i, round(pe_1.item(), 2), round(pe_2.item(), 2), round(pe_3.item(), 2),
                                     round(pe_4.item(), 2)]
        df_pe.to_csv(os.path.join(logdir, 'weight_pe.csv'), index=False)

        svs_1 = torch.sum(weight_svs * index_1) / (torch.sum(index_1) * hop_length)
        svs_2 = torch.sum(weight_svs * index_2) / (torch.sum(index_2) * hop_length)
        svs_3 = torch.sum(weight_svs * index_3) / (torch.sum(index_3) * hop_length)
        svs_4 = torch.sum(weight_svs * index_4) / (torch.sum(index_4) * hop_length)
        df_svs.loc[i - 1, columns] = [i, round(svs_1.item(), 2), round(svs_2.item(), 2), round(svs_3.item(), 2),
                                      round(svs_4.item(), 2)]
        df_svs.to_csv(os.path.join(logdir, 'weight_svs.csv'), index=False)

        sum_1 = np.sum(pitch_pred_bool * spec_v2pitch_pred_bool)
        sum_2 = np.sum(pitch_pred_bool * (1 - spec_v2pitch_pred_bool))
        sum_3 = np.sum((1 - pitch_pred_bool) * spec_v2pitch_pred_bool)
        sum_4 = np.sum((1 - pitch_pred_bool) * (1 - spec_v2pitch_pred_bool))
        df_nums.loc[i - 1, columns] = [i, sum_1, sum_2, sum_3, sum_4]
        df_nums.to_csv(os.path.join(logdir, 'nums.csv'), index=False)
        print(sum_1, '\t', sum_2, '\t', sum_3, '\t', sum_4, end='\t')
        """动态加权"""

        '''5、loss修改'''
        weight_svs = torch.flatten(weight_svs, start_dim=1)
        loss_svs = mae(out_audio, audio_v, weight_svs)
        loss_pitch = FL(pitch_pred, pitch_label, alpha, gamma, weight_pe)
        loss_total = loss_svs + loss_pitch
        '''5、loss修改'''
        print(i, end='\t')
        print('loss_total:{:.6f}'.format(loss_total.item()), end='\t')
        print('loss_svs:{:.6f}'.format(loss_svs.item()), end='\t')
        print('loss_pe:{:.6f}'.format(loss_pitch.item()))

        optimizer.zero_grad()
        loss_total.backward()
        if clip_grad_norm:
            clip_grad_norm_(model.parameters(), clip_grad_norm)
        optimizer.step()
        scheduler.step()

        writer.add_scalar('loss/loss_total', loss_total.item(), global_step=i)
        writer.add_scalar('loss/loss_svs', loss_svs.item(), global_step=i)
        writer.add_scalar('loss/loss_pe', loss_pitch.item(), global_step=i)

        if i % epoch_nums == 0:
            print('*' * 50)
            print(i, '\t', epoch_nums)
            model.eval()
            with torch.no_grad():
                # dataset, model, batch_size, hop_length, seq_l, device, path=None, pitch_th=0.5
                '''6、验证方法'''
                metrics = evaluate(valid_dataset, model, batch_size, hop_length, seq_l, device, None, 0)
                '''6、验证方法'''
                for key, value in metrics.items():
                    writer.add_scalar('validation/' + key, np.mean(value), global_step=i)
                gnsdr = np.round((np.sum(metrics["NSDR_W"]) / np.sum(metrics["LENGTH"])), 2)
                writer.add_scalar('validation/GNSDR', gnsdr, global_step=i)
                sdr = np.round(np.mean(metrics['SDR']), 2)
                rpa = np.round(np.mean(metrics['RPA']) * 100, 2)
                rca = np.round(np.mean(metrics['RCA']) * 100, 2)
                if sdr + rpa >= SDR + RPA:
                    SDR, GNSDR, RPA, RCA, it = sdr, gnsdr, rpa, rca, i
                    with open(os.path.join(logdir, 'result.txt'), 'a') as f:
                        f.write(str(i) + '\t')
                        f.write(str(SDR) + '\t')
                        f.write(str(GNSDR) + '\t')
                        f.write(str(RPA) + '\t')
                        f.write(str(RCA) + '\n')
                    torch.save(model, os.path.join(logdir, f'model-{i}.pt'))
                    torch.save(optimizer.state_dict(), os.path.join(logdir, 'last-optimizer-state.pt'))
            model.train()

        # if i % (epoch_nums * 10) == 0 or i == epoch_nums:
        #     mox.file.copy_parallel('./runs', output_data)
        #     for file in os.listdir(logdir):
        #         if file.endswith('.pt'):
        #             os.remove(os.path.join(logdir, file))

        if i - it >= epoch_nums * early_stop_epochs:
            # mox.file.copy_parallel('./runs', output_data)
            break


train_J1('UNet', 'CREPE', 5, 0.2)
train_J1('ByteSep', 'CREPE', 5, 0.2)

# mox.file.copy_parallel('./runs', output_data)
