import os

import torch
from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import numpy as np

'''1、函数导入'''
from src import MyDataset, cycle, summary, CM, FL, mae
from evaluate import evaluate
'''1、函数导入'''


def train_J0(svs_type, pe_type):
    """
    简单联合学习
    naive joint learning
    """
    # svs = 'UNet'  # 'ByteSep'  #
    # pe = 'HARMOF0'  # 'CREPE'  #
    alpha = 5
    gamma = 2
    logdir = 'runs/MDB/' + svs_type + '_' + pe_type + '_test'
    num_class = 360 if pe_type.lower() == 'crepe' else 352
    seq_l = 2.56
    pitch_th = 0.5
    hop_length = 320
    learning_rate = 1e-3
    batch_size = 16
    clip_grad_norm = 3
    learning_rate_decay_rate = 0.95
    learning_rate_decay_epochs = 3
    train_epochs = 250
    early_stop_epochs = 25
    data_para = True

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    train_dataset = MyDataset(path='./dataset/MDB', hop_length=hop_length, groups=['train'],
                              sequence_length=seq_l, num_class=num_class)
    print('train nums:', len(train_dataset))
    valid_dataset = MyDataset(path='./dataset/MDB', hop_length=hop_length, groups=['train'],
                                   sequence_length=None, num_class=num_class)
    print('valid nums:', len(valid_dataset))
    data_loader = DataLoader(train_dataset, batch_size, shuffle=True, drop_last=True)
    epoch_nums = len(data_loader)
    print('epoch_nums:', epoch_nums)
    learning_rate_decay_steps = len(data_loader) * learning_rate_decay_epochs
    iterations = epoch_nums * train_epochs

    resume_iteration = None
    os.makedirs(logdir, exist_ok=True)
    writer = SummaryWriter(logdir)

    if resume_iteration is None:
        '''4、模型设置'''
        model = CM(1, 1024, hop_length, svs_type, pe_type)
        '''4、模型设置'''
        if data_para:
            model = nn.DataParallel(model).to(device)
        else:
            model = model.to(device)
        optimizer = torch.optim.Adam(model.parameters(), learning_rate)
        resume_iteration = 0
    else:
        model_path = os.path.join(logdir, f'model-{resume_iteration}.pt')
        model = torch.load(model_path)
        optimizer = torch.optim.Adam(model.parameters(), learning_rate)
        optimizer.load_state_dict(torch.load(os.path.join(logdir, 'last-optimizer-state.pt')))

    scheduler = StepLR(optimizer, step_size=learning_rate_decay_steps, gamma=learning_rate_decay_rate)
    summary(model)
    SDR, RPA, GNSDR, RCA, it = 0, 0, 0, 0, 0
    loop = tqdm(range(resume_iteration + 1, iterations + 1))

    for i, data in zip(loop, cycle(data_loader)):
        audio_m = data['audio_m'].to(device)
        audio_v = data['audio_v'].to(device)
        pitch_label = data['pitch'].to(device)
        out_audio, pitch_pred, v_pitch_pred, out_spec, spec_v = model(audio_m, audio_v)
        '''5、loss修改'''
        loss_svs = mae(out_audio, audio_v)
        loss_pitch = FL(pitch_pred, pitch_label, alpha, gamma)
        # loss_total = loss_svs
        loss_total = loss_svs + loss_pitch
        '''5、loss修改'''
        print(i, end='\t')
        print('loss_total:{:.6f}'.format(loss_total.item()), end='\t')
        print('loss_svs:{:.6f}'.format(loss_svs.item()), end='\t')
        print('loss_pe:{:.6f}'.format(loss_pitch.item()))

        optimizer.zero_grad()
        loss_total.backward()
        if clip_grad_norm:
            clip_grad_norm_(model.parameters(), clip_grad_norm)
        optimizer.step()
        scheduler.step()

        writer.add_scalar('loss/loss_total', loss_total.item(), global_step=i)
        writer.add_scalar('loss/loss_svs', loss_svs.item(), global_step=i)
        writer.add_scalar('loss/loss_pe', loss_pitch.item(), global_step=i)

        if i % epoch_nums == 0:
            print('*' * 50)
            print(i, '\t', epoch_nums)
            model.eval()
            with torch.no_grad():
                # dataset, model, batch_size, hop_length, seq_l, device, path=None, pitch_th=0.5
                '''6、验证方法'''
                metrics = evaluate(valid_dataset, model, batch_size, hop_length, seq_l, device, None, pitch_th=pitch_th)
                '''6、验证方法'''
                for key, value in metrics.items():
                    writer.add_scalar('validation/' + key, np.mean(value), global_step=i)
                gnsdr = np.round(np.sum(metrics["NSDR_W"]) / np.sum(metrics["LENGTH"]), 2)
                writer.add_scalar('validation/GNSDR', gnsdr, global_step=i)
                sdr = np.round(np.mean(metrics['SDR']), 2)
                rpa = np.round(np.mean(metrics['RPA']) * 100, 2)
                rca = np.round(np.mean(metrics['RCA']) * 100, 2)
                if sdr + rpa >= SDR + RPA:
                # if sdr >= SDR:
                    SDR, GNSDR, RPA, RCA, it = sdr, gnsdr, rpa, rca, i
                    with open(os.path.join(logdir, 'result.txt'), 'a') as f:
                        f.write(str(i) + '\t')
                        f.write(str(SDR) + '\t')
                        f.write(str(GNSDR) + '\t')
                        f.write(str(RPA) + '\t')
                        f.write(str(RCA) + '\n')
                    torch.save(model, os.path.join(logdir, f'model-{i}.pt'))
                    torch.save(optimizer.state_dict(), os.path.join(logdir, 'last-optimizer-state.pt'))
            model.train()

        if i - it >= epoch_nums * early_stop_epochs:
            break


train_J0('ByteSep', 'CREPE')
