import os
import torch
import torch.nn as nn
import os.path as osp
import torch.nn.functional as F
from tqdm.auto import tqdm
from utils import *
from svdiffusion import SVDiffusion
from unet1d import Unet1d
import numpy as np
import csv
from data import monitor_erp
from torch.utils.data import DataLoader
from itertools import chain
import matplotlib.pyplot as plt
from scipy.signal import butter, filtfilt



class Trainer():
    def __init__(self, file, mask_ratio, epochs, diffuser, train_loader, device=None,  pretrain_path=None):
        # super().__init__()
        self.device = device
        self.diffuser = diffuser
        self.T = self.diffuser.time_steps
        self.forward_diffusion_sample = self.diffuser.forward
        self.unet = self.diffuser.model
        self.sampler = self.diffuser.sampling_sequence
     
        self.model_save_dir = f"results_monitor_erp_iclr/{pretrain_path}/{file}"
        # self.model_save_dir = f"results_order/{pretrain_path}"
        if not os.path.exists(self.model_save_dir):
            os.makedirs(self.model_save_dir)
        

        self.epochs = epochs
        self.train_loader = train_loader

        self.optimizer = torch.optim.AdamW(self.unet.parameters(), lr=5e-4, weight_decay=1e-2)
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=self.epochs*len(self.train_loader))

    def get_loss(self, U, t):
        U_noisy, noise = self.forward_diffusion_sample(U, t)
        noise_pred = self.unet(U_noisy, t)
        return F.l1_loss(noise, noise_pred)

    def save_model_weight(self, epoch):
        torch.save({
            'unet': self.unet.state_dict()
        }, f'{self.model_save_dir}/model_{epoch}.pt')

    def train(self):
        logfilename = osp.join(self.model_save_dir, 'train.log')
        logger = get_logger(logfilename)
        ravg = RunningAverage()
    

        for epoch in tqdm(range(1, self.epochs+1)):
            epoch_loss = 0
            for data in self.train_loader:
                self.optimizer.zero_grad()
                eeg = data.float()
                t = torch.randint(0, self.T, (data.shape[0],)).to(self.device).long()

                loss = self.get_loss(eeg, t)
                loss.backward()
                epoch_loss += loss.item()
                self.optimizer.step()
                self.scheduler.step()

            if epoch % 50 == 0:
                self.save_model_weight(epoch)

            line = f'epoch_loss: {epoch_loss} '
            line += ravg.info()
            logger.info(line)
            ravg.reset()
    
    def calculate(self, eval_loader):
        model_path = f'{self.model_save_dir}/model_50.pt'
        with torch.no_grad():
            for data in eval_loader:
                eeg = data.float()
                B = eeg.shape[0]
                batch = get_batch(eeg, mask_ratio=args.mask, dataset='monitor')
                xc = batch.xc.to(self.device)
                batch.xt = batch.xt.to(self.device)
                _, S, VT = torch.linalg.svd(xc, full_matrices=True)
                VT = VT[:, :64, :]

                x_shape = (B, 95, 95)
                U0 = torch.randn(x_shape)
                U = self.sampler(batch, VT, model_path, U0)

                xt_pred = torch.matmul(U, VT)
                print("xt_pred:", xt_pred.shape)
                print(batch.yt.shape)

                indices = torch.arange(eeg.shape[0]).unsqueeze(1).expand(-1, batch.yt.shape[1]).to(self.device)
                xt_true = batch.xt[indices, batch.yt.to(self.device)]
                xt_pred = xt_pred[indices, batch.yt.to(self.device)]
                pcc = compute_pcc(xt_true, xt_pred)
                rrmse = compute_rrmse(xt_true, xt_pred)
                rsnr = compute_rsnr(xt_true, xt_pred)
                ssim = compute_ssim(xt_true, xt_pred)
           
        return pcc, rrmse, rsnr, ssim
    
    def calculate_no_finetune(self, eval_loader):
        model_path = None
        with torch.no_grad():
            for data in eval_loader:
                eeg = data.float()
                B = eeg.shape[0]
                batch = get_batch(eeg, mask_ratio=args.mask, dataset='monitor')
                xc = batch.xc.to(self.device)
                batch.xt = batch.xt.to(self.device)
                _, S, VT = torch.linalg.svd(xc, full_matrices=True)
                VT = VT[:, :64, :]

                x_shape = (B, 95, 95)
                U0 = torch.randn(x_shape)
                U = self.sampler(batch, VT, model_path, U0)

                xt_pred = torch.matmul(U, VT)
                print("xt_pred:", xt_pred.shape)

                indices = torch.arange(eeg.shape[0]).unsqueeze(1).expand(-1, batch.yt.shape[1]).to(self.device)
                xt_true = batch.xt[indices, batch.yt.to(self.device)]
                xt_pred = xt_pred[indices, batch.yt.to(self.device)]
                pcc = compute_pcc(xt_true, xt_pred)
                rrmse = compute_rrmse(xt_true, xt_pred)
                rsnr = compute_rsnr(xt_true, xt_pred)
                ssim = compute_ssim(xt_true, xt_pred)
          
        return pcc, rrmse, rsnr, ssim

    

