import meta_train_dataset 
import torch
import os
import numpy as np
from io_utils import parse_args_eposide_train
import ResNet10
import ProtoNet
import fuseNet
import torch.nn as nn
from torch.autograd import Variable
import utils
import random
import copy
import torch.nn.functional as F
import warnings
warnings.filterwarnings("ignore", category=Warning)


def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    

def train(train_loader, model, fuse_net, Siamese_model_low, Siamese_model_high, head, loss_fn, loss_fn_MSE, optimizer, params):
    model.train()
    fuse_net.train()
    top1 = utils.AverageMeter()
    total_loss = 0
    softmax = torch.nn.Softmax(dim=1)
    for i, (x,y) in enumerate(train_loader):
        optimizer.zero_grad() 
        x_raw = x[0] # (way, shot+query,3,224,224)
        x_raw = Variable(x_raw.cuda()) # x:(way,shot+query,3,224,224)
        x_raw = x_raw.contiguous().view(params.n_way*(params.n_support+params.n_query), *x_raw.size()[2:]) # (way*(shot+query),3,224,224)
        out_raw = model(x_raw) # (way*(shot+query),512)
        out = out_raw.view(params.n_way, params.n_support+params.n_query, -1) #(way,shot+query,512) 
        z_support = out[:, :params.n_support] # (way,shot,512)
        z_query = out[:, params.n_support:] # (way,query,512)
        z_proto_raw = z_support.view(params.n_way, params.n_support, -1).mean(1)
        z_query = z_query.contiguous().view(params.n_way*params.n_query, -1)
        pred_raw = head(z_proto_raw, z_query) #(way*query,way)
        

        x_low = x[1] # (way, shot+query,3,224,224)
        x_low = Variable(x_low.cuda()) # x:(way,shot+query,3,224,224)
        x_low = x_low.contiguous().view(params.n_way*(params.n_support+params.n_query), *x_low.size()[2:]) # (way*(shot+query),3,224,224)
        with torch.no_grad():
            out_low = Siamese_model_low(x_low) # (way*(shot+query),512)
        out = out_low.view(params.n_way, params.n_support+params.n_query, -1) #(way,shot+query,512) 
        z_support = out[:, :params.n_support] # (way,shot,512)
        z_query = out[:, params.n_support:] # (way,query,512)
        z_proto_low = z_support.view(params.n_way, params.n_support, -1).mean(1)
        z_query = z_query.contiguous().view(params.n_way*params.n_query, -1)
        pred_low = head(z_proto_low, z_query) #(way*query,way)
        
        x_high = x[2] # (way, shot+query,3,224,224)
        x_high = Variable(x_high.cuda()) # x:(way,shot+query,3,224,224)
        x_high = x_high.contiguous().view(params.n_way*(params.n_support+params.n_query), *x_high.size()[2:]) # (way*(shot+query),3,224,224)
        with torch.no_grad():
            out_high = Siamese_model_high(x_high) # (_,512)
        out = out_high.view(params.n_way, params.n_support+params.n_query, -1) #(way,shot+query,512) 
        z_support = out[:, :params.n_support] # (way,shot,512)
        z_query = out[:, params.n_support:] # (way,query,512)
        z_proto_high = z_support.view(params.n_way, params.n_support, -1).mean(1)
        z_query = z_query.contiguous().view(params.n_way*params.n_query, -1)
        pred_high = head(z_proto_high, z_query) #(way*query,way)

        y = torch.from_numpy(np.repeat(range(params.n_way), params.n_query))
        y = Variable(y.cuda())
        loss_ce = loss_fn(pred_raw, y)
        
        _, predicted = torch.max(pred_raw.data, 1)
        correct = predicted.eq(y.data).cpu().sum()
        top1.update(correct.item()*100 / (y.size(0)+0.0), y.size(0))  
        
        pred_low = softmax(pred_low)
        pred_high = softmax(pred_high)
        pred_raw = softmax(pred_raw)
        alignment_loss_low = torch.mean(torch.sum(torch.log(pred_low**(-pred_raw)), dim=1))
        alignment_loss_high = torch.mean(torch.sum(torch.log(pred_high**(-pred_raw)), dim=1))
        alignment_loss = 0.5 * alignment_loss_low + 0.5 * alignment_loss_high
        
    
        out_low = fuse_net(out_low) # (way*(shot+query),512)
        out_high = fuse_net(out_high) # (way*(shot+query),512)
        out_fuse = out_low + out_high # (way*(shot+query),512)
        out_raw = fuse_net(out_raw) # (way*(shot+query),512)
        out_fuse = F.normalize(out_fuse, dim=1)
        out_raw = F.normalize(out_raw, dim=1)
        reconstruct_loss = loss_fn_MSE(out_fuse, out_raw)
        
        
        loss = loss_ce + alignment_loss * params.lamba_alignment_loss+ reconstruct_loss * params.lamba_reconstruct_loss
        
        loss.backward()
        optimizer.step()

        with torch.no_grad():
            for param_q, param_k, param_v in zip(model.parameters(), Siamese_model_low.parameters(), Siamese_model_high.parameters()):
                param_k.data = param_k.data * params.m_low + param_q.data * (1. - params.m_low)
                param_v.data = param_v.data * params.m_high + param_q.data * (1. - params.m_high)
    
        total_loss = total_loss + loss.item()
    avg_loss = total_loss/float(i+1)
    return avg_loss, top1.avg
        
 
                
if __name__=='__main__':
    params = parse_args_eposide_train()
    setup_seed(params.seed)
    datamgr = meta_train_dataset.Eposide_DataManager(data_path=params.data_path, base_class=params.base_class, image_size=params.image_size, n_way=params.n_way, n_support=params.n_support, n_query=params.n_query, n_eposide=params.n_eposide)
    base_loader = datamgr.get_data_loader(aug=params.train_aug)
    model = ResNet10.ResNet(list_of_out_dims=params.list_of_out_dims, list_of_stride=params.list_of_stride, list_of_dilated_rate=params.list_of_dilated_rate)
    fuse_net = fuseNet.Fuse()
    head = ProtoNet.ProtoNet()
    if not os.path.isdir(params.save_dir):
        os.makedirs(params.save_dir)
    tmp = torch.load(params.model_path)
    state = tmp['state']
    model.load_state_dict(state)
    Siamese_model_low = copy.deepcopy(model)
    Siamese_model_low = Siamese_model_low.cuda()
    Siamese_model_high = copy.deepcopy(model)
    Siamese_model_high = Siamese_model_high.cuda()
    model = model.cuda()
    fuse_net = fuse_net.cuda()
    head = head.cuda()
    loss_fn = nn.CrossEntropyLoss().cuda()
    loss_fn_MSE = nn.MSELoss().cuda()
    optimizer = torch.optim.Adam([{"params":model.parameters()}, {"params":fuse_net.parameters()}], lr=params.lr)
    for epoch in range(params.epoch):
        train_loss, train_acc = train(base_loader, model, fuse_net, Siamese_model_low, Siamese_model_high, head, loss_fn, loss_fn_MSE, optimizer, params)
        print('train:', epoch+1, 'current epoch train loss:', train_loss, 'current epoch train acc:', train_acc)
        # 保存模型
        if (epoch+1) % params.save_freq==0:
            outfile = os.path.join(params.save_dir, '{:d}.tar'.format(epoch+1))
            torch.save({
            'epoch':epoch+1, 
            'state_model':model.state_dict(),
            'state_fuse_net':fuse_net.state_dict()},
             outfile)


    
    
    
    