import os
import warnings
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from utils import *


warnings.filterwarnings('ignore')
abs_path = os.path.dirname(os.path.realpath(__file__))




class TSSR(nn.Module):

    def __init__(self, gen_model_s, gen_model_t, decomp_f, decomp_c, sr_ratio, task_type, n_step, device= "cuda"):
        super().__init__()
        assert task_type in ["SSR", "ASR"]
        self.gen_model_s = gen_model_s      # v predict model for s
        self.gen_model_t = gen_model_t      # v predict model for t
        self.decomp_f = decomp_f            # decomposition model, fine-grained
        self.decomp_c = decomp_c            # decomposition model, coarse-grained
        self.sr_ratio = sr_ratio
        print("scale factor : ", self.sr_ratio)
        self.task_type = task_type
        self.n_step = n_step                # sampling steps
        self.device = device
        print("use device : ", self.device)
    
    
    def forward(self, d, lr, target_mask):
        # d: residual sequence, [B, T, D]
        # lr: low-resolution sequence, [B, T, D]
        d = d.to(self.device)
        B = d.shape[0]

        # decomposition
        d_s, d_t = self.decomp_f(d)
        lr_s, lr_t = self.decomp_c(lr)

        # rectified flow 
        cond_s, cond_t = lr_s, lr_t
        v_s_hat, v_s, itf_s = self.gen_model_s(d_s, [cond_s, cond_t], lr, target_mask)
        v_s_hat = v_s_hat.to(self.device)
        itf_s = itf_s.to(self.device)
        v_t_hat, v_t, itf_t = self.gen_model_t(d_t, [cond_s, cond_t], lr, target_mask)
        v_t_hat = v_t_hat.to(self.device)
        itf_t = itf_t.to(self.device)
        return v_s, v_t, v_s_hat, v_t_hat, itf_s, itf_t


    def fit(self, train_data, sample_drift, args, verbose= True):
        # train data: [N, T, D]
        batch_size = args.batch_size
        n_epochs= args.epoch
        train_dataset = TensorDataset(torch.from_numpy(train_data).to(torch.float))
        train_loader = DataLoader(
            dataset= train_dataset,
            batch_size= min(batch_size, len(train_dataset)),
            shuffle= True,
            drop_last= True,
        )
        reconstruct_loss = nn.MSELoss()
        optimizer_s = torch.optim.Adam(self.gen_model_s.parameters(), lr= 0.001)
        optimizer_t = torch.optim.Adam(self.gen_model_t.parameters(), lr= 0.001)
        scheduler_s = torch.optim.lr_scheduler.StepLR(optimizer_s, step_size= 10, gamma= 0.8)
        scheduler_t = torch.optim.lr_scheduler.StepLR(optimizer_t, step_size= 10, gamma= 0.8)

        loss_log = []
        for epoch in range(n_epochs):
            cum_loss = 0        
            cum_loss_s, cum_loss_t, cum_loss_re = 0, 0, 0        
            n_epoch_iters = 0
            for batch in train_loader :
                hr = batch[0].to(self.device)                               # [B, T, D]
                hr_s, hr_t = self.decomp_f(hr)
                lr_t_idx = range(sample_drift, hr.shape[1], self.sr_ratio)      
                if self.task_type == "SSR" :
                    lr = hr[ : , lr_t_idx, : ]                                  # [B, T', D]
                else :
                    start_idx = torch.cat([torch.tensor([-1]), torch.from_numpy(np.array(lr_t_idx[:-1]))]) + 1
                    l = []
                    for i in range(len(lr_t_idx)):
                        s, e = start_idx[i], lr_t_idx[i] + 1   
                        avg = hr[:, s:e, :].mean(dim=1)
                        l.append(avg)
                    lr = torch.stack(l, dim=1)
                target_mask = np.zeros(hr.shape[1], dtype=int)
                target_mask[lr_t_idx] = 1
                lr_impu = lr_linear_impu(lr, hr.shape[1], self.device)      # [B, T, D]
                d = hr - lr_impu                                            # [B, T, D]
                optimizer_s.zero_grad()
                optimizer_t.zero_grad()
                v_s, v_t, v_s_hat, v_t_hat, itf_s, itf_t = self.forward(d, lr, target_mask)
                loss_s = vpred_loss(v_s_hat, v_s)
                loss_t = vpred_loss(v_t_hat, v_t)
                if args.re_loss :
                    loss_re = reconstruct_loss(itf_s, hr_s).mean() + reconstruct_loss(itf_t, hr_t).mean()
                    loss = loss_s + loss_t + loss_re
                else :
                    loss = loss_s + loss_t
                loss.backward()
                optimizer_s.step()
                optimizer_t.step()
                cum_loss += loss.item()
                cum_loss_s += loss_s.item()
                cum_loss_t += loss_t.item()
                if args.re_loss :
                    cum_loss_re += loss_re.item()
                n_epoch_iters += 1
            scheduler_s.step()
            scheduler_t.step()
            cum_loss /= n_epoch_iters
            cum_loss_s /= n_epoch_iters
            cum_loss_t /= n_epoch_iters
            if args.re_loss :
                cum_loss_re /= n_epoch_iters
            loss_log.append(cum_loss)
            if verbose and epoch % 10 == 0 :
                if args.re_loss :
                    print("Epoch {:03d} : {:.6f} = ({:.6f} + {:.6f} + {:.6f})".format(epoch, cum_loss, cum_loss_s, cum_loss_t, cum_loss_re))
                else :
                    print("Epoch {:03d} : {:.6f} = ({:.6f} + {:.6f})".format(epoch, cum_loss, cum_loss_s, cum_loss_t))
        return loss_log
    

    def generation(self, lr, sr_times, batch_size, sample_drift, target_mask= None):
        test_dataset = TensorDataset(torch.from_numpy(lr).to(torch.float))
        test_loader = DataLoader(
            dataset= test_dataset,
            batch_size= min(batch_size, len(test_dataset)),
            shuffle= False,
            drop_last= False,
        )
        gen_seq_b, gen_s_b, gen_t_b = [], [], []
        cond_s_b, cond_t_b = [], []
        for batch in test_loader :
            lr = batch[0].to(self.device)
            B, T, D = lr.shape[0], (lr.shape[1] - 1) * sr_times + 1, lr.shape[2]
            lr_impu = lr_linear_impu(lr, T, self.device)
            lr_s, lr_t = self.decomp_c(lr)
            cond_s, cond_t = lr_s.to(self.device), lr_t.to(self.device)
            xs_0 = torch.randn(B, T, D).to(self.device)
            xt_0 = torch.randn(B, T, D).to(self.device)

            target_mask = np.zeros(T, dtype=int)
            target_mask[sample_drift::sr_times] = 1

            gen_s, cond_itf_s = self.gen_model_s.generation(xs_0, lr, [cond_s, cond_t], self.n_step, target_mask)
            gen_t, cond_itf_t = self.gen_model_t.generation(xt_0, lr, [cond_s, cond_t], self.n_step, target_mask)
            gen_diff = get_diff(gen_s.detach().cpu(), gen_t.detach().cpu(), lr, self.sr_ratio, target_mask, args.refine)

            gen_seq_b.append((lr_impu.detach().cpu() + gen_diff))
            gen_s_b.append(gen_s.detach().cpu())
            gen_t_b.append(gen_t.detach().cpu())
            cond_s_b.append(cond_itf_s.detach().cpu())
            cond_t_b.append(cond_itf_t.detach().cpu())
        gen_seq_b = torch.cat(gen_seq_b, dim= 0)
        gen_s_b = torch.cat(gen_s_b, dim= 0)
        gen_t_b = torch.cat(gen_t_b, dim= 0)
        cond_s_b = torch.cat(cond_s_b, dim= 0)
        cond_t_b = torch.cat(cond_t_b, dim= 0)
        return gen_seq_b.numpy(), gen_s_b.numpy(), gen_t_b.numpy(), cond_s_b.numpy(), cond_t_b.numpy()