def eeg_to_u(data):
    standard_channel_list = ['Fp1', 'Fpz', 'Fp2', 'Fp9', 'Fp10', 'Nz', 'AF1', 'AF2', 'AFz', 'AF3', 'AF4', 'AF5', 'AF6', 
                         'AF7', 'AF8', 'AF9', 'AF10', 'F1', 'F2', 'Fz', 'F3', 'F4', 'F5', 'F6', 'F7', 'F8', 'F9', 'F10',
                         'FC1', 'FC2', 'FCz', 'FC3', 'FC4', 'FC5', 'FC6', 'FT7', 'FT8', 'FT9', 'FT10', 'C1', 'C2', 'Cz',
                         'C3', 'C4', 'C5', 'C6', 'T7', 'T8', 'T9', 'T10', 'I1', 'I2', 'CP1', 'CP2', 'CPz', 'CP3', 'CP4',
                         'CP5', 'CP6', 'TP7', 'TP8', 'TP9', 'TP10', 'P1', 'P2', 'Pz', 'P3', 'P4', 'P5', 'P6', 'P7', 'P8',
                         'P9', 'P10', 'PO1', 'PO2', 'POz', 'PO3', 'PO4', 'PO5', 'PO6', 'PO7', 'PO8', 'PO9', 'PO10',
                         'O1', 'O2', 'Oz', 'O9', 'O10', 'Iz', 'CB1', 'CB2', 'A1', 'A2']

    monitor_channel_list = ['Fp1', 'AF7', 'AF3', 'F1', 'F3', 'F5', 'F7', 'FT7', 'FC5', 'FC3', 'FC1', 'C1','C3','C5','T7'
                  ,'TP7','CP5','CP3','CP1','P1','P3','P5','P7','P9','PO7','PO3','O1','Iz','Oz','POz','Pz','CPz','Fpz','Fp2','AF8','AF4'
                  ,'AFz','Fz','F2','F4','F6','F8','FT8','FC6','FC4','FC2','FCz','Cz','C2','C4','C6','T8','TP8','CP6','CP4','CP2','P2','P4','P6','P8','P10','PO8','PO4','O2']
    
    standard_channel_list = [s.upper() for s in standard_channel_list]
    monitor_channel_list = [s.upper() for s in monitor_channel_list]
    reordered_index_map = [monitor_channel_list.index(ch) for ch in standard_channel_list if ch in monitor_channel_list]
    new_channels = [ch for ch in standard_channel_list if ch in monitor_channel_list]

    channel_index =  [standard_channel_list.index(item) for item in new_channels]
    channel_index = torch.tensor(list(channel_index))
    
    all_channel_embedding = torch.eye(len(standard_channel_list), requires_grad=False).cuda()
    data = data.float().cuda()
    eeg = data[:, reordered_index_map, :]
    _, S, VT = torch.linalg.svd(eeg, full_matrices=True)
    VT = VT[:, :eeg.shape[1], :]
    VT_inverse = VT.transpose(2, 1)
    channel_embedding = all_channel_embedding[channel_index]
    channel_embedding = channel_embedding.unsqueeze(0).expand(eeg.shape[0], -1, -1)
    right_channel_embedding = channel_embedding.cuda()
    left_channel_embedding = channel_embedding.transpose(2, 1).cuda()
    U = torch.matmul(eeg, VT_inverse)
    U = torch.matmul(left_channel_embedding, U)
    U = torch.matmul(U, right_channel_embedding)
    return U

class Dataset():
    def __init__(self, U):
        self.U = U
        self. len = U.shape[0]
    def __len__(self):
        return self.len
    def __getitem__(self, index):
        return self.U[index]


import argparse

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    # parser.add_argument('--time', type=int, default=2, choices=[2, 3, 4, 5, 10, 20])
    parser.add_argument('--mask', type=float, default=0.3, choices=[0.1, 0.3, 0.5])
    parser.add_argument("--dim_mults", default=(1, 2, 3))
    parser.add_argument("--init_dim", default=128, type=int)
    args = parser.parse_args()
    file_list = [i for i in range(1, 7)]
    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")

    channels = 95
    
    device = torch.device("cuda:0")
    pretrain_path = f'results_iclr/pretrain_40epoch/model_40.pt'
    PCC = []
    RRMSE = []
    RSNR = []
    SSIM = []
    F_ori = []
    F_recon = []
    TRUE = []
    PRED = []
    for file in file_list:
        print(file)
        data = monitor_erp(file, type='train')
        test_data = monitor_erp(file, type='test')
        U = eeg_to_u(data.eeg)
        train_data = Dataset(U)
        train_loader = DataLoader(train_data, batch_size=2048, shuffle=True)
        test_loader = DataLoader(test_data, batch_size=len(test_data), shuffle=False)
        unet = Unet1d(dim=args.init_dim, T=1000, channels=channels, dim_mults=args.dim_mults).to(device)
        unet = nn.DataParallel(unet, device_ids=[0, 1, 2, 3])
        unet.load_state_dict(torch.load(pretrain_path, map_location=device)['unet'])
        diffuser = SVDiffusion(time_steps=1000, unet=unet, w=2, device=device) # change 'time_steps' to 100 when sampling
        diffuser_trainer = Trainer(file, mask_ratio=0.4, epochs=50, diffuser=diffuser, train_loader=train_loader, 
                                   device=device, pretrain_path=f'pretrain_40epoch')
      
        diffuser_trainer.train()


        # pcc, rrmse, rsnr, ssim = diffuser_trainer.calculate_no_finetune(test_loader)
    #     pcc, rrmse, rsnr, ssim = diffuser_trainer.calculate(test_loader)
    #     PCC.append(pcc)
    #     RRMSE.append(rrmse)
    #     RSNR.append(rsnr)
    #     SSIM.append(ssim)
    #
    #
    # print("------------")
    # print("PCC:", np.mean(PCC), np.std(PCC))
    # print("RRMSE:", np.mean(RRMSE), np.std(RRMSE))
    # print("RSNR:", np.mean(RSNR), np.std(RSNR))
    # print("SSIM:", np.mean(SSIM), np.std(SSIM))
    




        