if __name__ == "__main__" :

    import numpy as np
    from preprocessing import *
    from init import *
    from set_param import *
    from utils import evaluate_metrics


    # init TSSR
    tssr = TSSR(
        gen_model_s, 
        gen_model_t, 
        decomp_f, 
        decomp_c, 
        sr_ratio= data_parser[args.dataset][args.task_type]["scale"], 
        task_type= args.task_type,
        n_step= args.sample_step, 
    )


    # train
    if args.predictor == "VP" and args.version == "large" :
        if args.retrain is True :
            loss = tssr.fit(
                train_data, 
                sample_drift= data_parser[args.dataset][args.task_type]["sample_drift"],
                args= args,
            )
            torch.save(tssr, "SRT-large/SRTlarge_{}.pth".format(args.task_type))
        else :
            tssr = torch.load("SRT-large/SRTlarge_{}.pth".format(args.task_type), weights_only=False)
    else :
        loss = tssr.fit(
            train_data, 
            sample_drift= data_parser[args.dataset][args.task_type]["sample_drift"],
            args= args,
        )


    # generation
    if args.predictor == "VP" and args.version == "large" :
        gen_seq_list = []
        for d in range(test_lr.shape[2]) :
            test_lr_d = test_lr[ : , : , d].reshape(-1, test_lr.shape[1], 1)
            gen_seq_d, _, _, _, _ = tssr.generation(
                lr= test_lr_d,
                sr_times= data_parser[args.dataset][args.task_type]["scale"],
                sample_drift= data_parser[args.dataset][args.task_type]["sample_drift"],
                batch_size= args.sample_batch_size,
            )
            gen_seq_list.append(gen_seq_d)
        gen_seq = np.concatenate(gen_seq_list, axis=2)
    else :
        gen_seq, gen_s, gen_t, cond_s, cond_t = tssr.generation(
            lr= test_lr,
            sr_times= data_parser[args.dataset][args.task_type]["scale"],
            sample_drift= data_parser[args.dataset][args.task_type]["sample_drift"],
            batch_size= args.sample_batch_size,
        )
    print("generated data : ", gen_seq.shape)
    print("test data : ", tssr_gt.shape)


    res_metric = evaluate_metrics(gen_seq, tssr_gt)
    print("mse : ", res_metric["mse"])
    print("dtw : ", res_metric["dtw"])